sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,7 +1,7 @@
1
1
  import io
2
2
  from collections.abc import AsyncGenerator, Generator
3
3
  from contextlib import asynccontextmanager, contextmanager
4
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
4
+ from typing import TYPE_CHECKING, Any, Optional, cast
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from psycopg.abc import Query
@@ -11,6 +11,7 @@ from psycopg.rows import DictRow as PsycopgDictRow
11
11
  from sqlglot.dialects.dialect import DialectType
12
12
 
13
13
  from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
14
+ from sqlspec.driver.connection import managed_transaction_async, managed_transaction_sync
14
15
  from sqlspec.driver.mixins import (
15
16
  AsyncPipelinedExecutionMixin,
16
17
  AsyncStorageMixin,
@@ -20,11 +21,13 @@ from sqlspec.driver.mixins import (
20
21
  ToSchemaMixin,
21
22
  TypeCoercionMixin,
22
23
  )
23
- from sqlspec.statement.parameters import ParameterStyle
24
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
24
+ from sqlspec.driver.parameters import normalize_parameter_sequence
25
+ from sqlspec.exceptions import PipelineExecutionError
26
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
27
+ from sqlspec.statement.result import ArrowResult, SQLResult
25
28
  from sqlspec.statement.splitter import split_sql_script
26
29
  from sqlspec.statement.sql import SQL, SQLConfig
27
- from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
30
+ from sqlspec.typing import DictRow, RowT
28
31
  from sqlspec.utils.logging import get_logger
29
32
 
30
33
  if TYPE_CHECKING:
@@ -72,12 +75,18 @@ class PsycopgSyncDriver(
72
75
 
73
76
  def _execute_statement(
74
77
  self, statement: SQL, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
75
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
78
+ ) -> SQLResult[RowT]:
76
79
  if statement.is_script:
77
80
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
78
81
  return self._execute_script(sql, connection=connection, **kwargs)
79
82
 
80
- detected_styles = {p.style for p in statement.parameter_info}
83
+ detected_styles = set()
84
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
85
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
86
+ param_infos = validator.extract_parameters(sql_str)
87
+ if param_infos:
88
+ detected_styles = {p.style for p in param_infos}
89
+
81
90
  target_style = self.default_parameter_style
82
91
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
83
92
  if unsupported_styles:
@@ -89,20 +98,39 @@ class PsycopgSyncDriver(
89
98
  break
90
99
 
91
100
  if statement.is_many:
92
- sql, params = statement.compile(placeholder_style=target_style)
93
- # For execute_many, check if parameters were passed via kwargs (legacy support)
94
- # Otherwise use the parameters from the SQL object
101
+ # Check if parameters were provided in kwargs first
95
102
  kwargs_params = kwargs.get("parameters")
96
103
  if kwargs_params is not None:
104
+ # Use the SQL string directly if parameters come from kwargs
105
+ sql = statement.to_sql(placeholder_style=target_style)
97
106
  params = kwargs_params
107
+ else:
108
+ sql, params = statement.compile(placeholder_style=target_style)
98
109
  if params is not None:
99
110
  processed_params = [self._process_parameters(param_set) for param_set in params]
100
111
  params = processed_params
101
- return self._execute_many(sql, params, connection=connection, **kwargs)
102
-
103
- sql, params = statement.compile(placeholder_style=target_style)
112
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute_many method signature
113
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
114
+ return self._execute_many(sql, params, connection=connection, **exec_kwargs)
115
+
116
+ # Check if parameters were provided in kwargs (user-provided parameters)
117
+ kwargs_params = kwargs.get("parameters")
118
+ if kwargs_params is not None:
119
+ # Use the SQL string directly if parameters come from kwargs
120
+ sql = statement.to_sql(placeholder_style=target_style)
121
+ params = kwargs_params
122
+ else:
123
+ sql, params = statement.compile(placeholder_style=target_style)
104
124
  params = self._process_parameters(params)
105
- return self._execute(sql, params, statement, connection=connection, **kwargs)
125
+
126
+ # Fix over-nested parameters for Psycopg
127
+ # If params is a tuple containing a single tuple or dict, flatten it
128
+ if isinstance(params, tuple) and len(params) == 1 and isinstance(params[0], (tuple, dict, list)):
129
+ params = params[0]
130
+
131
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute method signature
132
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
133
+ return self._execute(sql, params, statement, connection=connection, **exec_kwargs)
106
134
 
107
135
  def _execute(
108
136
  self,
@@ -111,43 +139,133 @@ class PsycopgSyncDriver(
111
139
  statement: SQL,
112
140
  connection: Optional[PsycopgSyncConnection] = None,
113
141
  **kwargs: Any,
114
- ) -> Union[SelectResultDict, DMLResultDict]:
115
- conn = self._connection(connection)
116
- with conn.cursor() as cursor:
117
- cursor.execute(cast("Query", sql), parameters)
118
- # Check if the statement returns rows by checking cursor.description
119
- # This is more reliable than parsing when parsing is disabled
120
- if cursor.description is not None:
121
- fetched_data = cursor.fetchall()
122
- column_names = [col.name for col in cursor.description]
123
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
124
- return {"rows_affected": cursor.rowcount, "status_message": cursor.statusmessage or "OK"}
142
+ ) -> SQLResult[RowT]:
143
+ # Use provided connection or driver's default connection
144
+ conn = connection if connection is not None else self._connection(None)
145
+
146
+ # Handle COPY commands separately (they don't use transactions)
147
+ sql_upper = sql.strip().upper()
148
+ if sql_upper.startswith("COPY") and ("FROM STDIN" in sql_upper or "TO STDOUT" in sql_upper):
149
+ return self._handle_copy_command(sql, parameters, conn)
150
+
151
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
152
+ # For Psycopg, pass parameters directly to the driver
153
+ final_params = parameters
154
+
155
+ # Debug logging
156
+ logger.debug("Executing SQL: %r with parameters: %r", sql, final_params)
157
+
158
+ with txn_conn.cursor() as cursor:
159
+ cursor.execute(cast("Query", sql), final_params)
160
+ if cursor.description is not None:
161
+ fetched_data = cursor.fetchall()
162
+ column_names = [col.name for col in cursor.description]
163
+ return SQLResult(
164
+ statement=statement,
165
+ data=cast("list[RowT]", fetched_data),
166
+ column_names=column_names,
167
+ rows_affected=len(fetched_data),
168
+ operation_type="SELECT",
169
+ )
170
+ operation_type = self._determine_operation_type(statement)
171
+ return SQLResult(
172
+ statement=statement,
173
+ data=[],
174
+ rows_affected=cursor.rowcount or 0,
175
+ operation_type=operation_type,
176
+ metadata={"status_message": cursor.statusmessage or "OK"},
177
+ )
178
+
179
+ def _handle_copy_command(self, sql: str, data: Any, connection: PsycopgSyncConnection) -> SQLResult[RowT]:
180
+ """Handle PostgreSQL COPY commands using cursor.copy() method."""
181
+ sql_upper = sql.strip().upper()
182
+
183
+ # Handle case where data is wrapped in a single-element tuple (from positional args)
184
+ if isinstance(data, tuple) and len(data) == 1:
185
+ data = data[0]
186
+
187
+ with connection.cursor() as cursor:
188
+ if "TO STDOUT" in sql_upper:
189
+ # COPY TO STDOUT - read data from the database
190
+ output_data: list[Any] = []
191
+ with cursor.copy(cast("Query", sql)) as copy:
192
+ output_data.extend(row for row in copy)
193
+
194
+ return SQLResult(
195
+ statement=SQL(sql, _dialect=self.dialect),
196
+ data=cast("list[RowT]", output_data),
197
+ column_names=["copy_data"],
198
+ rows_affected=len(output_data),
199
+ operation_type="SELECT",
200
+ )
201
+ # COPY FROM STDIN - write data to the database
202
+ with cursor.copy(cast("Query", sql)) as copy:
203
+ if data:
204
+ # If data is provided, write it to the copy stream
205
+ if isinstance(data, str):
206
+ copy.write(data.encode("utf-8"))
207
+ elif isinstance(data, bytes):
208
+ copy.write(data)
209
+ elif isinstance(data, (list, tuple)):
210
+ # If data is a list/tuple of rows, write each row
211
+ for row in data:
212
+ copy.write_row(row)
213
+ else:
214
+ # Single row
215
+ copy.write_row(data)
216
+
217
+ # For COPY operations, cursor.rowcount contains the number of rows affected
218
+ return SQLResult(
219
+ statement=SQL(sql, _dialect=self.dialect),
220
+ data=[],
221
+ rows_affected=cursor.rowcount or -1,
222
+ operation_type="EXECUTE",
223
+ metadata={"status_message": cursor.statusmessage or "COPY COMPLETE"},
224
+ )
125
225
 
126
226
  def _execute_many(
127
227
  self, sql: str, param_list: Any, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
128
- ) -> DMLResultDict:
129
- conn = self._connection(connection)
130
- with self._get_cursor(conn) as cursor:
131
- cursor.executemany(sql, param_list or [])
132
- # psycopg's executemany might return -1 or 0 for rowcount
133
- # In that case, use the length of param_list for DML operations
134
- rows_affected = cursor.rowcount
135
- if rows_affected <= 0 and param_list:
136
- rows_affected = len(param_list)
137
- result: DMLResultDict = {"rows_affected": rows_affected, "status_message": cursor.statusmessage or "OK"}
138
- return result
228
+ ) -> SQLResult[RowT]:
229
+ # Use provided connection or driver's default connection
230
+ conn = connection if connection is not None else self._connection(None)
231
+
232
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
233
+ # Normalize parameter list using consolidated utility
234
+ normalized_param_list = normalize_parameter_sequence(param_list)
235
+ final_param_list = normalized_param_list or []
236
+
237
+ with self._get_cursor(txn_conn) as cursor:
238
+ cursor.executemany(sql, final_param_list)
239
+ # psycopg's executemany might return -1 or 0 for rowcount
240
+ # In that case, use the length of param_list for DML operations
241
+ rows_affected = cursor.rowcount
242
+ if rows_affected <= 0 and final_param_list:
243
+ rows_affected = len(final_param_list)
244
+ return SQLResult(
245
+ statement=SQL(sql, _dialect=self.dialect),
246
+ data=[],
247
+ rows_affected=rows_affected,
248
+ operation_type="EXECUTE",
249
+ metadata={"status_message": cursor.statusmessage or "OK"},
250
+ )
139
251
 
140
252
  def _execute_script(
141
253
  self, script: str, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
142
- ) -> ScriptResultDict:
143
- conn = self._connection(connection)
144
- with self._get_cursor(conn) as cursor:
254
+ ) -> SQLResult[RowT]:
255
+ # Use provided connection or driver's default connection
256
+ conn = connection if connection is not None else self._connection(None)
257
+
258
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
145
259
  cursor.execute(script)
146
- result: ScriptResultDict = {
147
- "statements_executed": -1,
148
- "status_message": cursor.statusmessage or "SCRIPT EXECUTED",
149
- }
150
- return result
260
+ return SQLResult(
261
+ statement=SQL(script, _dialect=self.dialect).as_script(),
262
+ data=[],
263
+ rows_affected=0,
264
+ operation_type="SCRIPT",
265
+ metadata={"status_message": cursor.statusmessage or "SCRIPT EXECUTED"},
266
+ total_statements=1,
267
+ successful_statements=1,
268
+ )
151
269
 
152
270
  def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
153
271
  self._ensure_pyarrow_installed()
@@ -170,61 +288,6 @@ class PsycopgSyncDriver(
170
288
 
171
289
  return cursor.rowcount if cursor.rowcount is not None else -1
172
290
 
173
- def _wrap_select_result(
174
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
175
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
176
- rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in result["data"]]
177
-
178
- if schema_type:
179
- return SQLResult[ModelDTOT](
180
- statement=statement,
181
- data=list(self.to_schema(data=result["data"], schema_type=schema_type)),
182
- column_names=result["column_names"],
183
- rows_affected=result["rows_affected"],
184
- operation_type="SELECT",
185
- )
186
- return SQLResult[RowT](
187
- statement=statement,
188
- data=rows_as_dicts,
189
- column_names=result["column_names"],
190
- rows_affected=result["rows_affected"],
191
- operation_type="SELECT",
192
- )
193
-
194
- def _wrap_execute_result(
195
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
196
- ) -> SQLResult[RowT]:
197
- operation_type = "UNKNOWN"
198
- if statement.expression:
199
- operation_type = str(statement.expression.key).upper()
200
-
201
- # Handle case where we got a SelectResultDict but it was routed here due to parsing being disabled
202
- if is_dict_with_field(result, "data") and is_dict_with_field(result, "column_names"):
203
- # This is actually a SELECT result, wrap it properly
204
- return self._wrap_select_result(statement, cast("SelectResultDict", result), **kwargs)
205
-
206
- if is_dict_with_field(result, "statements_executed"):
207
- return SQLResult[RowT](
208
- statement=statement,
209
- data=[],
210
- rows_affected=0,
211
- operation_type="SCRIPT",
212
- metadata={"status_message": result.get("status_message", "")},
213
- )
214
-
215
- if is_dict_with_field(result, "rows_affected"):
216
- return SQLResult[RowT](
217
- statement=statement,
218
- data=[],
219
- rows_affected=cast("int", result.get("rows_affected", -1)),
220
- operation_type=operation_type,
221
- metadata={"status_message": result.get("status_message", "")},
222
- )
223
-
224
- # This shouldn't happen with TypedDict approach
225
- msg = f"Unexpected result type: {type(result)}"
226
- raise ValueError(msg)
227
-
228
291
  def _connection(self, connection: Optional[PsycopgSyncConnection] = None) -> PsycopgSyncConnection:
229
292
  """Get the connection to use for the operation."""
230
293
  return connection or self.connection
@@ -242,7 +305,6 @@ class PsycopgSyncDriver(
242
305
  Returns:
243
306
  List of SQLResult objects from all operations
244
307
  """
245
- from sqlspec.exceptions import PipelineExecutionError
246
308
 
247
309
  results = []
248
310
  connection = self._connection()
@@ -268,7 +330,6 @@ class PsycopgSyncDriver(
268
330
  from sqlspec.exceptions import PipelineExecutionError
269
331
 
270
332
  try:
271
- # Prepare SQL and parameters
272
333
  filtered_sql = self._apply_operation_filters(operation.sql, operation.filters)
273
334
  sql_str = filtered_sql.to_sql(placeholder_style=self.default_parameter_style)
274
335
  params = self._convert_psycopg_params(filtered_sql.parameters)
@@ -317,7 +378,7 @@ class PsycopgSyncDriver(
317
378
  statement=sql,
318
379
  data=cast("list[RowT]", []),
319
380
  rows_affected=cursor.rowcount,
320
- operation_type="execute_many",
381
+ operation_type="EXECUTE",
321
382
  metadata={"status_message": "OK"},
322
383
  )
323
384
 
@@ -332,7 +393,7 @@ class PsycopgSyncDriver(
332
393
  statement=sql,
333
394
  data=cast("list[RowT]", data),
334
395
  rows_affected=len(data),
335
- operation_type="select",
396
+ operation_type="SELECT",
336
397
  metadata={"column_names": column_names},
337
398
  )
338
399
 
@@ -353,7 +414,7 @@ class PsycopgSyncDriver(
353
414
  statement=sql,
354
415
  data=cast("list[RowT]", []),
355
416
  rows_affected=total_affected,
356
- operation_type="execute_script",
417
+ operation_type="SCRIPT",
357
418
  metadata={"status_message": "SCRIPT EXECUTED", "statements_executed": len(script_statements)},
358
419
  )
359
420
 
@@ -365,7 +426,7 @@ class PsycopgSyncDriver(
365
426
  statement=sql,
366
427
  data=cast("list[RowT]", []),
367
428
  rows_affected=cursor.rowcount or 0,
368
- operation_type="execute",
429
+ operation_type="EXECUTE",
369
430
  metadata={"status_message": "OK"},
370
431
  )
371
432
 
@@ -386,7 +447,6 @@ class PsycopgSyncDriver(
386
447
  # Psycopg handles dict parameters directly for named placeholders
387
448
  return params
388
449
  if isinstance(params, (list, tuple)):
389
- # Convert to tuple for positional parameters
390
450
  return tuple(params)
391
451
  # Single parameter
392
452
  return (params,)
@@ -406,7 +466,6 @@ class PsycopgSyncDriver(
406
466
  def _split_script_statements(self, script: str, strip_trailing_semicolon: bool = False) -> "list[str]":
407
467
  """Split a SQL script into individual statements."""
408
468
 
409
- # Use the sophisticated splitter with PostgreSQL dialect
410
469
  return split_sql_script(script=script, dialect="postgresql", strip_trailing_semicolon=strip_trailing_semicolon)
411
470
 
412
471
 
@@ -444,22 +503,24 @@ class PsycopgAsyncDriver(
444
503
 
445
504
  async def _execute_statement(
446
505
  self, statement: SQL, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
447
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
506
+ ) -> SQLResult[RowT]:
448
507
  if statement.is_script:
449
508
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
450
509
  return await self._execute_script(sql, connection=connection, **kwargs)
451
510
 
452
- # Determine if we need to convert parameter style
453
- detected_styles = {p.style for p in statement.parameter_info}
511
+ detected_styles = set()
512
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
513
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
514
+ param_infos = validator.extract_parameters(sql_str)
515
+ if param_infos:
516
+ detected_styles = {p.style for p in param_infos}
517
+
454
518
  target_style = self.default_parameter_style
455
519
 
456
- # Check if any detected style is not supported
457
520
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
458
521
  if unsupported_styles:
459
- # Convert to default style if we have unsupported styles
460
522
  target_style = self.default_parameter_style
461
523
  elif detected_styles:
462
- # Use the first detected style if all are supported
463
524
  # Prefer the first supported style found
464
525
  for style in detected_styles:
465
526
  if style in self.supported_parameter_styles:
@@ -467,18 +528,49 @@ class PsycopgAsyncDriver(
467
528
  break
468
529
 
469
530
  if statement.is_many:
470
- sql, _ = statement.compile(placeholder_style=target_style)
471
- # For execute_many, use the parameters passed via kwargs
472
- params = kwargs.get("parameters")
531
+ # Check if parameters were provided in kwargs first
532
+ kwargs_params = kwargs.get("parameters")
533
+ if kwargs_params is not None:
534
+ # Use the SQL string directly if parameters come from kwargs
535
+ sql = statement.to_sql(placeholder_style=target_style)
536
+ params = kwargs_params
537
+ else:
538
+ sql, _ = statement.compile(placeholder_style=target_style)
539
+ params = statement.parameters
473
540
  if params is not None:
474
- # Process each parameter set individually
475
541
  processed_params = [self._process_parameters(param_set) for param_set in params]
476
542
  params = processed_params
477
- return await self._execute_many(sql, params, connection=connection, **kwargs)
478
543
 
479
- sql, params = statement.compile(placeholder_style=target_style)
544
+ # Fix over-nested parameters for each param set
545
+ fixed_params = []
546
+ for param_set in params:
547
+ if isinstance(param_set, tuple) and len(param_set) == 1:
548
+ fixed_params.append(param_set[0])
549
+ else:
550
+ fixed_params.append(param_set)
551
+ params = fixed_params
552
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute_many method signature
553
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
554
+ return await self._execute_many(sql, params, connection=connection, **exec_kwargs)
555
+
556
+ # Check if parameters were provided in kwargs (user-provided parameters)
557
+ kwargs_params = kwargs.get("parameters")
558
+ if kwargs_params is not None:
559
+ # Use the SQL string directly if parameters come from kwargs
560
+ sql = statement.to_sql(placeholder_style=target_style)
561
+ params = kwargs_params
562
+ else:
563
+ sql, params = statement.compile(placeholder_style=target_style)
480
564
  params = self._process_parameters(params)
481
- return await self._execute(sql, params, statement, connection=connection, **kwargs)
565
+
566
+ # Fix over-nested parameters for Psycopg
567
+ # If params is a tuple containing a single tuple or dict, flatten it
568
+ if isinstance(params, tuple) and len(params) == 1 and isinstance(params[0], (tuple, dict, list)):
569
+ params = params[0]
570
+
571
+ # Remove 'parameters' from kwargs to avoid conflicts in _execute method signature
572
+ exec_kwargs = {k: v for k, v in kwargs.items() if k != "parameters"}
573
+ return await self._execute(sql, params, statement, connection=connection, **exec_kwargs)
482
574
 
483
575
  async def _execute(
484
576
  self,
@@ -487,48 +579,140 @@ class PsycopgAsyncDriver(
487
579
  statement: SQL,
488
580
  connection: Optional[PsycopgAsyncConnection] = None,
489
581
  **kwargs: Any,
490
- ) -> Union[SelectResultDict, DMLResultDict]:
491
- conn = self._connection(connection)
492
- async with conn.cursor() as cursor:
493
- await cursor.execute(cast("Query", sql), parameters)
494
-
495
- # When parsing is disabled, expression will be None, so check SQL directly
496
- if statement.expression and self.returns_rows(statement.expression):
497
- # For SELECT statements, extract data while cursor is open
498
- fetched_data = await cursor.fetchall()
499
- column_names = [col.name for col in cursor.description or []]
500
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
501
- if not statement.expression and sql.strip().upper().startswith("SELECT"):
502
- # For SELECT statements when parsing is disabled
503
- fetched_data = await cursor.fetchall()
504
- column_names = [col.name for col in cursor.description or []]
505
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
506
- # For DML statements
507
- dml_result: DMLResultDict = {
508
- "rows_affected": cursor.rowcount,
509
- "status_message": cursor.statusmessage or "OK",
510
- }
511
- return dml_result
582
+ ) -> SQLResult[RowT]:
583
+ # Use provided connection or driver's default connection
584
+ conn = connection if connection is not None else self._connection(None)
585
+
586
+ # Handle COPY commands separately (they don't use transactions)
587
+ sql_upper = sql.strip().upper()
588
+ if sql_upper.startswith("COPY") and ("FROM STDIN" in sql_upper or "TO STDOUT" in sql_upper):
589
+ return await self._handle_copy_command(sql, parameters, conn)
590
+
591
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
592
+ # For Psycopg, pass parameters directly to the driver
593
+ final_params = parameters
594
+
595
+ async with txn_conn.cursor() as cursor:
596
+ await cursor.execute(cast("Query", sql), final_params)
597
+
598
+ # When parsing is disabled, expression will be None, so check SQL directly
599
+ if statement.expression and self.returns_rows(statement.expression):
600
+ # For SELECT statements, extract data while cursor is open
601
+ fetched_data = await cursor.fetchall()
602
+ column_names = [col.name for col in cursor.description or []]
603
+ return SQLResult(
604
+ statement=statement,
605
+ data=cast("list[RowT]", fetched_data),
606
+ column_names=column_names,
607
+ rows_affected=len(fetched_data),
608
+ operation_type="SELECT",
609
+ )
610
+ if not statement.expression and sql.strip().upper().startswith("SELECT"):
611
+ # For SELECT statements when parsing is disabled
612
+ fetched_data = await cursor.fetchall()
613
+ column_names = [col.name for col in cursor.description or []]
614
+ return SQLResult(
615
+ statement=statement,
616
+ data=cast("list[RowT]", fetched_data),
617
+ column_names=column_names,
618
+ rows_affected=len(fetched_data),
619
+ operation_type="SELECT",
620
+ )
621
+ # For DML statements
622
+ operation_type = self._determine_operation_type(statement)
623
+ return SQLResult(
624
+ statement=statement,
625
+ data=[],
626
+ rows_affected=cursor.rowcount or 0,
627
+ operation_type=operation_type,
628
+ metadata={"status_message": cursor.statusmessage or "OK"},
629
+ )
630
+
631
+ async def _handle_copy_command(self, sql: str, data: Any, connection: PsycopgAsyncConnection) -> SQLResult[RowT]:
632
+ """Handle PostgreSQL COPY commands using cursor.copy() method."""
633
+ sql_upper = sql.strip().upper()
634
+
635
+ # Handle case where data is wrapped in a single-element tuple (from positional args)
636
+ if isinstance(data, tuple) and len(data) == 1:
637
+ data = data[0]
638
+
639
+ async with connection.cursor() as cursor:
640
+ if "TO STDOUT" in sql_upper:
641
+ # COPY TO STDOUT - read data from the database
642
+ output_data = []
643
+ async with cursor.copy(cast("Query", sql)) as copy:
644
+ output_data.extend([row async for row in copy])
645
+
646
+ return SQLResult(
647
+ statement=SQL(sql, _dialect=self.dialect),
648
+ data=cast("list[RowT]", output_data),
649
+ column_names=["copy_data"],
650
+ rows_affected=len(output_data),
651
+ operation_type="SELECT",
652
+ )
653
+ # COPY FROM STDIN - write data to the database
654
+ async with cursor.copy(cast("Query", sql)) as copy:
655
+ if data:
656
+ # If data is provided, write it to the copy stream
657
+ if isinstance(data, str):
658
+ await copy.write(data.encode("utf-8"))
659
+ elif isinstance(data, bytes):
660
+ await copy.write(data)
661
+ elif isinstance(data, (list, tuple)):
662
+ # If data is a list/tuple of rows, write each row
663
+ for row in data:
664
+ await copy.write_row(row)
665
+ else:
666
+ # Single row
667
+ await copy.write_row(data)
668
+
669
+ # For COPY operations, cursor.rowcount contains the number of rows affected
670
+ return SQLResult(
671
+ statement=SQL(sql, _dialect=self.dialect),
672
+ data=[],
673
+ rows_affected=cursor.rowcount or -1,
674
+ operation_type="EXECUTE",
675
+ metadata={"status_message": cursor.statusmessage or "COPY COMPLETE"},
676
+ )
512
677
 
513
678
  async def _execute_many(
514
679
  self, sql: str, param_list: Any, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
515
- ) -> DMLResultDict:
516
- conn = self._connection(connection)
517
- async with conn.cursor() as cursor:
518
- await cursor.executemany(cast("Query", sql), param_list or [])
519
- return {"rows_affected": cursor.rowcount, "status_message": cursor.statusmessage or "OK"}
680
+ ) -> SQLResult[RowT]:
681
+ # Use provided connection or driver's default connection
682
+ conn = connection if connection is not None else self._connection(None)
683
+
684
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
685
+ # Normalize parameter list using consolidated utility
686
+ normalized_param_list = normalize_parameter_sequence(param_list)
687
+ final_param_list = normalized_param_list or []
688
+
689
+ async with txn_conn.cursor() as cursor:
690
+ await cursor.executemany(cast("Query", sql), final_param_list)
691
+ return SQLResult(
692
+ statement=SQL(sql, _dialect=self.dialect),
693
+ data=[],
694
+ rows_affected=cursor.rowcount,
695
+ operation_type="EXECUTE",
696
+ metadata={"status_message": cursor.statusmessage or "OK"},
697
+ )
520
698
 
521
699
  async def _execute_script(
522
700
  self, script: str, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
523
- ) -> ScriptResultDict:
524
- conn = self._connection(connection)
525
- async with conn.cursor() as cursor:
701
+ ) -> SQLResult[RowT]:
702
+ # Use provided connection or driver's default connection
703
+ conn = connection if connection is not None else self._connection(None)
704
+
705
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn, txn_conn.cursor() as cursor:
526
706
  await cursor.execute(cast("Query", script))
527
- # For scripts, return script result format
528
- return {
529
- "statements_executed": -1, # Psycopg doesn't provide this info
530
- "status_message": cursor.statusmessage or "SCRIPT EXECUTED",
531
- }
707
+ return SQLResult(
708
+ statement=SQL(script, _dialect=self.dialect).as_script(),
709
+ data=[],
710
+ rows_affected=0,
711
+ operation_type="SCRIPT",
712
+ metadata={"status_message": cursor.statusmessage or "SCRIPT EXECUTED"},
713
+ total_statements=1,
714
+ successful_statements=1,
715
+ )
532
716
 
533
717
  async def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
534
718
  self._ensure_pyarrow_installed()
@@ -563,59 +747,6 @@ class PsycopgAsyncDriver(
563
747
 
564
748
  return cursor.rowcount if cursor.rowcount is not None else -1
565
749
 
566
- async def _wrap_select_result(
567
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
568
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
569
- # result must be a dict with keys: data, column_names, rows_affected
570
- fetched_data = result["data"]
571
- column_names = result["column_names"]
572
- rows_affected = result["rows_affected"]
573
- rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in fetched_data]
574
-
575
- if schema_type:
576
- return SQLResult[ModelDTOT](
577
- statement=statement,
578
- data=list(self.to_schema(data=fetched_data, schema_type=schema_type)),
579
- column_names=column_names,
580
- rows_affected=rows_affected,
581
- operation_type="SELECT",
582
- )
583
- return SQLResult[RowT](
584
- statement=statement,
585
- data=rows_as_dicts,
586
- column_names=column_names,
587
- rows_affected=rows_affected,
588
- operation_type="SELECT",
589
- )
590
-
591
- async def _wrap_execute_result(
592
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
593
- ) -> SQLResult[RowT]:
594
- operation_type = "UNKNOWN"
595
- if statement.expression:
596
- operation_type = str(statement.expression.key).upper()
597
-
598
- if is_dict_with_field(result, "statements_executed"):
599
- return SQLResult[RowT](
600
- statement=statement,
601
- data=[],
602
- rows_affected=0,
603
- operation_type="SCRIPT",
604
- metadata={"status_message": result.get("status_message", "")},
605
- )
606
-
607
- if is_dict_with_field(result, "rows_affected"):
608
- return SQLResult[RowT](
609
- statement=statement,
610
- data=[],
611
- rows_affected=cast("int", result.get("rows_affected", -1)),
612
- operation_type=operation_type,
613
- metadata={"status_message": result.get("status_message", "")},
614
- )
615
- # This shouldn't happen with TypedDict approach
616
- msg = f"Unexpected result type: {type(result)}"
617
- raise ValueError(msg)
618
-
619
750
  def _connection(self, connection: Optional[PsycopgAsyncConnection] = None) -> PsycopgAsyncConnection:
620
751
  """Get the connection to use for the operation."""
621
752
  return connection or self.connection
@@ -648,7 +779,6 @@ class PsycopgAsyncDriver(
648
779
  from sqlspec.exceptions import PipelineExecutionError
649
780
 
650
781
  try:
651
- # Prepare SQL and parameters
652
782
  filtered_sql = self._apply_operation_filters(operation.sql, operation.filters)
653
783
  sql_str = filtered_sql.to_sql(placeholder_style=self.default_parameter_style)
654
784
  params = self._convert_psycopg_params(filtered_sql.parameters)
@@ -670,7 +800,6 @@ class PsycopgAsyncDriver(
670
800
  msg, operation_index=index, partial_results=[], failed_operation=operation
671
801
  ) from e
672
802
  else:
673
- # Add pipeline context
674
803
  result.operation_index = index
675
804
  result.pipeline_sql = operation.sql
676
805
  return result
@@ -698,7 +827,7 @@ class PsycopgAsyncDriver(
698
827
  statement=sql,
699
828
  data=cast("list[RowT]", []),
700
829
  rows_affected=cursor.rowcount,
701
- operation_type="execute_many",
830
+ operation_type="EXECUTE",
702
831
  metadata={"status_message": "OK"},
703
832
  )
704
833
 
@@ -715,7 +844,7 @@ class PsycopgAsyncDriver(
715
844
  statement=sql,
716
845
  data=cast("list[RowT]", data),
717
846
  rows_affected=len(data),
718
- operation_type="select",
847
+ operation_type="SELECT",
719
848
  metadata={"column_names": column_names},
720
849
  )
721
850
 
@@ -736,7 +865,7 @@ class PsycopgAsyncDriver(
736
865
  statement=sql,
737
866
  data=cast("list[RowT]", []),
738
867
  rows_affected=total_affected,
739
- operation_type="execute_script",
868
+ operation_type="SCRIPT",
740
869
  metadata={"status_message": "SCRIPT EXECUTED", "statements_executed": len(script_statements)},
741
870
  )
742
871
 
@@ -750,7 +879,7 @@ class PsycopgAsyncDriver(
750
879
  statement=sql,
751
880
  data=cast("list[RowT]", []),
752
881
  rows_affected=cursor.rowcount or 0,
753
- operation_type="execute",
882
+ operation_type="EXECUTE",
754
883
  metadata={"status_message": "OK"},
755
884
  )
756
885
 
@@ -771,7 +900,6 @@ class PsycopgAsyncDriver(
771
900
  # Psycopg handles dict parameters directly for named placeholders
772
901
  return params
773
902
  if isinstance(params, (list, tuple)):
774
- # Convert to tuple for positional parameters
775
903
  return tuple(params)
776
904
  # Single parameter
777
905
  return (params,)