sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,11 +1,12 @@
1
1
  from collections.abc import AsyncGenerator, Generator
2
2
  from contextlib import asynccontextmanager, contextmanager
3
- from typing import Any, ClassVar, Optional, Union, cast
3
+ from typing import Any, ClassVar, Optional, cast
4
4
 
5
5
  from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor
6
6
  from sqlglot.dialects.dialect import DialectType
7
7
 
8
8
  from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
9
+ from sqlspec.driver.connection import managed_transaction_async, managed_transaction_sync
9
10
  from sqlspec.driver.mixins import (
10
11
  AsyncPipelinedExecutionMixin,
11
12
  AsyncStorageMixin,
@@ -15,10 +16,11 @@ from sqlspec.driver.mixins import (
15
16
  ToSchemaMixin,
16
17
  TypeCoercionMixin,
17
18
  )
18
- from sqlspec.statement.parameters import ParameterStyle
19
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
19
+ from sqlspec.driver.parameters import normalize_parameter_sequence
20
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
21
+ from sqlspec.statement.result import ArrowResult, SQLResult
20
22
  from sqlspec.statement.sql import SQL, SQLConfig
21
- from sqlspec.typing import DictRow, ModelDTOT, RowT, SQLParameterType
23
+ from sqlspec.typing import DictRow, RowT, SQLParameterType
22
24
  from sqlspec.utils.logging import get_logger
23
25
  from sqlspec.utils.sync_tools import ensure_async_
24
26
 
@@ -41,30 +43,21 @@ def _process_oracle_parameters(params: Any) -> Any:
41
43
  if params is None:
42
44
  return None
43
45
 
44
- # Handle TypedParameter objects
45
46
  if isinstance(params, TypedParameter):
46
47
  return _process_oracle_parameters(params.value)
47
48
 
48
49
  if isinstance(params, tuple):
49
- # Convert single tuple to list and process each element
50
50
  return [_process_oracle_parameters(item) for item in params]
51
51
  if isinstance(params, list):
52
- # Process list of parameter sets
53
52
  processed = []
54
53
  for param_set in params:
55
- if isinstance(param_set, tuple):
56
- # Convert tuple to list and process each element
57
- processed.append([_process_oracle_parameters(item) for item in param_set])
58
- elif isinstance(param_set, list):
59
- # Process each element in the list
54
+ if isinstance(param_set, (tuple, list)):
60
55
  processed.append([_process_oracle_parameters(item) for item in param_set])
61
56
  else:
62
57
  processed.append(_process_oracle_parameters(param_set))
63
58
  return processed
64
59
  if isinstance(params, dict):
65
- # Process dict values
66
60
  return {key: _process_oracle_parameters(value) for key, value in params.items()}
67
- # Return as-is for other types
68
61
  return params
69
62
 
70
63
 
@@ -114,22 +107,24 @@ class OracleSyncDriver(
114
107
 
115
108
  def _execute_statement(
116
109
  self, statement: SQL, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
117
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
110
+ ) -> SQLResult[RowT]:
118
111
  if statement.is_script:
119
112
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
120
113
  return self._execute_script(sql, connection=connection, **kwargs)
121
114
 
122
- # Determine if we need to convert parameter style
123
- detected_styles = {p.style for p in statement.parameter_info}
115
+ detected_styles = set()
116
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
117
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
118
+ param_infos = validator.extract_parameters(sql_str)
119
+ if param_infos:
120
+ detected_styles = {p.style for p in param_infos}
121
+
124
122
  target_style = self.default_parameter_style
125
123
 
126
- # Check if any detected style is not supported
127
124
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
128
125
  if unsupported_styles:
129
- # Convert to default style if we have unsupported styles
130
126
  target_style = self.default_parameter_style
131
127
  elif detected_styles:
132
- # Use the first detected style if all are supported
133
128
  # Prefer the first supported style found
134
129
  for style in detected_styles:
135
130
  if style in self.supported_parameter_styles:
@@ -138,32 +133,10 @@ class OracleSyncDriver(
138
133
 
139
134
  if statement.is_many:
140
135
  sql, params = statement.compile(placeholder_style=target_style)
141
- # Process parameters to convert tuples to lists for Oracle
142
136
  params = self._process_parameters(params)
143
- # Oracle doesn't like underscores in bind parameter names
144
- if isinstance(params, list) and params and isinstance(params[0], dict):
145
- # Fix the SQL and parameters
146
- for key in list(params[0].keys()):
147
- if key.startswith("_arg_"):
148
- # Remove leading underscore: _arg_0 -> arg0
149
- new_key = key[1:].replace("_", "")
150
- sql = sql.replace(f":{key}", f":{new_key}")
151
- # Update all parameter sets
152
- for param_set in params:
153
- if isinstance(param_set, dict) and key in param_set:
154
- param_set[new_key] = param_set.pop(key)
155
137
  return self._execute_many(sql, params, connection=connection, **kwargs)
156
138
 
157
139
  sql, params = statement.compile(placeholder_style=target_style)
158
- # Oracle doesn't like underscores in bind parameter names
159
- if isinstance(params, dict):
160
- # Fix the SQL and parameters
161
- for key in list(params.keys()):
162
- if key.startswith("_arg_"):
163
- # Remove leading underscore: _arg_0 -> arg0
164
- new_key = key[1:].replace("_", "")
165
- sql = sql.replace(f":{key}", f":{new_key}")
166
- params[new_key] = params.pop(key)
167
140
  return self._execute(sql, params, statement, connection=connection, **kwargs)
168
141
 
169
142
  def _execute(
@@ -173,65 +146,130 @@ class OracleSyncDriver(
173
146
  statement: SQL,
174
147
  connection: Optional[OracleSyncConnection] = None,
175
148
  **kwargs: Any,
176
- ) -> Union[SelectResultDict, DMLResultDict]:
149
+ ) -> SQLResult[RowT]:
150
+ # Use provided connection or driver's default connection
177
151
  conn = self._connection(connection)
178
- with self._get_cursor(conn) as cursor:
179
- # Process parameters to extract values from TypedParameter objects
180
- processed_params = self._process_parameters(parameters) if parameters else []
181
- cursor.execute(sql, processed_params)
182
-
183
- if self.returns_rows(statement.expression):
184
- fetched_data = cursor.fetchall()
185
- column_names = [col[0] for col in cursor.description or []]
186
- return {"data": fetched_data, "column_names": column_names, "rows_affected": cursor.rowcount}
187
152
 
188
- return {"rows_affected": cursor.rowcount, "status_message": "OK"}
153
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
154
+ # Oracle requires special parameter handling
155
+ processed_params = self._process_parameters(parameters) if parameters is not None else []
156
+
157
+ with self._get_cursor(txn_conn) as cursor:
158
+ cursor.execute(sql, processed_params)
159
+
160
+ if self.returns_rows(statement.expression):
161
+ fetched_data = cursor.fetchall()
162
+ column_names = [col[0] for col in cursor.description or []]
163
+
164
+ # Convert to dict if default_row_type is dict
165
+ if self.default_row_type == DictRow or issubclass(self.default_row_type, dict):
166
+ data = cast("list[RowT]", [dict(zip(column_names, row)) for row in fetched_data])
167
+ else:
168
+ data = cast("list[RowT]", fetched_data)
169
+
170
+ return SQLResult(
171
+ statement=statement,
172
+ data=data,
173
+ column_names=column_names,
174
+ rows_affected=cursor.rowcount,
175
+ operation_type="SELECT",
176
+ )
177
+
178
+ return SQLResult(
179
+ statement=statement,
180
+ data=[],
181
+ rows_affected=cursor.rowcount,
182
+ operation_type=self._determine_operation_type(statement),
183
+ metadata={"status_message": "OK"},
184
+ )
189
185
 
190
186
  def _execute_many(
191
187
  self, sql: str, param_list: Any, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
192
- ) -> DMLResultDict:
188
+ ) -> SQLResult[RowT]:
189
+ # Use provided connection or driver's default connection
193
190
  conn = self._connection(connection)
194
- with self._get_cursor(conn) as cursor:
195
- # Handle None or empty param_list
196
- if param_list is None:
197
- param_list = []
198
- # Ensure param_list is a list of parameter sets
199
- elif param_list and not isinstance(param_list, list):
191
+
192
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
193
+ # Normalize parameter list using consolidated utility
194
+ normalized_param_list = normalize_parameter_sequence(param_list)
195
+
196
+ # Process parameters for Oracle
197
+ if normalized_param_list is None:
198
+ processed_param_list = []
199
+ elif normalized_param_list and not isinstance(normalized_param_list, list):
200
200
  # Single parameter set, wrap it
201
- param_list = [param_list]
202
- elif param_list and not isinstance(param_list[0], (list, tuple, dict)):
201
+ processed_param_list = [normalized_param_list]
202
+ elif normalized_param_list and not isinstance(normalized_param_list[0], (list, tuple, dict)):
203
203
  # Already a flat list, likely from incorrect usage
204
- param_list = [param_list]
204
+ processed_param_list = [normalized_param_list]
205
+ else:
206
+ processed_param_list = normalized_param_list
207
+
205
208
  # Parameters have already been processed in _execute_statement
206
- cursor.executemany(sql, param_list)
207
- return {"rows_affected": cursor.rowcount, "status_message": "OK"}
209
+ with self._get_cursor(txn_conn) as cursor:
210
+ cursor.executemany(sql, processed_param_list or [])
211
+ return SQLResult(
212
+ statement=SQL(sql, _dialect=self.dialect),
213
+ data=[],
214
+ rows_affected=cursor.rowcount,
215
+ operation_type="EXECUTE",
216
+ metadata={"status_message": "OK"},
217
+ )
208
218
 
209
219
  def _execute_script(
210
220
  self, script: str, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
211
- ) -> ScriptResultDict:
221
+ ) -> SQLResult[RowT]:
222
+ # Use provided connection or driver's default connection
212
223
  conn = self._connection(connection)
213
- statements = self._split_script_statements(script, strip_trailing_semicolon=True)
214
- with self._get_cursor(conn) as cursor:
215
- for statement in statements:
216
- if statement and statement.strip():
217
- cursor.execute(statement.strip())
218
224
 
219
- return {"statements_executed": len(statements), "status_message": "SCRIPT EXECUTED"}
225
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
226
+ statements = self._split_script_statements(script, strip_trailing_semicolon=True)
227
+ with self._get_cursor(txn_conn) as cursor:
228
+ for statement in statements:
229
+ if statement and statement.strip():
230
+ cursor.execute(statement.strip())
231
+
232
+ return SQLResult(
233
+ statement=SQL(script, _dialect=self.dialect).as_script(),
234
+ data=[],
235
+ rows_affected=0,
236
+ operation_type="SCRIPT",
237
+ metadata={"status_message": "SCRIPT EXECUTED"},
238
+ total_statements=len(statements),
239
+ successful_statements=len(statements),
240
+ )
220
241
 
221
242
  def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
222
243
  self._ensure_pyarrow_installed()
223
244
  conn = self._connection(connection)
224
245
 
225
- # Get SQL and parameters using compile to ensure they match
226
- # For fetch_arrow_table, we need to use POSITIONAL_COLON style since the SQL has :1 placeholders
227
- sql_str, params = sql.compile(placeholder_style=ParameterStyle.POSITIONAL_COLON)
228
- if params is None:
229
- params = []
246
+ # Use the exact same parameter style detection logic as _execute_statement
247
+ detected_styles = set()
248
+ sql_str = sql.to_sql(placeholder_style=None) # Get raw SQL
249
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
250
+ param_infos = validator.extract_parameters(sql_str)
251
+ if param_infos:
252
+ detected_styles = {p.style for p in param_infos}
230
253
 
231
- # Process parameters to extract values from TypedParameter objects
232
- processed_params = self._process_parameters(params) if params else []
254
+ target_style = self.default_parameter_style
255
+
256
+ unsupported_styles = detected_styles - set(self.supported_parameter_styles)
257
+ if unsupported_styles:
258
+ target_style = self.default_parameter_style
259
+ elif detected_styles:
260
+ # Prefer the first supported style found
261
+ for style in detected_styles:
262
+ if style in self.supported_parameter_styles:
263
+ target_style = style
264
+ break
265
+
266
+ sql_str, params = sql.compile(placeholder_style=target_style)
267
+ processed_params = self._process_parameters(params) if params is not None else []
268
+
269
+ # Use proper transaction management like other methods
270
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
271
+ oracle_df = txn_conn.fetch_df_all(sql_str, processed_params)
233
272
 
234
- oracle_df = conn.fetch_df_all(sql_str, processed_params)
235
273
  from pyarrow.interchange.from_dataframe import from_dataframe
236
274
 
237
275
  arrow_table = from_dataframe(oracle_df)
@@ -242,7 +280,8 @@ class OracleSyncDriver(
242
280
  self._ensure_pyarrow_installed()
243
281
  conn = self._connection(None)
244
282
 
245
- with self._get_cursor(conn) as cursor:
283
+ # Use proper transaction management like other methods
284
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
246
285
  if mode == "replace":
247
286
  cursor.execute(f"TRUNCATE TABLE {table_name}")
248
287
  elif mode == "create":
@@ -260,57 +299,9 @@ class OracleSyncDriver(
260
299
  cursor.executemany(sql, data_for_ingest)
261
300
  return cursor.rowcount
262
301
 
263
- def _wrap_select_result(
264
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
265
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
266
- fetched_tuples = result.get("data", [])
267
- column_names = result.get("column_names", [])
268
-
269
- if not fetched_tuples:
270
- return SQLResult[RowT](statement=statement, data=[], column_names=column_names, operation_type="SELECT")
271
-
272
- rows_as_dicts: list[dict[str, Any]] = [dict(zip(column_names, row_tuple)) for row_tuple in fetched_tuples]
273
-
274
- if schema_type:
275
- converted_data = self.to_schema(rows_as_dicts, schema_type=schema_type)
276
- return SQLResult[ModelDTOT](
277
- statement=statement, data=list(converted_data), column_names=column_names, operation_type="SELECT"
278
- )
279
-
280
- return SQLResult[RowT](
281
- statement=statement, data=rows_as_dicts, column_names=column_names, operation_type="SELECT"
282
- )
283
-
284
- def _wrap_execute_result(
285
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
286
- ) -> SQLResult[RowT]:
287
- operation_type = "UNKNOWN"
288
- if statement.expression:
289
- operation_type = str(statement.expression.key).upper()
290
-
291
- if "statements_executed" in result:
292
- script_result = cast("ScriptResultDict", result)
293
- return SQLResult[RowT](
294
- statement=statement,
295
- data=[],
296
- rows_affected=0,
297
- operation_type="SCRIPT",
298
- metadata={
299
- "status_message": script_result.get("status_message", ""),
300
- "statements_executed": script_result.get("statements_executed", -1),
301
- },
302
- )
303
-
304
- dml_result = cast("DMLResultDict", result)
305
- rows_affected = dml_result.get("rows_affected", -1)
306
- status_message = dml_result.get("status_message", "")
307
- return SQLResult[RowT](
308
- statement=statement,
309
- data=[],
310
- rows_affected=rows_affected,
311
- operation_type=operation_type,
312
- metadata={"status_message": status_message},
313
- )
302
+ def _connection(self, connection: Optional[OracleSyncConnection] = None) -> OracleSyncConnection:
303
+ """Get the connection to use for the operation."""
304
+ return connection or self.connection
314
305
 
315
306
 
316
307
  class OracleAsyncDriver(
@@ -362,22 +353,24 @@ class OracleAsyncDriver(
362
353
 
363
354
  async def _execute_statement(
364
355
  self, statement: SQL, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
365
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
356
+ ) -> SQLResult[RowT]:
366
357
  if statement.is_script:
367
358
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
368
359
  return await self._execute_script(sql, connection=connection, **kwargs)
369
360
 
370
- # Determine if we need to convert parameter style
371
- detected_styles = {p.style for p in statement.parameter_info}
361
+ detected_styles = set()
362
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
363
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
364
+ param_infos = validator.extract_parameters(sql_str)
365
+ if param_infos:
366
+ detected_styles = {p.style for p in param_infos}
367
+
372
368
  target_style = self.default_parameter_style
373
369
 
374
- # Check if any detected style is not supported
375
370
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
376
371
  if unsupported_styles:
377
- # Convert to default style if we have unsupported styles
378
372
  target_style = self.default_parameter_style
379
373
  elif detected_styles:
380
- # Use the first detected style if all are supported
381
374
  # Prefer the first supported style found
382
375
  for style in detected_styles:
383
376
  if style in self.supported_parameter_styles:
@@ -386,32 +379,20 @@ class OracleAsyncDriver(
386
379
 
387
380
  if statement.is_many:
388
381
  sql, params = statement.compile(placeholder_style=target_style)
389
- # Process parameters to convert tuples to lists for Oracle
390
382
  params = self._process_parameters(params)
391
383
  # Oracle doesn't like underscores in bind parameter names
392
384
  if isinstance(params, list) and params and isinstance(params[0], dict):
393
385
  # Fix the SQL and parameters
394
386
  for key in list(params[0].keys()):
395
387
  if key.startswith("_arg_"):
396
- # Remove leading underscore: _arg_0 -> arg0
397
388
  new_key = key[1:].replace("_", "")
398
389
  sql = sql.replace(f":{key}", f":{new_key}")
399
- # Update all parameter sets
400
390
  for param_set in params:
401
391
  if isinstance(param_set, dict) and key in param_set:
402
392
  param_set[new_key] = param_set.pop(key)
403
393
  return await self._execute_many(sql, params, connection=connection, **kwargs)
404
394
 
405
395
  sql, params = statement.compile(placeholder_style=target_style)
406
- # Oracle doesn't like underscores in bind parameter names
407
- if isinstance(params, dict):
408
- # Fix the SQL and parameters
409
- for key in list(params.keys()):
410
- if key.startswith("_arg_"):
411
- # Remove leading underscore: _arg_0 -> arg0
412
- new_key = key[1:].replace("_", "")
413
- sql = sql.replace(f":{key}", f":{new_key}")
414
- params[new_key] = params.pop(key)
415
396
  return await self._execute(sql, params, statement, connection=connection, **kwargs)
416
397
 
417
398
  async def _execute(
@@ -421,77 +402,132 @@ class OracleAsyncDriver(
421
402
  statement: SQL,
422
403
  connection: Optional[OracleAsyncConnection] = None,
423
404
  **kwargs: Any,
424
- ) -> Union[SelectResultDict, DMLResultDict]:
405
+ ) -> SQLResult[RowT]:
406
+ # Use provided connection or driver's default connection
425
407
  conn = self._connection(connection)
426
- async with self._get_cursor(conn) as cursor:
427
- if parameters is None:
428
- await cursor.execute(sql)
429
- else:
430
- # Process parameters to extract values from TypedParameter objects
431
- processed_params = self._process_parameters(parameters)
432
- await cursor.execute(sql, processed_params)
433
-
434
- # For SELECT statements, extract data while cursor is open
435
- if self.returns_rows(statement.expression):
436
- fetched_data = await cursor.fetchall()
437
- column_names = [col[0] for col in cursor.description or []]
438
- result: SelectResultDict = {
439
- "data": fetched_data,
440
- "column_names": column_names,
441
- "rows_affected": cursor.rowcount,
442
- }
443
- return result
444
- dml_result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
445
- return dml_result
408
+
409
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
410
+ # Oracle requires special parameter handling
411
+ processed_params = self._process_parameters(parameters) if parameters is not None else []
412
+
413
+ async with self._get_cursor(txn_conn) as cursor:
414
+ if parameters is None:
415
+ await cursor.execute(sql)
416
+ else:
417
+ await cursor.execute(sql, processed_params)
418
+
419
+ # For SELECT statements, extract data while cursor is open
420
+ if self.returns_rows(statement.expression):
421
+ fetched_data = await cursor.fetchall()
422
+ column_names = [col[0] for col in cursor.description or []]
423
+
424
+ # Convert to dict if default_row_type is dict
425
+ if self.default_row_type == DictRow or issubclass(self.default_row_type, dict):
426
+ data = cast("list[RowT]", [dict(zip(column_names, row)) for row in fetched_data])
427
+ else:
428
+ data = cast("list[RowT]", fetched_data)
429
+
430
+ return SQLResult(
431
+ statement=statement,
432
+ data=data,
433
+ column_names=column_names,
434
+ rows_affected=cursor.rowcount,
435
+ operation_type="SELECT",
436
+ )
437
+
438
+ return SQLResult(
439
+ statement=statement,
440
+ data=[],
441
+ rows_affected=cursor.rowcount,
442
+ operation_type=self._determine_operation_type(statement),
443
+ metadata={"status_message": "OK"},
444
+ )
446
445
 
447
446
  async def _execute_many(
448
447
  self, sql: str, param_list: Any, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
449
- ) -> DMLResultDict:
448
+ ) -> SQLResult[RowT]:
449
+ # Use provided connection or driver's default connection
450
450
  conn = self._connection(connection)
451
- async with self._get_cursor(conn) as cursor:
452
- # Handle None or empty param_list
453
- if param_list is None:
454
- param_list = []
455
- # Ensure param_list is a list of parameter sets
456
- elif param_list and not isinstance(param_list, list):
451
+
452
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
453
+ # Normalize parameter list using consolidated utility
454
+ normalized_param_list = normalize_parameter_sequence(param_list)
455
+
456
+ # Process parameters for Oracle
457
+ if normalized_param_list is None:
458
+ processed_param_list = []
459
+ elif normalized_param_list and not isinstance(normalized_param_list, list):
457
460
  # Single parameter set, wrap it
458
- param_list = [param_list]
459
- elif param_list and not isinstance(param_list[0], (list, tuple, dict)):
461
+ processed_param_list = [normalized_param_list]
462
+ elif normalized_param_list and not isinstance(normalized_param_list[0], (list, tuple, dict)):
460
463
  # Already a flat list, likely from incorrect usage
461
- param_list = [param_list]
464
+ processed_param_list = [normalized_param_list]
465
+ else:
466
+ processed_param_list = normalized_param_list
467
+
462
468
  # Parameters have already been processed in _execute_statement
463
- await cursor.executemany(sql, param_list)
464
- result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
465
- return result
469
+ async with self._get_cursor(txn_conn) as cursor:
470
+ await cursor.executemany(sql, processed_param_list or [])
471
+ return SQLResult(
472
+ statement=SQL(sql, _dialect=self.dialect),
473
+ data=[],
474
+ rows_affected=cursor.rowcount,
475
+ operation_type="EXECUTE",
476
+ metadata={"status_message": "OK"},
477
+ )
466
478
 
467
479
  async def _execute_script(
468
480
  self, script: str, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
469
- ) -> ScriptResultDict:
481
+ ) -> SQLResult[RowT]:
482
+ # Use provided connection or driver's default connection
470
483
  conn = self._connection(connection)
471
- # Oracle doesn't support multi-statement scripts in a single execute
472
- # The splitter now handles PL/SQL blocks correctly when strip_trailing_semicolon=True
473
- statements = self._split_script_statements(script, strip_trailing_semicolon=True)
474
484
 
475
- async with self._get_cursor(conn) as cursor:
476
- for statement in statements:
477
- if statement and statement.strip():
478
- await cursor.execute(statement.strip())
485
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
486
+ # Oracle doesn't support multi-statement scripts in a single execute
487
+ # The splitter now handles PL/SQL blocks correctly when strip_trailing_semicolon=True
488
+ statements = self._split_script_statements(script, strip_trailing_semicolon=True)
479
489
 
480
- result: ScriptResultDict = {"statements_executed": len(statements), "status_message": "SCRIPT EXECUTED"}
481
- return result
490
+ async with self._get_cursor(txn_conn) as cursor:
491
+ for statement in statements:
492
+ if statement and statement.strip():
493
+ await cursor.execute(statement.strip())
494
+
495
+ return SQLResult(
496
+ statement=SQL(script, _dialect=self.dialect).as_script(),
497
+ data=[],
498
+ rows_affected=0,
499
+ operation_type="SCRIPT",
500
+ metadata={"status_message": "SCRIPT EXECUTED"},
501
+ total_statements=len(statements),
502
+ successful_statements=len(statements),
503
+ )
482
504
 
483
505
  async def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
484
506
  self._ensure_pyarrow_installed()
485
507
  conn = self._connection(connection)
486
508
 
487
- # Get SQL and parameters using compile to ensure they match
488
- # For fetch_arrow_table, we need to use POSITIONAL_COLON style since the SQL has :1 placeholders
489
- sql_str, params = sql.compile(placeholder_style=ParameterStyle.POSITIONAL_COLON)
490
- if params is None:
491
- params = []
509
+ # Use the exact same parameter style detection logic as _execute_statement
510
+ detected_styles = set()
511
+ sql_str = sql.to_sql(placeholder_style=None) # Get raw SQL
512
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
513
+ param_infos = validator.extract_parameters(sql_str)
514
+ if param_infos:
515
+ detected_styles = {p.style for p in param_infos}
492
516
 
493
- # Process parameters to extract values from TypedParameter objects
494
- processed_params = self._process_parameters(params) if params else []
517
+ target_style = self.default_parameter_style
518
+
519
+ unsupported_styles = detected_styles - set(self.supported_parameter_styles)
520
+ if unsupported_styles:
521
+ target_style = self.default_parameter_style
522
+ elif detected_styles:
523
+ # Prefer the first supported style found
524
+ for style in detected_styles:
525
+ if style in self.supported_parameter_styles:
526
+ target_style = style
527
+ break
528
+
529
+ sql_str, params = sql.compile(placeholder_style=target_style)
530
+ processed_params = self._process_parameters(params) if params is not None else []
495
531
 
496
532
  oracle_df = await conn.fetch_df_all(sql_str, processed_params)
497
533
  from pyarrow.interchange.from_dataframe import from_dataframe
@@ -504,7 +540,8 @@ class OracleAsyncDriver(
504
540
  self._ensure_pyarrow_installed()
505
541
  conn = self._connection(None)
506
542
 
507
- async with self._get_cursor(conn) as cursor:
543
+ # Use proper transaction management like other methods
544
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
508
545
  if mode == "replace":
509
546
  await cursor.execute(f"TRUNCATE TABLE {table_name}")
510
547
  elif mode == "create":
@@ -522,60 +559,6 @@ class OracleAsyncDriver(
522
559
  await cursor.executemany(sql, data_for_ingest)
523
560
  return cursor.rowcount
524
561
 
525
- async def _wrap_select_result(
526
- self,
527
- statement: SQL,
528
- result: SelectResultDict,
529
- schema_type: Optional[type[ModelDTOT]] = None,
530
- **kwargs: Any, # pyright: ignore[reportUnusedParameter]
531
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
532
- fetched_tuples = result["data"]
533
- column_names = result["column_names"]
534
-
535
- if not fetched_tuples:
536
- return SQLResult[RowT](statement=statement, data=[], column_names=column_names, operation_type="SELECT")
537
-
538
- rows_as_dicts: list[dict[str, Any]] = [dict(zip(column_names, row_tuple)) for row_tuple in fetched_tuples]
539
-
540
- if schema_type:
541
- converted_data = self.to_schema(rows_as_dicts, schema_type=schema_type)
542
- return SQLResult[ModelDTOT](
543
- statement=statement, data=list(converted_data), column_names=column_names, operation_type="SELECT"
544
- )
545
- return SQLResult[RowT](
546
- statement=statement, data=rows_as_dicts, column_names=column_names, operation_type="SELECT"
547
- )
548
-
549
- async def _wrap_execute_result(
550
- self,
551
- statement: SQL,
552
- result: Union[DMLResultDict, ScriptResultDict],
553
- **kwargs: Any, # pyright: ignore[reportUnusedParameter]
554
- ) -> SQLResult[RowT]:
555
- operation_type = "UNKNOWN"
556
- if statement.expression:
557
- operation_type = str(statement.expression.key).upper()
558
-
559
- if "statements_executed" in result:
560
- script_result = cast("ScriptResultDict", result)
561
- return SQLResult[RowT](
562
- statement=statement,
563
- data=[],
564
- rows_affected=0,
565
- operation_type="SCRIPT",
566
- metadata={
567
- "status_message": script_result.get("status_message", ""),
568
- "statements_executed": script_result.get("statements_executed", -1),
569
- },
570
- )
571
-
572
- dml_result = cast("DMLResultDict", result)
573
- rows_affected = dml_result.get("rows_affected", -1)
574
- status_message = dml_result.get("status_message", "")
575
- return SQLResult[RowT](
576
- statement=statement,
577
- data=[],
578
- rows_affected=rows_affected,
579
- operation_type=operation_type,
580
- metadata={"status_message": status_message},
581
- )
562
+ def _connection(self, connection: Optional[OracleAsyncConnection] = None) -> OracleAsyncConnection:
563
+ """Get the connection to use for the operation."""
564
+ return connection or self.connection