sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,7 +1,7 @@
1
1
  import io
2
2
  from collections.abc import AsyncGenerator, Generator
3
3
  from contextlib import asynccontextmanager, contextmanager
4
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
4
+ from typing import TYPE_CHECKING, Any, Optional, cast
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from psycopg.abc import Query
@@ -11,6 +11,7 @@ from psycopg.rows import DictRow as PsycopgDictRow
11
11
  from sqlglot.dialects.dialect import DialectType
12
12
 
13
13
  from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
14
+ from sqlspec.driver.connection import managed_transaction_async, managed_transaction_sync
14
15
  from sqlspec.driver.mixins import (
15
16
  AsyncPipelinedExecutionMixin,
16
17
  AsyncStorageMixin,
@@ -20,12 +21,13 @@ from sqlspec.driver.mixins import (
20
21
  ToSchemaMixin,
21
22
  TypeCoercionMixin,
22
23
  )
24
+ from sqlspec.driver.parameters import normalize_parameter_sequence
23
25
  from sqlspec.exceptions import PipelineExecutionError
24
- from sqlspec.statement.parameters import ParameterStyle
25
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
26
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
27
+ from sqlspec.statement.result import ArrowResult, SQLResult
26
28
  from sqlspec.statement.splitter import split_sql_script
27
29
  from sqlspec.statement.sql import SQL, SQLConfig
28
- from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
30
+ from sqlspec.typing import DictRow, RowT
29
31
  from sqlspec.utils.logging import get_logger
30
32
 
31
33
  if TYPE_CHECKING:
@@ -73,12 +75,18 @@ class PsycopgSyncDriver(
73
75
 
74
76
  def _execute_statement(
75
77
  self, statement: SQL, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
76
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
78
+ ) -> SQLResult[RowT]:
77
79
  if statement.is_script:
78
80
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
79
81
  return self._execute_script(sql, connection=connection, **kwargs)
80
82
 
81
- detected_styles = {p.style for p in statement.parameter_info}
83
+ detected_styles = set()
84
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
85
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
86
+ param_infos = validator.extract_parameters(sql_str)
87
+ if param_infos:
88
+ detected_styles = {p.style for p in param_infos}
89
+
82
90
  target_style = self.default_parameter_style
83
91
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
84
92
  if unsupported_styles:
@@ -90,20 +98,39 @@ class PsycopgSyncDriver(
90
98
  break
91
99
 
92
100
  if statement.is_many:
93
- sql, params = statement.compile(placeholder_style=target_style)
94
- # For execute_many, check if parameters were passed via kwargs (legacy support)
95
- # Otherwise use the parameters from the SQL object
101
+ # Check if parameters were provided in kwargs first
96
102
  kwargs_params = kwargs.get("parameters")
97
103
  if kwargs_params is not None:
104
+ # Use the SQL string directly if parameters come from kwargs
105
+ sql = statement.to_sql(placeholder_style=target_style)
98
106
  params = kwargs_params
107
+ else:
108
+ sql, params = statement.compile(placeholder_style=target_style)
99
109
  if params is not None:
100
110
  processed_params = [self._process_parameters(param_set) for param_set in params]
101
111
  params = processed_params
102
- return self._execute_many(sql, params, connection=connection, **kwargs)
103
-
104
- sql, params = statement.compile(placeholder_style=target_style)
112
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute_many method signature
113
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
114
+ return self._execute_many(sql, params, connection=connection, **exec_kwargs)
115
+
116
+ # Check if parameters were provided in kwargs (user-provided parameters)
117
+ kwargs_params = kwargs.get("parameters")
118
+ if kwargs_params is not None:
119
+ # Use the SQL string directly if parameters come from kwargs
120
+ sql = statement.to_sql(placeholder_style=target_style)
121
+ params = kwargs_params
122
+ else:
123
+ sql, params = statement.compile(placeholder_style=target_style)
105
124
  params = self._process_parameters(params)
106
- return self._execute(sql, params, statement, connection=connection, **kwargs)
125
+
126
+ # Fix over-nested parameters for Psycopg
127
+ # If params is a tuple containing a single tuple or dict, flatten it
128
+ if isinstance(params, tuple) and len(params) == 1 and isinstance(params[0], (tuple, dict, list)):
129
+ params = params[0]
130
+
131
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute method signature
132
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
133
+ return self._execute(sql, params, statement, connection=connection, **exec_kwargs)
107
134
 
108
135
  def _execute(
109
136
  self,
@@ -112,30 +139,51 @@ class PsycopgSyncDriver(
112
139
  statement: SQL,
113
140
  connection: Optional[PsycopgSyncConnection] = None,
114
141
  **kwargs: Any,
115
- ) -> Union[SelectResultDict, DMLResultDict]:
116
- conn = self._connection(connection)
142
+ ) -> SQLResult[RowT]:
143
+ # Use provided connection or driver's default connection
144
+ conn = connection if connection is not None else self._connection(None)
117
145
 
118
- # Check if this is a COPY command
146
+ # Handle COPY commands separately (they don't use transactions)
119
147
  sql_upper = sql.strip().upper()
120
148
  if sql_upper.startswith("COPY") and ("FROM STDIN" in sql_upper or "TO STDOUT" in sql_upper):
121
149
  return self._handle_copy_command(sql, parameters, conn)
122
150
 
123
- with conn.cursor() as cursor:
124
- cursor.execute(cast("Query", sql), parameters)
125
- # Check if the statement returns rows by checking cursor.description
126
- # This is more reliable than parsing when parsing is disabled
127
- if cursor.description is not None:
128
- fetched_data = cursor.fetchall()
129
- column_names = [col.name for col in cursor.description]
130
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
131
- return {"rows_affected": cursor.rowcount, "status_message": cursor.statusmessage or "OK"}
132
-
133
- def _handle_copy_command(
134
- self, sql: str, data: Any, connection: PsycopgSyncConnection
135
- ) -> Union[SelectResultDict, DMLResultDict]:
151
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
152
+ # For Psycopg, pass parameters directly to the driver
153
+ final_params = parameters
154
+
155
+ # Debug logging
156
+ logger.debug("Executing SQL: %r with parameters: %r", sql, final_params)
157
+
158
+ with txn_conn.cursor() as cursor:
159
+ cursor.execute(cast("Query", sql), final_params)
160
+ if cursor.description is not None:
161
+ fetched_data = cursor.fetchall()
162
+ column_names = [col.name for col in cursor.description]
163
+ return SQLResult(
164
+ statement=statement,
165
+ data=cast("list[RowT]", fetched_data),
166
+ column_names=column_names,
167
+ rows_affected=len(fetched_data),
168
+ operation_type="SELECT",
169
+ )
170
+ operation_type = self._determine_operation_type(statement)
171
+ return SQLResult(
172
+ statement=statement,
173
+ data=[],
174
+ rows_affected=cursor.rowcount or 0,
175
+ operation_type=operation_type,
176
+ metadata={"status_message": cursor.statusmessage or "OK"},
177
+ )
178
+
179
+ def _handle_copy_command(self, sql: str, data: Any, connection: PsycopgSyncConnection) -> SQLResult[RowT]:
136
180
  """Handle PostgreSQL COPY commands using cursor.copy() method."""
137
181
  sql_upper = sql.strip().upper()
138
182
 
183
+ # Handle case where data is wrapped in a single-element tuple (from positional args)
184
+ if isinstance(data, tuple) and len(data) == 1:
185
+ data = data[0]
186
+
139
187
  with connection.cursor() as cursor:
140
188
  if "TO STDOUT" in sql_upper:
141
189
  # COPY TO STDOUT - read data from the database
@@ -143,13 +191,20 @@ class PsycopgSyncDriver(
143
191
  with cursor.copy(cast("Query", sql)) as copy:
144
192
  output_data.extend(row for row in copy)
145
193
 
146
- # Return as SelectResultDict with the raw COPY data
147
- return {"data": output_data, "column_names": ["copy_data"], "rows_affected": len(output_data)}
194
+ return SQLResult(
195
+ statement=SQL(sql, _dialect=self.dialect),
196
+ data=cast("list[RowT]", output_data),
197
+ column_names=["copy_data"],
198
+ rows_affected=len(output_data),
199
+ operation_type="SELECT",
200
+ )
148
201
  # COPY FROM STDIN - write data to the database
149
202
  with cursor.copy(cast("Query", sql)) as copy:
150
203
  if data:
151
204
  # If data is provided, write it to the copy stream
152
- if isinstance(data, (str, bytes)):
205
+ if isinstance(data, str):
206
+ copy.write(data.encode("utf-8"))
207
+ elif isinstance(data, bytes):
153
208
  copy.write(data)
154
209
  elif isinstance(data, (list, tuple)):
155
210
  # If data is a list/tuple of rows, write each row
@@ -160,33 +215,57 @@ class PsycopgSyncDriver(
160
215
  copy.write_row(data)
161
216
 
162
217
  # For COPY operations, cursor.rowcount contains the number of rows affected
163
- return {"rows_affected": cursor.rowcount or -1, "status_message": cursor.statusmessage or "COPY COMPLETE"}
218
+ return SQLResult(
219
+ statement=SQL(sql, _dialect=self.dialect),
220
+ data=[],
221
+ rows_affected=cursor.rowcount or -1,
222
+ operation_type="EXECUTE",
223
+ metadata={"status_message": cursor.statusmessage or "COPY COMPLETE"},
224
+ )
164
225
 
165
226
  def _execute_many(
166
227
  self, sql: str, param_list: Any, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
167
- ) -> DMLResultDict:
168
- conn = self._connection(connection)
169
- with self._get_cursor(conn) as cursor:
170
- cursor.executemany(sql, param_list or [])
171
- # psycopg's executemany might return -1 or 0 for rowcount
172
- # In that case, use the length of param_list for DML operations
173
- rows_affected = cursor.rowcount
174
- if rows_affected <= 0 and param_list:
175
- rows_affected = len(param_list)
176
- result: DMLResultDict = {"rows_affected": rows_affected, "status_message": cursor.statusmessage or "OK"}
177
- return result
228
+ ) -> SQLResult[RowT]:
229
+ # Use provided connection or driver's default connection
230
+ conn = connection if connection is not None else self._connection(None)
231
+
232
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
233
+ # Normalize parameter list using consolidated utility
234
+ normalized_param_list = normalize_parameter_sequence(param_list)
235
+ final_param_list = normalized_param_list or []
236
+
237
+ with self._get_cursor(txn_conn) as cursor:
238
+ cursor.executemany(sql, final_param_list)
239
+ # psycopg's executemany might return -1 or 0 for rowcount
240
+ # In that case, use the length of param_list for DML operations
241
+ rows_affected = cursor.rowcount
242
+ if rows_affected <= 0 and final_param_list:
243
+ rows_affected = len(final_param_list)
244
+ return SQLResult(
245
+ statement=SQL(sql, _dialect=self.dialect),
246
+ data=[],
247
+ rows_affected=rows_affected,
248
+ operation_type="EXECUTE",
249
+ metadata={"status_message": cursor.statusmessage or "OK"},
250
+ )
178
251
 
179
252
  def _execute_script(
180
253
  self, script: str, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
181
- ) -> ScriptResultDict:
182
- conn = self._connection(connection)
183
- with self._get_cursor(conn) as cursor:
254
+ ) -> SQLResult[RowT]:
255
+ # Use provided connection or driver's default connection
256
+ conn = connection if connection is not None else self._connection(None)
257
+
258
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
184
259
  cursor.execute(script)
185
- result: ScriptResultDict = {
186
- "statements_executed": -1,
187
- "status_message": cursor.statusmessage or "SCRIPT EXECUTED",
188
- }
189
- return result
260
+ return SQLResult(
261
+ statement=SQL(script, _dialect=self.dialect).as_script(),
262
+ data=[],
263
+ rows_affected=0,
264
+ operation_type="SCRIPT",
265
+ metadata={"status_message": cursor.statusmessage or "SCRIPT EXECUTED"},
266
+ total_statements=1,
267
+ successful_statements=1,
268
+ )
190
269
 
191
270
  def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
192
271
  self._ensure_pyarrow_installed()
@@ -209,61 +288,6 @@ class PsycopgSyncDriver(
209
288
 
210
289
  return cursor.rowcount if cursor.rowcount is not None else -1
211
290
 
212
- def _wrap_select_result(
213
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
214
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
215
- rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in result["data"]]
216
-
217
- if schema_type:
218
- return SQLResult[ModelDTOT](
219
- statement=statement,
220
- data=list(self.to_schema(data=result["data"], schema_type=schema_type)),
221
- column_names=result["column_names"],
222
- rows_affected=result["rows_affected"],
223
- operation_type="SELECT",
224
- )
225
- return SQLResult[RowT](
226
- statement=statement,
227
- data=rows_as_dicts,
228
- column_names=result["column_names"],
229
- rows_affected=result["rows_affected"],
230
- operation_type="SELECT",
231
- )
232
-
233
- def _wrap_execute_result(
234
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
235
- ) -> SQLResult[RowT]:
236
- operation_type = "UNKNOWN"
237
- if statement.expression:
238
- operation_type = str(statement.expression.key).upper()
239
-
240
- # Handle case where we got a SelectResultDict but it was routed here due to parsing being disabled
241
- if is_dict_with_field(result, "data") and is_dict_with_field(result, "column_names"):
242
- # This is actually a SELECT result, wrap it properly
243
- return self._wrap_select_result(statement, cast("SelectResultDict", result), **kwargs)
244
-
245
- if is_dict_with_field(result, "statements_executed"):
246
- return SQLResult[RowT](
247
- statement=statement,
248
- data=[],
249
- rows_affected=0,
250
- operation_type="SCRIPT",
251
- metadata={"status_message": result.get("status_message", "")},
252
- )
253
-
254
- if is_dict_with_field(result, "rows_affected"):
255
- return SQLResult[RowT](
256
- statement=statement,
257
- data=[],
258
- rows_affected=cast("int", result.get("rows_affected", -1)),
259
- operation_type=operation_type,
260
- metadata={"status_message": result.get("status_message", "")},
261
- )
262
-
263
- # This shouldn't happen with TypedDict approach
264
- msg = f"Unexpected result type: {type(result)}"
265
- raise ValueError(msg)
266
-
267
291
  def _connection(self, connection: Optional[PsycopgSyncConnection] = None) -> PsycopgSyncConnection:
268
292
  """Get the connection to use for the operation."""
269
293
  return connection or self.connection
@@ -306,7 +330,6 @@ class PsycopgSyncDriver(
306
330
  from sqlspec.exceptions import PipelineExecutionError
307
331
 
308
332
  try:
309
- # Prepare SQL and parameters
310
333
  filtered_sql = self._apply_operation_filters(operation.sql, operation.filters)
311
334
  sql_str = filtered_sql.to_sql(placeholder_style=self.default_parameter_style)
312
335
  params = self._convert_psycopg_params(filtered_sql.parameters)
@@ -355,7 +378,7 @@ class PsycopgSyncDriver(
355
378
  statement=sql,
356
379
  data=cast("list[RowT]", []),
357
380
  rows_affected=cursor.rowcount,
358
- operation_type="execute_many",
381
+ operation_type="EXECUTE",
359
382
  metadata={"status_message": "OK"},
360
383
  )
361
384
 
@@ -370,7 +393,7 @@ class PsycopgSyncDriver(
370
393
  statement=sql,
371
394
  data=cast("list[RowT]", data),
372
395
  rows_affected=len(data),
373
- operation_type="select",
396
+ operation_type="SELECT",
374
397
  metadata={"column_names": column_names},
375
398
  )
376
399
 
@@ -391,7 +414,7 @@ class PsycopgSyncDriver(
391
414
  statement=sql,
392
415
  data=cast("list[RowT]", []),
393
416
  rows_affected=total_affected,
394
- operation_type="execute_script",
417
+ operation_type="SCRIPT",
395
418
  metadata={"status_message": "SCRIPT EXECUTED", "statements_executed": len(script_statements)},
396
419
  )
397
420
 
@@ -403,7 +426,7 @@ class PsycopgSyncDriver(
403
426
  statement=sql,
404
427
  data=cast("list[RowT]", []),
405
428
  rows_affected=cursor.rowcount or 0,
406
- operation_type="execute",
429
+ operation_type="EXECUTE",
407
430
  metadata={"status_message": "OK"},
408
431
  )
409
432
 
@@ -424,7 +447,6 @@ class PsycopgSyncDriver(
424
447
  # Psycopg handles dict parameters directly for named placeholders
425
448
  return params
426
449
  if isinstance(params, (list, tuple)):
427
- # Convert to tuple for positional parameters
428
450
  return tuple(params)
429
451
  # Single parameter
430
452
  return (params,)
@@ -444,7 +466,6 @@ class PsycopgSyncDriver(
444
466
  def _split_script_statements(self, script: str, strip_trailing_semicolon: bool = False) -> "list[str]":
445
467
  """Split a SQL script into individual statements."""
446
468
 
447
- # Use the sophisticated splitter with PostgreSQL dialect
448
469
  return split_sql_script(script=script, dialect="postgresql", strip_trailing_semicolon=strip_trailing_semicolon)
449
470
 
450
471
 
@@ -482,22 +503,24 @@ class PsycopgAsyncDriver(
482
503
 
483
504
  async def _execute_statement(
484
505
  self, statement: SQL, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
485
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
506
+ ) -> SQLResult[RowT]:
486
507
  if statement.is_script:
487
508
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
488
509
  return await self._execute_script(sql, connection=connection, **kwargs)
489
510
 
490
- # Determine if we need to convert parameter style
491
- detected_styles = {p.style for p in statement.parameter_info}
511
+ detected_styles = set()
512
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
513
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
514
+ param_infos = validator.extract_parameters(sql_str)
515
+ if param_infos:
516
+ detected_styles = {p.style for p in param_infos}
517
+
492
518
  target_style = self.default_parameter_style
493
519
 
494
- # Check if any detected style is not supported
495
520
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
496
521
  if unsupported_styles:
497
- # Convert to default style if we have unsupported styles
498
522
  target_style = self.default_parameter_style
499
523
  elif detected_styles:
500
- # Use the first detected style if all are supported
501
524
  # Prefer the first supported style found
502
525
  for style in detected_styles:
503
526
  if style in self.supported_parameter_styles:
@@ -505,18 +528,49 @@ class PsycopgAsyncDriver(
505
528
  break
506
529
 
507
530
  if statement.is_many:
508
- sql, _ = statement.compile(placeholder_style=target_style)
509
- # For execute_many, use the parameters passed via kwargs
510
- params = kwargs.get("parameters")
531
+ # Check if parameters were provided in kwargs first
532
+ kwargs_params = kwargs.get("parameters")
533
+ if kwargs_params is not None:
534
+ # Use the SQL string directly if parameters come from kwargs
535
+ sql = statement.to_sql(placeholder_style=target_style)
536
+ params = kwargs_params
537
+ else:
538
+ sql, _ = statement.compile(placeholder_style=target_style)
539
+ params = statement.parameters
511
540
  if params is not None:
512
- # Process each parameter set individually
513
541
  processed_params = [self._process_parameters(param_set) for param_set in params]
514
542
  params = processed_params
515
- return await self._execute_many(sql, params, connection=connection, **kwargs)
516
543
 
517
- sql, params = statement.compile(placeholder_style=target_style)
544
+ # Fix over-nested parameters for each param set
545
+ fixed_params = []
546
+ for param_set in params:
547
+ if isinstance(param_set, tuple) and len(param_set) == 1:
548
+ fixed_params.append(param_set[0])
549
+ else:
550
+ fixed_params.append(param_set)
551
+ params = fixed_params
552
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute_many method signature
553
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
554
+ return await self._execute_many(sql, params, connection=connection, **exec_kwargs)
555
+
556
+ # Check if parameters were provided in kwargs (user-provided parameters)
557
+ kwargs_params = kwargs.get("parameters")
558
+ if kwargs_params is not None:
559
+ # Use the SQL string directly if parameters come from kwargs
560
+ sql = statement.to_sql(placeholder_style=target_style)
561
+ params = kwargs_params
562
+ else:
563
+ sql, params = statement.compile(placeholder_style=target_style)
518
564
  params = self._process_parameters(params)
519
- return await self._execute(sql, params, statement, connection=connection, **kwargs)
565
+
566
+ # Fix over-nested parameters for Psycopg
567
+ # If params is a tuple containing a single tuple or dict, flatten it
568
+ if isinstance(params, tuple) and len(params) == 1 and isinstance(params[0], (tuple, dict, list)):
569
+ params = params[0]
570
+
571
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute method signature
572
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
573
+ return await self._execute(sql, params, statement, connection=connection, **exec_kwargs)
520
574
 
521
575
  async def _execute(
522
576
  self,
@@ -525,41 +579,63 @@ class PsycopgAsyncDriver(
525
579
  statement: SQL,
526
580
  connection: Optional[PsycopgAsyncConnection] = None,
527
581
  **kwargs: Any,
528
- ) -> Union[SelectResultDict, DMLResultDict]:
529
- conn = self._connection(connection)
582
+ ) -> SQLResult[RowT]:
583
+ # Use provided connection or driver's default connection
584
+ conn = connection if connection is not None else self._connection(None)
530
585
 
531
- # Check if this is a COPY command
586
+ # Handle COPY commands separately (they don't use transactions)
532
587
  sql_upper = sql.strip().upper()
533
588
  if sql_upper.startswith("COPY") and ("FROM STDIN" in sql_upper or "TO STDOUT" in sql_upper):
534
589
  return await self._handle_copy_command(sql, parameters, conn)
535
590
 
536
- async with conn.cursor() as cursor:
537
- await cursor.execute(cast("Query", sql), parameters)
538
-
539
- # When parsing is disabled, expression will be None, so check SQL directly
540
- if statement.expression and self.returns_rows(statement.expression):
541
- # For SELECT statements, extract data while cursor is open
542
- fetched_data = await cursor.fetchall()
543
- column_names = [col.name for col in cursor.description or []]
544
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
545
- if not statement.expression and sql.strip().upper().startswith("SELECT"):
546
- # For SELECT statements when parsing is disabled
547
- fetched_data = await cursor.fetchall()
548
- column_names = [col.name for col in cursor.description or []]
549
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
550
- # For DML statements
551
- dml_result: DMLResultDict = {
552
- "rows_affected": cursor.rowcount,
553
- "status_message": cursor.statusmessage or "OK",
554
- }
555
- return dml_result
556
-
557
- async def _handle_copy_command(
558
- self, sql: str, data: Any, connection: PsycopgAsyncConnection
559
- ) -> Union[SelectResultDict, DMLResultDict]:
591
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
592
+ # For Psycopg, pass parameters directly to the driver
593
+ final_params = parameters
594
+
595
+ async with txn_conn.cursor() as cursor:
596
+ await cursor.execute(cast("Query", sql), final_params)
597
+
598
+ # When parsing is disabled, expression will be None, so check SQL directly
599
+ if statement.expression and self.returns_rows(statement.expression):
600
+ # For SELECT statements, extract data while cursor is open
601
+ fetched_data = await cursor.fetchall()
602
+ column_names = [col.name for col in cursor.description or []]
603
+ return SQLResult(
604
+ statement=statement,
605
+ data=cast("list[RowT]", fetched_data),
606
+ column_names=column_names,
607
+ rows_affected=len(fetched_data),
608
+ operation_type="SELECT",
609
+ )
610
+ if not statement.expression and sql.strip().upper().startswith("SELECT"):
611
+ # For SELECT statements when parsing is disabled
612
+ fetched_data = await cursor.fetchall()
613
+ column_names = [col.name for col in cursor.description or []]
614
+ return SQLResult(
615
+ statement=statement,
616
+ data=cast("list[RowT]", fetched_data),
617
+ column_names=column_names,
618
+ rows_affected=len(fetched_data),
619
+ operation_type="SELECT",
620
+ )
621
+ # For DML statements
622
+ operation_type = self._determine_operation_type(statement)
623
+ return SQLResult(
624
+ statement=statement,
625
+ data=[],
626
+ rows_affected=cursor.rowcount or 0,
627
+ operation_type=operation_type,
628
+ metadata={"status_message": cursor.statusmessage or "OK"},
629
+ )
630
+
631
+ async def _handle_copy_command(self, sql: str, data: Any, connection: PsycopgAsyncConnection) -> SQLResult[RowT]:
560
632
  """Handle PostgreSQL COPY commands using cursor.copy() method."""
561
633
  sql_upper = sql.strip().upper()
562
634
 
635
+ # Handle case where data is wrapped in a single-element tuple (from positional args)
636
+ if isinstance(data, tuple) and len(data) == 1:
637
+ data = data[0]
638
+
563
639
  async with connection.cursor() as cursor:
564
640
  if "TO STDOUT" in sql_upper:
565
641
  # COPY TO STDOUT - read data from the database
@@ -567,13 +643,20 @@ class PsycopgAsyncDriver(
567
643
  async with cursor.copy(cast("Query", sql)) as copy:
568
644
  output_data.extend([row async for row in copy])
569
645
 
570
- # Return as SelectResultDict with the raw COPY data
571
- return {"data": output_data, "column_names": ["copy_data"], "rows_affected": len(output_data)}
646
+ return SQLResult(
647
+ statement=SQL(sql, _dialect=self.dialect),
648
+ data=cast("list[RowT]", output_data),
649
+ column_names=["copy_data"],
650
+ rows_affected=len(output_data),
651
+ operation_type="SELECT",
652
+ )
572
653
  # COPY FROM STDIN - write data to the database
573
654
  async with cursor.copy(cast("Query", sql)) as copy:
574
655
  if data:
575
656
  # If data is provided, write it to the copy stream
576
- if isinstance(data, (str, bytes)):
657
+ if isinstance(data, str):
658
+ await copy.write(data.encode("utf-8"))
659
+ elif isinstance(data, bytes):
577
660
  await copy.write(data)
578
661
  elif isinstance(data, (list, tuple)):
579
662
  # If data is a list/tuple of rows, write each row
@@ -584,27 +667,52 @@ class PsycopgAsyncDriver(
584
667
  await copy.write_row(data)
585
668
 
586
669
  # For COPY operations, cursor.rowcount contains the number of rows affected
587
- return {"rows_affected": cursor.rowcount or -1, "status_message": cursor.statusmessage or "COPY COMPLETE"}
670
+ return SQLResult(
671
+ statement=SQL(sql, _dialect=self.dialect),
672
+ data=[],
673
+ rows_affected=cursor.rowcount or -1,
674
+ operation_type="EXECUTE",
675
+ metadata={"status_message": cursor.statusmessage or "COPY COMPLETE"},
676
+ )
588
677
 
589
678
  async def _execute_many(
590
679
  self, sql: str, param_list: Any, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
591
- ) -> DMLResultDict:
592
- conn = self._connection(connection)
593
- async with conn.cursor() as cursor:
594
- await cursor.executemany(cast("Query", sql), param_list or [])
595
- return {"rows_affected": cursor.rowcount, "status_message": cursor.statusmessage or "OK"}
680
+ ) -> SQLResult[RowT]:
681
+ # Use provided connection or driver's default connection
682
+ conn = connection if connection is not None else self._connection(None)
683
+
684
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
685
+ # Normalize parameter list using consolidated utility
686
+ normalized_param_list = normalize_parameter_sequence(param_list)
687
+ final_param_list = normalized_param_list or []
688
+
689
+ async with txn_conn.cursor() as cursor:
690
+ await cursor.executemany(cast("Query", sql), final_param_list)
691
+ return SQLResult(
692
+ statement=SQL(sql, _dialect=self.dialect),
693
+ data=[],
694
+ rows_affected=cursor.rowcount,
695
+ operation_type="EXECUTE",
696
+ metadata={"status_message": cursor.statusmessage or "OK"},
697
+ )
596
698
 
597
699
  async def _execute_script(
598
700
  self, script: str, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
599
- ) -> ScriptResultDict:
600
- conn = self._connection(connection)
601
- async with conn.cursor() as cursor:
701
+ ) -> SQLResult[RowT]:
702
+ # Use provided connection or driver's default connection
703
+ conn = connection if connection is not None else self._connection(None)
704
+
705
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn, txn_conn.cursor() as cursor:
602
706
  await cursor.execute(cast("Query", script))
603
- # For scripts, return script result format
604
- return {
605
- "statements_executed": -1, # Psycopg doesn't provide this info
606
- "status_message": cursor.statusmessage or "SCRIPT EXECUTED",
607
- }
707
+ return SQLResult(
708
+ statement=SQL(script, _dialect=self.dialect).as_script(),
709
+ data=[],
710
+ rows_affected=0,
711
+ operation_type="SCRIPT",
712
+ metadata={"status_message": cursor.statusmessage or "SCRIPT EXECUTED"},
713
+ total_statements=1,
714
+ successful_statements=1,
715
+ )
608
716
 
609
717
  async def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
610
718
  self._ensure_pyarrow_installed()
@@ -639,64 +747,6 @@ class PsycopgAsyncDriver(
639
747
 
640
748
  return cursor.rowcount if cursor.rowcount is not None else -1
641
749
 
642
- async def _wrap_select_result(
643
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
644
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
645
- # result must be a dict with keys: data, column_names, rows_affected
646
- fetched_data = result["data"]
647
- column_names = result["column_names"]
648
- rows_affected = result["rows_affected"]
649
- rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in fetched_data]
650
-
651
- if schema_type:
652
- return SQLResult[ModelDTOT](
653
- statement=statement,
654
- data=list(self.to_schema(data=fetched_data, schema_type=schema_type)),
655
- column_names=column_names,
656
- rows_affected=rows_affected,
657
- operation_type="SELECT",
658
- )
659
- return SQLResult[RowT](
660
- statement=statement,
661
- data=rows_as_dicts,
662
- column_names=column_names,
663
- rows_affected=rows_affected,
664
- operation_type="SELECT",
665
- )
666
-
667
- async def _wrap_execute_result(
668
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
669
- ) -> SQLResult[RowT]:
670
- operation_type = "UNKNOWN"
671
- if statement.expression:
672
- operation_type = str(statement.expression.key).upper()
673
-
674
- # Handle case where we got a SelectResultDict but it was routed here due to parsing being disabled
675
- if is_dict_with_field(result, "data") and is_dict_with_field(result, "column_names"):
676
- # This is actually a SELECT result, wrap it properly
677
- return await self._wrap_select_result(statement, cast("SelectResultDict", result), **kwargs)
678
-
679
- if is_dict_with_field(result, "statements_executed"):
680
- return SQLResult[RowT](
681
- statement=statement,
682
- data=[],
683
- rows_affected=0,
684
- operation_type="SCRIPT",
685
- metadata={"status_message": result.get("status_message", "")},
686
- )
687
-
688
- if is_dict_with_field(result, "rows_affected"):
689
- return SQLResult[RowT](
690
- statement=statement,
691
- data=[],
692
- rows_affected=cast("int", result.get("rows_affected", -1)),
693
- operation_type=operation_type,
694
- metadata={"status_message": result.get("status_message", "")},
695
- )
696
- # This shouldn't happen with TypedDict approach
697
- msg = f"Unexpected result type: {type(result)}"
698
- raise ValueError(msg)
699
-
700
750
  def _connection(self, connection: Optional[PsycopgAsyncConnection] = None) -> PsycopgAsyncConnection:
701
751
  """Get the connection to use for the operation."""
702
752
  return connection or self.connection
@@ -729,7 +779,6 @@ class PsycopgAsyncDriver(
729
779
  from sqlspec.exceptions import PipelineExecutionError
730
780
 
731
781
  try:
732
- # Prepare SQL and parameters
733
782
  filtered_sql = self._apply_operation_filters(operation.sql, operation.filters)
734
783
  sql_str = filtered_sql.to_sql(placeholder_style=self.default_parameter_style)
735
784
  params = self._convert_psycopg_params(filtered_sql.parameters)
@@ -751,7 +800,6 @@ class PsycopgAsyncDriver(
751
800
  msg, operation_index=index, partial_results=[], failed_operation=operation
752
801
  ) from e
753
802
  else:
754
- # Add pipeline context
755
803
  result.operation_index = index
756
804
  result.pipeline_sql = operation.sql
757
805
  return result
@@ -779,7 +827,7 @@ class PsycopgAsyncDriver(
779
827
  statement=sql,
780
828
  data=cast("list[RowT]", []),
781
829
  rows_affected=cursor.rowcount,
782
- operation_type="execute_many",
830
+ operation_type="EXECUTE",
783
831
  metadata={"status_message": "OK"},
784
832
  )
785
833
 
@@ -796,7 +844,7 @@ class PsycopgAsyncDriver(
796
844
  statement=sql,
797
845
  data=cast("list[RowT]", data),
798
846
  rows_affected=len(data),
799
- operation_type="select",
847
+ operation_type="SELECT",
800
848
  metadata={"column_names": column_names},
801
849
  )
802
850
 
@@ -817,7 +865,7 @@ class PsycopgAsyncDriver(
817
865
  statement=sql,
818
866
  data=cast("list[RowT]", []),
819
867
  rows_affected=total_affected,
820
- operation_type="execute_script",
868
+ operation_type="SCRIPT",
821
869
  metadata={"status_message": "SCRIPT EXECUTED", "statements_executed": len(script_statements)},
822
870
  )
823
871
 
@@ -831,7 +879,7 @@ class PsycopgAsyncDriver(
831
879
  statement=sql,
832
880
  data=cast("list[RowT]", []),
833
881
  rows_affected=cursor.rowcount or 0,
834
- operation_type="execute",
882
+ operation_type="EXECUTE",
835
883
  metadata={"status_message": "OK"},
836
884
  )
837
885
 
@@ -852,7 +900,6 @@ class PsycopgAsyncDriver(
852
900
  # Psycopg handles dict parameters directly for named placeholders
853
901
  return params
854
902
  if isinstance(params, (list, tuple)):
855
- # Convert to tuple for positional parameters
856
903
  return tuple(params)
857
904
  # Single parameter
858
905
  return (params,)