sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -2,12 +2,14 @@ import contextlib
2
2
  import uuid
3
3
  from collections.abc import Generator
4
4
  from contextlib import contextmanager
5
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
6
7
 
7
8
  from duckdb import DuckDBPyConnection
8
9
  from sqlglot import exp
9
10
 
10
11
  from sqlspec.driver import SyncDriverAdapterProtocol
12
+ from sqlspec.driver.connection import managed_transaction_sync
11
13
  from sqlspec.driver.mixins import (
12
14
  SQLTranslatorMixin,
13
15
  SyncPipelinedExecutionMixin,
@@ -15,10 +17,11 @@ from sqlspec.driver.mixins import (
15
17
  ToSchemaMixin,
16
18
  TypeCoercionMixin,
17
19
  )
20
+ from sqlspec.driver.parameters import normalize_parameter_sequence
18
21
  from sqlspec.statement.parameters import ParameterStyle
19
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
22
+ from sqlspec.statement.result import ArrowResult, SQLResult
20
23
  from sqlspec.statement.sql import SQL, SQLConfig
21
- from sqlspec.typing import ArrowTable, DictRow, ModelDTOT, RowT
24
+ from sqlspec.typing import ArrowTable, DictRow, RowT
22
25
  from sqlspec.utils.logging import get_logger
23
26
 
24
27
  if TYPE_CHECKING:
@@ -81,136 +84,129 @@ class DuckDBDriver(
81
84
 
82
85
  def _execute_statement(
83
86
  self, statement: SQL, connection: Optional["DuckDBConnection"] = None, **kwargs: Any
84
- ) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
87
+ ) -> SQLResult[RowT]:
85
88
  if statement.is_script:
86
89
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
87
90
  return self._execute_script(sql, connection=connection, **kwargs)
88
91
 
92
+ sql, params = statement.compile(placeholder_style=self.default_parameter_style)
93
+ params = self._process_parameters(params)
94
+
89
95
  if statement.is_many:
90
- sql, params = statement.compile(placeholder_style=self.default_parameter_style)
91
- params = self._process_parameters(params)
92
96
  return self._execute_many(sql, params, connection=connection, **kwargs)
93
97
 
94
- sql, params = statement.compile(placeholder_style=self.default_parameter_style)
95
- params = self._process_parameters(params)
96
98
  return self._execute(sql, params, statement, connection=connection, **kwargs)
97
99
 
98
100
  def _execute(
99
101
  self, sql: str, parameters: Any, statement: SQL, connection: Optional["DuckDBConnection"] = None, **kwargs: Any
100
- ) -> "Union[SelectResultDict, DMLResultDict]":
101
- conn = self._connection(connection)
102
+ ) -> SQLResult[RowT]:
103
+ # Use provided connection or driver's default connection
104
+ conn = connection if connection is not None else self._connection(None)
105
+
106
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
107
+ # Normalize parameters using consolidated utility
108
+ normalized_params = normalize_parameter_sequence(parameters)
109
+ final_params = normalized_params or []
110
+
111
+ if self.returns_rows(statement.expression):
112
+ result = txn_conn.execute(sql, final_params)
113
+ fetched_data = result.fetchall()
114
+ column_names = [col[0] for col in result.description or []]
115
+
116
+ if fetched_data and isinstance(fetched_data[0], tuple):
117
+ dict_data = [dict(zip(column_names, row)) for row in fetched_data]
118
+ else:
119
+ dict_data = fetched_data
120
+
121
+ return SQLResult[RowT](
122
+ statement=statement,
123
+ data=dict_data, # type: ignore[arg-type]
124
+ column_names=column_names,
125
+ rows_affected=len(dict_data),
126
+ operation_type="SELECT",
127
+ )
102
128
 
103
- if self.returns_rows(statement.expression):
104
- result = conn.execute(sql, parameters or [])
105
- fetched_data = result.fetchall()
106
- column_names = [col[0] for col in result.description or []]
107
- return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
108
-
109
- with self._get_cursor(conn) as cursor:
110
- cursor.execute(sql, parameters or [])
111
- # DuckDB returns -1 for rowcount on DML operations
112
- # However, fetchone() returns the actual affected row count as (count,)
113
- rows_affected = cursor.rowcount
114
- if rows_affected < 0:
115
- try:
116
- # Get actual affected row count from fetchone()
117
- fetch_result = cursor.fetchone()
118
- if fetch_result and isinstance(fetch_result, (tuple, list)) and len(fetch_result) > 0:
119
- rows_affected = fetch_result[0]
120
- else:
121
- rows_affected = 0
122
- except Exception:
123
- # Fallback to 1 if fetchone fails
124
- rows_affected = 1
125
- return {"rows_affected": rows_affected}
129
+ with self._get_cursor(txn_conn) as cursor:
130
+ cursor.execute(sql, final_params)
131
+ # DuckDB returns -1 for rowcount on DML operations
132
+ # However, fetchone() returns the actual affected row count as (count,)
133
+ rows_affected = cursor.rowcount
134
+ if rows_affected < 0:
135
+ try:
136
+ fetch_result = cursor.fetchone()
137
+ if fetch_result and isinstance(fetch_result, (tuple, list)) and len(fetch_result) > 0:
138
+ rows_affected = fetch_result[0]
139
+ else:
140
+ rows_affected = 0
141
+ except Exception:
142
+ rows_affected = 1
143
+
144
+ return SQLResult(
145
+ statement=statement,
146
+ data=[],
147
+ rows_affected=rows_affected,
148
+ operation_type=self._determine_operation_type(statement),
149
+ metadata={"status_message": "OK"},
150
+ )
126
151
 
127
152
  def _execute_many(
128
153
  self, sql: str, param_list: Any, connection: Optional["DuckDBConnection"] = None, **kwargs: Any
129
- ) -> "DMLResultDict":
130
- conn = self._connection(connection)
131
- param_list = param_list or []
132
-
133
- # DuckDB throws an error if executemany is called with empty parameter list
134
- if not param_list:
135
- return {"rows_affected": 0}
136
- with self._get_cursor(conn) as cursor:
137
- cursor.executemany(sql, param_list)
138
- # DuckDB returns -1 for rowcount on DML operations
139
- # For executemany, fetchone() only returns the count from the last operation,
140
- # so use parameter list length as the most accurate estimate
141
- rows_affected = cursor.rowcount if cursor.rowcount >= 0 else len(param_list)
142
- return {"rows_affected": rows_affected}
154
+ ) -> SQLResult[RowT]:
155
+ # Use provided connection or driver's default connection
156
+ conn = connection if connection is not None else self._connection(None)
157
+
158
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
159
+ # Normalize parameter list using consolidated utility
160
+ normalized_param_list = normalize_parameter_sequence(param_list)
161
+ final_param_list = normalized_param_list or []
162
+
163
+ # DuckDB throws an error if executemany is called with empty parameter list
164
+ if not final_param_list:
165
+ return SQLResult(
166
+ statement=SQL(sql, _dialect=self.dialect),
167
+ data=[],
168
+ rows_affected=0,
169
+ operation_type="EXECUTE",
170
+ metadata={"status_message": "OK"},
171
+ )
172
+
173
+ with self._get_cursor(txn_conn) as cursor:
174
+ cursor.executemany(sql, final_param_list)
175
+ # DuckDB returns -1 for rowcount on DML operations
176
+ # For executemany, fetchone() only returns the count from the last operation,
177
+ # so use parameter list length as the most accurate estimate
178
+ rows_affected = cursor.rowcount if cursor.rowcount >= 0 else len(final_param_list)
179
+ return SQLResult(
180
+ statement=SQL(sql, _dialect=self.dialect),
181
+ data=[],
182
+ rows_affected=rows_affected,
183
+ operation_type="EXECUTE",
184
+ metadata={"status_message": "OK"},
185
+ )
143
186
 
144
187
  def _execute_script(
145
188
  self, script: str, connection: Optional["DuckDBConnection"] = None, **kwargs: Any
146
- ) -> "ScriptResultDict":
147
- conn = self._connection(connection)
148
- with self._get_cursor(conn) as cursor:
149
- cursor.execute(script)
150
-
151
- return {
152
- "statements_executed": -1,
153
- "status_message": "Script executed successfully.",
154
- "description": "The script was sent to the database.",
155
- }
156
-
157
- def _wrap_select_result(
158
- self, statement: SQL, result: "SelectResultDict", schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
159
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
160
- fetched_tuples = result["data"]
161
- column_names = result["column_names"]
162
- rows_affected = result["rows_affected"]
163
-
164
- rows_as_dicts: list[dict[str, Any]] = [dict(zip(column_names, row)) for row in fetched_tuples]
165
-
166
- logger.debug("Query returned %d rows", len(rows_as_dicts))
167
-
168
- if schema_type:
169
- converted_data = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
170
- return SQLResult[ModelDTOT](
171
- statement=statement,
172
- data=list(converted_data),
173
- column_names=column_names,
174
- rows_affected=rows_affected,
175
- operation_type="SELECT",
176
- )
189
+ ) -> SQLResult[RowT]:
190
+ # Use provided connection or driver's default connection
191
+ conn = connection if connection is not None else self._connection(None)
177
192
 
178
- return SQLResult[RowT](
179
- statement=statement,
180
- data=rows_as_dicts,
181
- column_names=column_names,
182
- rows_affected=rows_affected,
183
- operation_type="SELECT",
184
- )
193
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
194
+ with self._get_cursor(txn_conn) as cursor:
195
+ cursor.execute(script)
185
196
 
186
- def _wrap_execute_result(
187
- self, statement: SQL, result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
188
- ) -> SQLResult[RowT]:
189
- operation_type = "UNKNOWN"
190
- if statement.expression:
191
- operation_type = str(statement.expression.key).upper()
192
-
193
- if "statements_executed" in result:
194
- script_result = cast("ScriptResultDict", result)
195
- return SQLResult[RowT](
196
- statement=statement,
197
+ return SQLResult(
198
+ statement=SQL(script, _dialect=self.dialect).as_script(),
197
199
  data=[],
198
200
  rows_affected=0,
199
- operation_type=operation_type or "SCRIPT",
200
- metadata={"status_message": script_result.get("status_message", "")},
201
+ operation_type="SCRIPT",
202
+ metadata={
203
+ "status_message": "Script executed successfully.",
204
+ "description": "The script was sent to the database.",
205
+ },
206
+ total_statements=-1,
207
+ successful_statements=-1,
201
208
  )
202
209
 
203
- dml_result = cast("DMLResultDict", result)
204
- rows_affected = dml_result.get("rows_affected", -1)
205
- status_message = dml_result.get("status_message", "")
206
- return SQLResult[RowT](
207
- statement=statement,
208
- data=[],
209
- rows_affected=rows_affected,
210
- operation_type=operation_type,
211
- metadata={"status_message": status_message},
212
- )
213
-
214
210
  # ============================================================================
215
211
  # DuckDB Native Arrow Support
216
212
  # ============================================================================
@@ -251,7 +247,7 @@ class DuckDBDriver(
251
247
  return True
252
248
  return False
253
249
 
254
- def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
250
+ def _export_native(self, query: str, destination_uri: Union[str, Path], format: str, **options: Any) -> int:
255
251
  conn = self._connection(None)
256
252
  copy_options: list[str] = []
257
253
 
@@ -283,19 +279,21 @@ class DuckDBDriver(
283
279
  raise ValueError(msg)
284
280
 
285
281
  options_str = f"({', '.join(copy_options)})" if copy_options else ""
286
- copy_sql = f"COPY ({query}) TO '{destination_uri}' {options_str}"
282
+ copy_sql = f"COPY ({query}) TO '{destination_uri!s}' {options_str}"
287
283
  result_rel = conn.execute(copy_sql)
288
284
  result = result_rel.fetchone() if result_rel else None
289
285
  return result[0] if result else 0
290
286
 
291
- def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
287
+ def _import_native(
288
+ self, source_uri: Union[str, Path], table_name: str, format: str, mode: str, **options: Any
289
+ ) -> int:
292
290
  conn = self._connection(None)
293
291
  if format == "parquet":
294
- read_func = f"read_parquet('{source_uri}')"
292
+ read_func = f"read_parquet('{source_uri!s}')"
295
293
  elif format == "csv":
296
- read_func = f"read_csv_auto('{source_uri}')"
294
+ read_func = f"read_csv_auto('{source_uri!s}')"
297
295
  elif format == "json":
298
- read_func = f"read_json_auto('{source_uri}')"
296
+ read_func = f"read_json_auto('{source_uri!s}')"
299
297
  else:
300
298
  msg = f"Unsupported format for DuckDB native import: {format}"
301
299
  raise ValueError(msg)
@@ -320,16 +318,16 @@ class DuckDBDriver(
320
318
  return int(count_result[0]) if count_result else 0
321
319
 
322
320
  def _read_parquet_native(
323
- self, source_uri: str, columns: Optional[list[str]] = None, **options: Any
321
+ self, source_uri: Union[str, Path], columns: Optional[list[str]] = None, **options: Any
324
322
  ) -> "SQLResult[dict[str, Any]]":
325
323
  conn = self._connection(None)
326
324
  if isinstance(source_uri, list):
327
325
  file_list = "[" + ", ".join(f"'{f}'" for f in source_uri) + "]"
328
326
  read_func = f"read_parquet({file_list})"
329
- elif "*" in source_uri or "?" in source_uri:
330
- read_func = f"read_parquet('{source_uri}')"
327
+ elif "*" in str(source_uri) or "?" in str(source_uri):
328
+ read_func = f"read_parquet('{source_uri!s}')"
331
329
  else:
332
- read_func = f"read_parquet('{source_uri}')"
330
+ read_func = f"read_parquet('{source_uri!s}')"
333
331
 
334
332
  column_list = ", ".join(columns) if columns else "*"
335
333
  query = f"SELECT {column_list} FROM {read_func}"
@@ -350,10 +348,16 @@ class DuckDBDriver(
350
348
  rows = [{col: arrow_dict[col][i] for col in column_names} for i in range(num_rows)]
351
349
 
352
350
  return SQLResult[dict[str, Any]](
353
- statement=SQL(query), data=rows, column_names=column_names, rows_affected=num_rows, operation_type="SELECT"
351
+ statement=SQL(query, _dialect=self.dialect),
352
+ data=rows,
353
+ column_names=column_names,
354
+ rows_affected=num_rows,
355
+ operation_type="SELECT",
354
356
  )
355
357
 
356
- def _write_parquet_native(self, data: Union[str, "ArrowTable"], destination_uri: str, **options: Any) -> None:
358
+ def _write_parquet_native(
359
+ self, data: Union[str, "ArrowTable"], destination_uri: Union[str, Path], **options: Any
360
+ ) -> None:
357
361
  conn = self._connection(None)
358
362
  copy_options: list[str] = ["FORMAT PARQUET"]
359
363
  if "compression" in options:
@@ -364,18 +368,22 @@ class DuckDBDriver(
364
368
  options_str = f"({', '.join(copy_options)})"
365
369
 
366
370
  if isinstance(data, str):
367
- copy_sql = f"COPY ({data}) TO '{destination_uri}' {options_str}"
371
+ copy_sql = f"COPY ({data}) TO '{destination_uri!s}' {options_str}"
368
372
  conn.execute(copy_sql)
369
373
  else:
370
374
  temp_name = f"_arrow_data_{uuid.uuid4().hex[:8]}"
371
375
  conn.register(temp_name, data)
372
376
  try:
373
- copy_sql = f"COPY {temp_name} TO '{destination_uri}' {options_str}"
377
+ copy_sql = f"COPY {temp_name} TO '{destination_uri!s}' {options_str}"
374
378
  conn.execute(copy_sql)
375
379
  finally:
376
380
  with contextlib.suppress(Exception):
377
381
  conn.unregister(temp_name)
378
382
 
383
+ def _connection(self, connection: Optional["DuckDBConnection"] = None) -> "DuckDBConnection":
384
+ """Get the connection to use for the operation."""
385
+ return connection or self.connection
386
+
379
387
  def _ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
380
388
  """DuckDB-optimized Arrow table ingestion using native registration."""
381
389
  self._ensure_pyarrow_installed()
@@ -404,7 +412,7 @@ class DuckDBDriver(
404
412
  msg = f"Unsupported mode: {mode}"
405
413
  raise ValueError(msg)
406
414
 
407
- result = self.execute(SQL(sql_expr.sql(dialect=self.dialect)))
415
+ result = self.execute(SQL(sql_expr.sql(dialect=self.dialect), _dialect=self.dialect))
408
416
  return result.rows_affected or table.num_rows
409
417
  finally:
410
418
  with contextlib.suppress(Exception):
@@ -4,7 +4,6 @@ import contextlib
4
4
  import logging
5
5
  from collections.abc import AsyncGenerator
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import replace
8
7
  from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
9
8
 
10
9
  import oracledb
@@ -293,15 +292,16 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"
293
292
  An OracleSyncDriver instance.
294
293
  """
295
294
  with self.provide_connection(*args, **kwargs) as conn:
296
- # Create statement config with parameter style info if not already set
297
295
  statement_config = self.statement_config
296
+ # Inject parameter style info if not already set
298
297
  if statement_config.allowed_parameter_styles is None:
298
+ from dataclasses import replace
299
+
299
300
  statement_config = replace(
300
301
  statement_config,
301
302
  allowed_parameter_styles=self.supported_parameter_styles,
302
303
  target_parameter_style=self.preferred_parameter_style,
303
304
  )
304
-
305
305
  driver = self.driver_type(connection=conn, config=statement_config)
306
306
  yield driver
307
307
 
@@ -602,15 +602,16 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnect
602
602
  An OracleAsyncDriver instance.
603
603
  """
604
604
  async with self.provide_connection(*args, **kwargs) as conn:
605
- # Create statement config with parameter style info if not already set
606
605
  statement_config = self.statement_config
606
+ # Inject parameter style info if not already set
607
607
  if statement_config.allowed_parameter_styles is None:
608
+ from dataclasses import replace
609
+
608
610
  statement_config = replace(
609
611
  statement_config,
610
612
  allowed_parameter_styles=self.supported_parameter_styles,
611
613
  target_parameter_style=self.preferred_parameter_style,
612
614
  )
613
-
614
615
  driver = self.driver_type(connection=conn, config=statement_config)
615
616
  yield driver
616
617