sqlspec 0.13.1__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (185) hide show
  1. sqlspec/__init__.py +71 -8
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +930 -136
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +116 -285
  10. sqlspec/adapters/adbc/driver.py +462 -340
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +202 -150
  14. sqlspec/adapters/aiosqlite/driver.py +226 -247
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -199
  18. sqlspec/adapters/asyncmy/driver.py +257 -215
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +81 -214
  22. sqlspec/adapters/asyncpg/driver.py +284 -359
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -299
  26. sqlspec/adapters/bigquery/driver.py +474 -634
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +414 -397
  30. sqlspec/adapters/duckdb/driver.py +342 -393
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -458
  34. sqlspec/adapters/oracledb/driver.py +505 -531
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -307
  38. sqlspec/adapters/psqlpy/driver.py +504 -213
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -472
  42. sqlspec/adapters/psycopg/driver.py +704 -825
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +208 -142
  46. sqlspec/adapters/sqlite/driver.py +263 -278
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder/base.py → builder/_base.py} +184 -86
  50. sqlspec/{statement/builder/column.py → builder/_column.py} +97 -60
  51. sqlspec/{statement/builder/ddl.py → builder/_ddl.py} +61 -131
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +4 -10
  53. sqlspec/{statement/builder/delete.py → builder/_delete.py} +10 -30
  54. sqlspec/builder/_insert.py +421 -0
  55. sqlspec/builder/_merge.py +71 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +49 -26
  57. sqlspec/builder/_select.py +170 -0
  58. sqlspec/{statement/builder/update.py → builder/_update.py} +16 -20
  59. sqlspec/builder/mixins/__init__.py +55 -0
  60. sqlspec/builder/mixins/_cte_and_set_ops.py +222 -0
  61. sqlspec/{statement/builder/mixins/_delete_from.py → builder/mixins/_delete_operations.py} +8 -1
  62. sqlspec/builder/mixins/_insert_operations.py +244 -0
  63. sqlspec/{statement/builder/mixins/_join.py → builder/mixins/_join_operations.py} +45 -13
  64. sqlspec/{statement/builder/mixins/_merge_clauses.py → builder/mixins/_merge_operations.py} +188 -30
  65. sqlspec/builder/mixins/_order_limit_operations.py +135 -0
  66. sqlspec/builder/mixins/_pivot_operations.py +153 -0
  67. sqlspec/builder/mixins/_select_operations.py +604 -0
  68. sqlspec/builder/mixins/_update_operations.py +202 -0
  69. sqlspec/builder/mixins/_where_clause.py +644 -0
  70. sqlspec/cli.py +247 -0
  71. sqlspec/config.py +183 -138
  72. sqlspec/core/__init__.py +63 -0
  73. sqlspec/core/cache.py +871 -0
  74. sqlspec/core/compiler.py +417 -0
  75. sqlspec/core/filters.py +830 -0
  76. sqlspec/core/hashing.py +310 -0
  77. sqlspec/core/parameters.py +1237 -0
  78. sqlspec/core/result.py +677 -0
  79. sqlspec/{statement → core}/splitter.py +321 -191
  80. sqlspec/core/statement.py +676 -0
  81. sqlspec/driver/__init__.py +7 -10
  82. sqlspec/driver/_async.py +422 -163
  83. sqlspec/driver/_common.py +545 -287
  84. sqlspec/driver/_sync.py +426 -160
  85. sqlspec/driver/mixins/__init__.py +2 -13
  86. sqlspec/driver/mixins/_result_tools.py +193 -0
  87. sqlspec/driver/mixins/_sql_translator.py +65 -14
  88. sqlspec/exceptions.py +5 -252
  89. sqlspec/extensions/aiosql/adapter.py +93 -96
  90. sqlspec/extensions/litestar/__init__.py +2 -1
  91. sqlspec/extensions/litestar/cli.py +48 -0
  92. sqlspec/extensions/litestar/config.py +0 -1
  93. sqlspec/extensions/litestar/handlers.py +15 -26
  94. sqlspec/extensions/litestar/plugin.py +21 -16
  95. sqlspec/extensions/litestar/providers.py +17 -52
  96. sqlspec/loader.py +423 -104
  97. sqlspec/migrations/__init__.py +35 -0
  98. sqlspec/migrations/base.py +414 -0
  99. sqlspec/migrations/commands.py +443 -0
  100. sqlspec/migrations/loaders.py +402 -0
  101. sqlspec/migrations/runner.py +213 -0
  102. sqlspec/migrations/tracker.py +140 -0
  103. sqlspec/migrations/utils.py +129 -0
  104. sqlspec/protocols.py +51 -186
  105. sqlspec/storage/__init__.py +1 -1
  106. sqlspec/storage/backends/base.py +37 -40
  107. sqlspec/storage/backends/fsspec.py +136 -112
  108. sqlspec/storage/backends/obstore.py +138 -160
  109. sqlspec/storage/capabilities.py +5 -4
  110. sqlspec/storage/registry.py +57 -106
  111. sqlspec/typing.py +136 -115
  112. sqlspec/utils/__init__.py +2 -2
  113. sqlspec/utils/correlation.py +0 -3
  114. sqlspec/utils/deprecation.py +6 -6
  115. sqlspec/utils/fixtures.py +6 -6
  116. sqlspec/utils/logging.py +0 -2
  117. sqlspec/utils/module_loader.py +7 -12
  118. sqlspec/utils/singleton.py +0 -1
  119. sqlspec/utils/sync_tools.py +17 -38
  120. sqlspec/utils/text.py +12 -51
  121. sqlspec/utils/type_guards.py +482 -235
  122. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/METADATA +7 -2
  123. sqlspec-0.16.2.dist-info/RECORD +134 -0
  124. sqlspec-0.16.2.dist-info/entry_points.txt +2 -0
  125. sqlspec/driver/connection.py +0 -207
  126. sqlspec/driver/mixins/_csv_writer.py +0 -91
  127. sqlspec/driver/mixins/_pipeline.py +0 -512
  128. sqlspec/driver/mixins/_result_utils.py +0 -140
  129. sqlspec/driver/mixins/_storage.py +0 -926
  130. sqlspec/driver/mixins/_type_coercion.py +0 -130
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/service/__init__.py +0 -4
  133. sqlspec/service/_util.py +0 -147
  134. sqlspec/service/base.py +0 -1131
  135. sqlspec/service/pagination.py +0 -26
  136. sqlspec/statement/__init__.py +0 -21
  137. sqlspec/statement/builder/insert.py +0 -288
  138. sqlspec/statement/builder/merge.py +0 -95
  139. sqlspec/statement/builder/mixins/__init__.py +0 -65
  140. sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
  141. sqlspec/statement/builder/mixins/_case_builder.py +0 -91
  142. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
  143. sqlspec/statement/builder/mixins/_from.py +0 -63
  144. sqlspec/statement/builder/mixins/_group_by.py +0 -118
  145. sqlspec/statement/builder/mixins/_having.py +0 -35
  146. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
  147. sqlspec/statement/builder/mixins/_insert_into.py +0 -36
  148. sqlspec/statement/builder/mixins/_insert_values.py +0 -67
  149. sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
  150. sqlspec/statement/builder/mixins/_order_by.py +0 -46
  151. sqlspec/statement/builder/mixins/_pivot.py +0 -79
  152. sqlspec/statement/builder/mixins/_returning.py +0 -37
  153. sqlspec/statement/builder/mixins/_select_columns.py +0 -61
  154. sqlspec/statement/builder/mixins/_set_ops.py +0 -122
  155. sqlspec/statement/builder/mixins/_unpivot.py +0 -77
  156. sqlspec/statement/builder/mixins/_update_from.py +0 -55
  157. sqlspec/statement/builder/mixins/_update_set.py +0 -94
  158. sqlspec/statement/builder/mixins/_update_table.py +0 -29
  159. sqlspec/statement/builder/mixins/_where.py +0 -401
  160. sqlspec/statement/builder/mixins/_window_functions.py +0 -86
  161. sqlspec/statement/builder/select.py +0 -221
  162. sqlspec/statement/filters.py +0 -596
  163. sqlspec/statement/parameter_manager.py +0 -220
  164. sqlspec/statement/parameters.py +0 -867
  165. sqlspec/statement/pipelines/__init__.py +0 -210
  166. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  167. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  168. sqlspec/statement/pipelines/context.py +0 -115
  169. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  170. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  171. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  172. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  173. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  174. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  175. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  176. sqlspec/statement/pipelines/validators/_performance.py +0 -718
  177. sqlspec/statement/pipelines/validators/_security.py +0 -967
  178. sqlspec/statement/result.py +0 -435
  179. sqlspec/statement/sql.py +0 -1704
  180. sqlspec/statement/sql_compiler.py +0 -140
  181. sqlspec/utils/cached_property.py +0 -25
  182. sqlspec-0.13.1.dist-info/RECORD +0 -150
  183. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/WHEEL +0 -0
  184. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/licenses/LICENSE +0 -0
  185. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/licenses/NOTICE +0 -0
@@ -1,390 +1,512 @@
1
+ """ADBC driver implementation for Arrow Database Connectivity.
2
+
3
+ This module provides ADBC driver integration with support for:
4
+ - Multi-dialect database connections through ADBC
5
+ - Arrow-native data handling with type coercion
6
+ - Parameter style conversion for different database backends
7
+ - Transaction management with proper error handling
8
+ """
9
+
1
10
  import contextlib
2
- import logging
3
- from collections.abc import Iterator
4
- from contextlib import contextmanager
5
- from dataclasses import replace
6
- from decimal import Decimal
7
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
8
-
9
- from adbc_driver_manager.dbapi import Connection, Cursor
10
-
11
- from sqlspec.driver import SyncDriverAdapterProtocol
12
- from sqlspec.driver.connection import managed_transaction_sync
13
- from sqlspec.driver.mixins import (
14
- SQLTranslatorMixin,
15
- SyncPipelinedExecutionMixin,
16
- SyncStorageMixin,
17
- ToSchemaMixin,
18
- TypeCoercionMixin,
19
- )
20
- from sqlspec.driver.parameters import normalize_parameter_sequence
21
- from sqlspec.exceptions import wrap_exceptions
22
- from sqlspec.statement.parameters import ParameterStyle
23
- from sqlspec.statement.result import ArrowResult, SQLResult
24
- from sqlspec.statement.sql import SQL, SQLConfig
25
- from sqlspec.typing import DictRow, RowT
11
+ import datetime
12
+ import decimal
13
+ from typing import TYPE_CHECKING, Any, Optional, cast
14
+
15
+ from sqlglot import exp
16
+
17
+ from sqlspec.core.cache import get_cache_config
18
+ from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
19
+ from sqlspec.core.statement import SQL, StatementConfig
20
+ from sqlspec.driver import SyncDriverAdapterBase
21
+ from sqlspec.exceptions import MissingDependencyError, SQLParsingError, SQLSpecError
22
+ from sqlspec.utils.logging import get_logger
26
23
  from sqlspec.utils.serializers import to_json
27
24
 
28
25
  if TYPE_CHECKING:
29
- from sqlglot.dialects.dialect import DialectType
26
+ from contextlib import AbstractContextManager
27
+
28
+ from adbc_driver_manager.dbapi import Cursor
29
+
30
+ from sqlspec.adapters.adbc._types import AdbcConnection
31
+ from sqlspec.core.result import SQLResult
32
+ from sqlspec.driver import ExecutionResult
33
+
34
+ __all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config")
30
35
 
31
- __all__ = ("AdbcConnection", "AdbcDriver")
36
+ logger = get_logger("adapters.adbc")
32
37
 
33
- logger = logging.getLogger("sqlspec")
38
+ DIALECT_PATTERNS = {
39
+ "postgres": ["postgres", "postgresql"],
40
+ "bigquery": ["bigquery"],
41
+ "sqlite": ["sqlite", "flight", "flightsql"],
42
+ "duckdb": ["duckdb"],
43
+ "mysql": ["mysql"],
44
+ "snowflake": ["snowflake"],
45
+ }
34
46
 
35
- AdbcConnection = Connection
47
+ DIALECT_PARAMETER_STYLES = {
48
+ "postgres": (ParameterStyle.NUMERIC, [ParameterStyle.NUMERIC]),
49
+ "postgresql": (ParameterStyle.NUMERIC, [ParameterStyle.NUMERIC]),
50
+ "bigquery": (ParameterStyle.NAMED_AT, [ParameterStyle.NAMED_AT]),
51
+ "sqlite": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NAMED_COLON]),
52
+ "duckdb": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR]),
53
+ "mysql": (ParameterStyle.POSITIONAL_PYFORMAT, [ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT]),
54
+ "snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]),
55
+ }
36
56
 
37
57
 
38
- class AdbcDriver(
39
- SyncDriverAdapterProtocol["AdbcConnection", RowT],
40
- SQLTranslatorMixin,
41
- TypeCoercionMixin,
42
- SyncStorageMixin,
43
- SyncPipelinedExecutionMixin,
44
- ToSchemaMixin,
45
- ):
46
- """ADBC Sync Driver Adapter with modern architecture.
58
+ def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
59
+ """ADBC-specific AST transformer for NULL parameter handling.
47
60
 
48
- ADBC (Arrow Database Connectivity) provides a universal interface for connecting
49
- to multiple database systems with high-performance Arrow-native data transfer.
61
+ For PostgreSQL, this transformer replaces NULL parameter placeholders with NULL literals
62
+ in the AST to prevent Arrow from inferring 'na' types which cause binding errors.
50
63
 
51
- This driver provides:
52
- - Universal connectivity across database backends (PostgreSQL, SQLite, DuckDB, etc.)
53
- - High-performance Arrow data streaming and bulk operations
54
- - Intelligent dialect detection and parameter style handling
55
- - Seamless integration with cloud databases (BigQuery, Snowflake)
56
- - Driver manager abstraction for easy multi-database support
64
+ The transformer:
65
+ 1. Detects None parameters in the parameter list
66
+ 2. Replaces corresponding placeholders in the AST with NULL literals
67
+ 3. Removes the None parameters from the list
68
+ 4. Renumbers remaining placeholders to maintain correct mapping
69
+
70
+ Args:
71
+ expression: SQLGlot AST expression
72
+ parameters: Parameter values that may contain None
73
+
74
+ Returns:
75
+ Tuple of (modified_expression, cleaned_parameters)
57
76
  """
77
+ if not parameters:
78
+ return expression, parameters
79
+
80
+ # Detect NULL parameter positions
81
+ null_positions = set()
82
+ if isinstance(parameters, (list, tuple)):
83
+ for i, param in enumerate(parameters):
84
+ if param is None:
85
+ null_positions.add(i)
86
+ elif isinstance(parameters, dict):
87
+ for key, param in parameters.items():
88
+ if param is None:
89
+ try:
90
+ if isinstance(key, str) and key.lstrip("$").isdigit():
91
+ param_num = int(key.lstrip("$"))
92
+ null_positions.add(param_num - 1)
93
+ except ValueError:
94
+ pass
95
+
96
+ if not null_positions:
97
+ return expression, parameters
98
+
99
+ # Track position for QMARK-style placeholders
100
+ qmark_position = [0]
101
+
102
+ def transform_node(node: Any) -> Any:
103
+ """Transform parameter nodes to NULL literals and renumber remaining ones."""
104
+ # Handle QMARK-style placeholders (?, ?, ?)
105
+ if isinstance(node, exp.Placeholder) and (not hasattr(node, "this") or node.this is None):
106
+ current_pos = qmark_position[0]
107
+ qmark_position[0] += 1
108
+
109
+ if current_pos in null_positions:
110
+ return exp.Null()
111
+ # Don't renumber QMARK placeholders - they stay as ?
112
+ return node
113
+
114
+ # Handle PostgreSQL-style placeholders ($1, $2, etc.)
115
+ if isinstance(node, exp.Placeholder) and hasattr(node, "this") and node.this is not None:
116
+ try:
117
+ param_str = str(node.this).lstrip("$")
118
+ param_num = int(param_str)
119
+ param_index = param_num - 1 # Convert to 0-based
120
+
121
+ if param_index in null_positions:
122
+ return exp.Null()
123
+ # Renumber placeholder to account for removed NULLs
124
+ nulls_before = sum(1 for idx in null_positions if idx < param_index)
125
+ new_param_num = param_num - nulls_before
126
+ return exp.Placeholder(this=f"${new_param_num}")
127
+ except (ValueError, AttributeError):
128
+ pass
129
+
130
+ # Handle generic parameter nodes
131
+ if isinstance(node, exp.Parameter) and hasattr(node, "this"):
132
+ try:
133
+ param_str = str(node.this)
134
+ param_num = int(param_str)
135
+ param_index = param_num - 1 # Convert to 0-based
136
+
137
+ if param_index in null_positions:
138
+ return exp.Null()
139
+ # Renumber parameter to account for removed NULLs
140
+ nulls_before = sum(1 for idx in null_positions if idx < param_index)
141
+ new_param_num = param_num - nulls_before
142
+ return exp.Parameter(this=str(new_param_num))
143
+ except (ValueError, AttributeError):
144
+ pass
145
+
146
+ return node
147
+
148
+ # Transform the AST
149
+ modified_expression = expression.transform(transform_node)
150
+
151
+ # Remove NULL parameters from the parameter list
152
+ cleaned_params: Any
153
+ if isinstance(parameters, (list, tuple)):
154
+ cleaned_params = [p for i, p in enumerate(parameters) if i not in null_positions]
155
+ elif isinstance(parameters, dict):
156
+ cleaned_params_dict = {}
157
+ new_num = 1
158
+ for val in parameters.values():
159
+ if val is not None:
160
+ cleaned_params_dict[str(new_num)] = val
161
+ new_num += 1
162
+ cleaned_params = cleaned_params_dict
163
+ else:
164
+ cleaned_params = parameters
165
+
166
+ return modified_expression, cleaned_params
167
+
168
+
169
+ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
170
+ """Create ADBC statement configuration for the specified dialect."""
171
+ default_style, supported_styles = DIALECT_PARAMETER_STYLES.get(
172
+ detected_dialect, (ParameterStyle.QMARK, [ParameterStyle.QMARK])
173
+ )
174
+
175
+ type_map = get_type_coercion_map(detected_dialect)
176
+
177
+ sqlglot_dialect = "postgres" if detected_dialect == "postgresql" else detected_dialect
178
+
179
+ parameter_config = ParameterStyleConfig(
180
+ default_parameter_style=default_style,
181
+ supported_parameter_styles=set(supported_styles),
182
+ default_execution_parameter_style=default_style,
183
+ supported_execution_parameter_styles=set(supported_styles),
184
+ type_coercion_map=type_map,
185
+ has_native_list_expansion=True,
186
+ needs_static_script_compilation=False,
187
+ preserve_parameter_format=True,
188
+ ast_transformer=_adbc_ast_transformer if detected_dialect in {"postgres", "postgresql"} else None,
189
+ )
190
+
191
+ return StatementConfig(
192
+ dialect=sqlglot_dialect,
193
+ parameter_config=parameter_config,
194
+ enable_parsing=True,
195
+ enable_validation=True,
196
+ enable_caching=True,
197
+ enable_parameter_type_wrapping=True,
198
+ )
199
+
200
+
201
+ def _convert_array_for_postgres_adbc(value: Any) -> Any:
202
+ """Convert array values for PostgreSQL ADBC compatibility."""
203
+ if isinstance(value, tuple):
204
+ return list(value)
205
+ return value
206
+
207
+
208
+ def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
209
+ """Get type coercion map for Arrow/ADBC type handling."""
210
+ type_map = {
211
+ datetime.datetime: lambda x: x,
212
+ datetime.date: lambda x: x,
213
+ datetime.time: lambda x: x,
214
+ decimal.Decimal: float,
215
+ bool: lambda x: x,
216
+ int: lambda x: x,
217
+ float: lambda x: x,
218
+ str: lambda x: x,
219
+ bytes: lambda x: x,
220
+ tuple: _convert_array_for_postgres_adbc,
221
+ list: _convert_array_for_postgres_adbc,
222
+ dict: lambda x: x,
223
+ }
224
+
225
+ if dialect in {"postgres", "postgresql"}:
226
+ type_map[dict] = lambda x: to_json(x) if x is not None else None
227
+
228
+ return type_map
229
+
230
+
231
+ class AdbcCursor:
232
+ """Context manager for ADBC cursor management."""
233
+
234
+ __slots__ = ("connection", "cursor")
235
+
236
+ def __init__(self, connection: "AdbcConnection") -> None:
237
+ self.connection = connection
238
+ self.cursor: Optional[Cursor] = None
239
+
240
+ def __enter__(self) -> "Cursor":
241
+ self.cursor = self.connection.cursor()
242
+ return self.cursor
243
+
244
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
245
+ _ = (exc_type, exc_val, exc_tb)
246
+ if self.cursor is not None:
247
+ with contextlib.suppress(Exception):
248
+ self.cursor.close() # type: ignore[no-untyped-call]
249
+
58
250
 
59
- supports_native_arrow_import: ClassVar[bool] = True
60
- supports_native_arrow_export: ClassVar[bool] = True
61
- supports_native_parquet_export: ClassVar[bool] = False # Not implemented yet
62
- supports_native_parquet_import: ClassVar[bool] = True
63
- __slots__ = ("default_parameter_style", "dialect", "supported_parameter_styles")
251
+ class AdbcExceptionHandler:
252
+ """Custom sync context manager for handling ADBC database exceptions."""
253
+
254
+ __slots__ = ()
255
+
256
+ def __enter__(self) -> None:
257
+ return None
258
+
259
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
260
+ if exc_type is None:
261
+ return
262
+
263
+ try:
264
+ from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
265
+
266
+ if issubclass(exc_type, IntegrityError):
267
+ e = exc_val
268
+ msg = f"ADBC integrity constraint violation: {e}"
269
+ raise SQLSpecError(msg) from e
270
+ if issubclass(exc_type, ProgrammingError):
271
+ e = exc_val
272
+ error_msg = str(e).lower()
273
+ if "syntax" in error_msg or "parse" in error_msg:
274
+ msg = f"ADBC SQL syntax error: {e}"
275
+ raise SQLParsingError(msg) from e
276
+ msg = f"ADBC programming error: {e}"
277
+ raise SQLSpecError(msg) from e
278
+ if issubclass(exc_type, OperationalError):
279
+ e = exc_val
280
+ msg = f"ADBC operational error: {e}"
281
+ raise SQLSpecError(msg) from e
282
+ if issubclass(exc_type, DatabaseError):
283
+ e = exc_val
284
+ msg = f"ADBC database error: {e}"
285
+ raise SQLSpecError(msg) from e
286
+ except ImportError:
287
+ pass
288
+ if issubclass(exc_type, Exception):
289
+ e = exc_val
290
+ error_msg = str(e).lower()
291
+ if "parse" in error_msg or "syntax" in error_msg:
292
+ msg = f"SQL parsing failed: {e}"
293
+ raise SQLParsingError(msg) from e
294
+ msg = f"Unexpected database operation error: {e}"
295
+ raise SQLSpecError(msg) from e
296
+
297
+
298
+ class AdbcDriver(SyncDriverAdapterBase):
299
+ """ADBC driver for Arrow Database Connectivity.
300
+
301
+ Provides database connectivity through ADBC with support for:
302
+ - Multi-database dialect support with automatic detection
303
+ - Arrow-native data handling with type coercion
304
+ - Parameter style conversion for different backends
305
+ - Transaction management with proper error handling
306
+ """
307
+
308
+ __slots__ = ("_detected_dialect", "dialect")
64
309
 
65
310
  def __init__(
66
311
  self,
67
312
  connection: "AdbcConnection",
68
- config: "Optional[SQLConfig]" = None,
69
- default_row_type: "type[DictRow]" = DictRow,
313
+ statement_config: "Optional[StatementConfig]" = None,
314
+ driver_features: "Optional[dict[str, Any]]" = None,
70
315
  ) -> None:
71
- dialect = self._get_dialect(connection)
72
- if config and not config.dialect:
73
- config = replace(config, dialect=dialect)
74
- elif not config:
75
- # Create config with dialect
76
- config = SQLConfig(dialect=dialect)
77
-
78
- super().__init__(connection=connection, config=config, default_row_type=default_row_type)
79
- self.dialect: DialectType = dialect
80
- self.default_parameter_style = self._get_parameter_style_for_dialect(self.dialect)
81
- # Override supported parameter styles based on actual dialect capabilities
82
- self.supported_parameter_styles = self._get_supported_parameter_styles_for_dialect(self.dialect)
83
-
84
- def _coerce_boolean(self, value: Any) -> Any:
85
- """ADBC boolean handling varies by underlying driver."""
86
- return value
87
-
88
- def _coerce_decimal(self, value: Any) -> Any:
89
- """ADBC decimal handling varies by underlying driver."""
90
- if isinstance(value, str):
91
- return Decimal(value)
92
- return value
93
-
94
- def _coerce_json(self, value: Any) -> Any:
95
- """ADBC JSON handling varies by underlying driver."""
96
- if self.dialect == "sqlite" and isinstance(value, (dict, list)):
97
- return to_json(value)
98
- return value
99
-
100
- def _coerce_array(self, value: Any) -> Any:
101
- """ADBC array handling varies by underlying driver."""
102
- if self.dialect == "sqlite" and isinstance(value, (list, tuple)):
103
- return to_json(list(value))
104
- return value
316
+ self._detected_dialect = self._get_dialect(connection)
317
+
318
+ if statement_config is None:
319
+ cache_config = get_cache_config()
320
+ base_config = get_adbc_statement_config(self._detected_dialect)
321
+ enhanced_config = base_config.replace(
322
+ enable_caching=cache_config.compiled_cache_enabled, enable_parsing=True, enable_validation=True
323
+ )
324
+ statement_config = enhanced_config
325
+
326
+ super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
327
+ self.dialect = statement_config.dialect
105
328
 
106
329
  @staticmethod
107
- def _get_dialect(connection: "AdbcConnection") -> str:
108
- """Get the database dialect based on the driver name.
330
+ def _ensure_pyarrow_installed() -> None:
331
+ """Ensure PyArrow is installed for Arrow operations."""
332
+ from sqlspec.typing import PYARROW_INSTALLED
109
333
 
110
- Args:
111
- connection: The ADBC connection object.
334
+ if not PYARROW_INSTALLED:
335
+ raise MissingDependencyError(package="pyarrow", install_package="arrow")
112
336
 
113
- Returns:
114
- The database dialect.
115
- """
337
+ @staticmethod
338
+ def _get_dialect(connection: "AdbcConnection") -> str:
339
+ """Detect database dialect from ADBC connection information."""
116
340
  try:
117
341
  driver_info = connection.adbc_get_info()
118
342
  vendor_name = driver_info.get("vendor_name", "").lower()
119
343
  driver_name = driver_info.get("driver_name", "").lower()
120
344
 
121
- if "postgres" in vendor_name or "postgresql" in driver_name:
122
- return "postgres"
123
- if "bigquery" in vendor_name or "bigquery" in driver_name:
124
- return "bigquery"
125
- if "sqlite" in vendor_name or "sqlite" in driver_name:
126
- return "sqlite"
127
- if "duckdb" in vendor_name or "duckdb" in driver_name:
128
- return "duckdb"
129
- if "mysql" in vendor_name or "mysql" in driver_name:
130
- return "mysql"
131
- if "snowflake" in vendor_name or "snowflake" in driver_name:
132
- return "snowflake"
133
- if "flight" in driver_name or "flightsql" in driver_name:
134
- return "sqlite"
135
- except Exception:
136
- logger.warning("Could not reliably determine ADBC dialect from driver info. Defaulting to 'postgres'.")
345
+ for dialect, patterns in DIALECT_PATTERNS.items():
346
+ if any(pattern in vendor_name or pattern in driver_name for pattern in patterns):
347
+ logger.debug("ADBC dialect detected: %s (from %s/%s)", dialect, vendor_name, driver_name)
348
+ return dialect
349
+ except Exception as e:
350
+ logger.debug("ADBC dialect detection failed: %s", e)
351
+
352
+ logger.warning("Could not reliably determine ADBC dialect from driver info. Defaulting to 'postgres'.")
137
353
  return "postgres"
138
354
 
139
- @staticmethod
140
- def _get_parameter_style_for_dialect(dialect: str) -> ParameterStyle:
141
- """Get the parameter style for a given dialect."""
142
- dialect_style_map = {
143
- "postgres": ParameterStyle.NUMERIC,
144
- "postgresql": ParameterStyle.NUMERIC,
145
- "bigquery": ParameterStyle.NAMED_AT,
146
- "sqlite": ParameterStyle.QMARK,
147
- "duckdb": ParameterStyle.QMARK,
148
- "mysql": ParameterStyle.POSITIONAL_PYFORMAT,
149
- "snowflake": ParameterStyle.QMARK,
150
- }
151
- return dialect_style_map.get(dialect, ParameterStyle.QMARK)
355
+ def _handle_postgres_rollback(self, cursor: "Cursor") -> None:
356
+ """Execute rollback for PostgreSQL after transaction failure."""
357
+ if self.dialect == "postgres":
358
+ with contextlib.suppress(Exception):
359
+ cursor.execute("ROLLBACK")
360
+ logger.debug("PostgreSQL rollback executed after ADBC transaction failure")
152
361
 
153
- @staticmethod
154
- def _get_supported_parameter_styles_for_dialect(dialect: str) -> "tuple[ParameterStyle, ...]":
155
- """Get the supported parameter styles for a given dialect.
362
+ def _handle_postgres_empty_parameters(self, parameters: Any) -> Any:
363
+ """Process empty parameters for PostgreSQL compatibility."""
364
+ if self.dialect == "postgres" and isinstance(parameters, dict) and not parameters:
365
+ return None
366
+ return parameters
367
+
368
+ def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
369
+ """Create context manager for ADBC cursor."""
370
+ return AdbcCursor(connection)
371
+
372
+ def handle_database_exceptions(self) -> "AbstractContextManager[None]":
373
+ """Handle database-specific exceptions and wrap them appropriately."""
374
+ return AdbcExceptionHandler()
156
375
 
157
- Each ADBC driver supports different parameter styles based on the underlying database.
376
+ def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[SQLResult]":
377
+ """Handle ADBC-specific operations.
378
+
379
+ Args:
380
+ cursor: ADBC cursor object
381
+ statement: SQL statement to analyze
382
+
383
+ Returns:
384
+ SQLResult if special operation was handled, None for standard execution
158
385
  """
159
- dialect_supported_styles_map = {
160
- "postgres": (ParameterStyle.NUMERIC,), # PostgreSQL only supports $1, $2, $3
161
- "postgresql": (ParameterStyle.NUMERIC,),
162
- "bigquery": (ParameterStyle.NAMED_AT,), # BigQuery only supports @param
163
- "sqlite": (ParameterStyle.QMARK,), # ADBC SQLite only supports ? (not :param)
164
- "duckdb": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # DuckDB supports ? and $1
165
- "mysql": (ParameterStyle.POSITIONAL_PYFORMAT,), # MySQL only supports %s
166
- "snowflake": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # Snowflake supports ? and :1
167
- }
168
- return dialect_supported_styles_map.get(dialect, (ParameterStyle.QMARK,))
386
+ _ = (cursor, statement)
387
+ return None
388
+
389
+ def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
390
+ """Execute SQL with multiple parameter sets using batch processing."""
391
+ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
169
392
 
170
- @staticmethod
171
- @contextmanager
172
- def _get_cursor(connection: "AdbcConnection") -> Iterator["Cursor"]:
173
- cursor = connection.cursor()
174
393
  try:
175
- yield cursor
176
- finally:
177
- with contextlib.suppress(Exception):
178
- cursor.close() # type: ignore[no-untyped-call]
394
+ if not prepared_parameters:
395
+ cursor._rowcount = 0
396
+ row_count = 0
397
+ elif isinstance(prepared_parameters, list) and prepared_parameters:
398
+ processed_params = []
399
+ for param_set in prepared_parameters:
400
+ postgres_compatible = self._handle_postgres_empty_parameters(param_set)
401
+ formatted_params = self.prepare_driver_parameters(
402
+ postgres_compatible, self.statement_config, is_many=False
403
+ )
404
+ processed_params.append(formatted_params)
179
405
 
180
- def _execute_statement(
181
- self, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
182
- ) -> SQLResult[RowT]:
183
- if statement.is_script:
184
- sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
185
- return self._execute_script(sql, connection=connection, **kwargs)
186
-
187
- detected_styles = {p.style for p in statement.parameter_info}
188
-
189
- target_style = self.default_parameter_style
190
- unsupported_styles = detected_styles - set(self.supported_parameter_styles)
191
-
192
- if unsupported_styles:
193
- target_style = self.default_parameter_style
194
- elif detected_styles:
195
- for style in detected_styles:
196
- if style in self.supported_parameter_styles:
197
- target_style = style
198
- break
199
-
200
- sql, params = statement.compile(placeholder_style=target_style)
201
- params = self._process_parameters(params)
202
- if statement.is_many:
203
- return self._execute_many(sql, params, connection=connection, **kwargs)
204
-
205
- return self._execute(sql, params, statement, connection=connection, **kwargs)
206
-
207
- def _execute(
208
- self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
209
- ) -> SQLResult[RowT]:
210
- # Use provided connection or driver's default connection
211
- conn = connection if connection is not None else self._connection(None)
212
-
213
- with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
214
- normalized_params = normalize_parameter_sequence(parameters)
215
- if normalized_params is not None and not isinstance(normalized_params, (list, tuple)):
216
- cursor_params = [normalized_params]
406
+ cursor.executemany(sql, processed_params)
407
+ row_count = cursor.rowcount if cursor.rowcount is not None else -1
217
408
  else:
218
- cursor_params = normalized_params
409
+ cursor.executemany(sql, prepared_parameters)
410
+ row_count = cursor.rowcount if cursor.rowcount is not None else -1
219
411
 
220
- with self._get_cursor(txn_conn) as cursor:
221
- try:
222
- cursor.execute(sql, cursor_params or [])
223
- except Exception as e:
224
- # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
225
- if self.dialect == "postgres":
226
- with contextlib.suppress(Exception):
227
- cursor.execute("ROLLBACK")
228
- raise e from e
229
-
230
- if self.returns_rows(statement.expression):
231
- fetched_data = cursor.fetchall()
232
- column_names = [col[0] for col in cursor.description or []]
233
-
234
- if fetched_data and isinstance(fetched_data[0], tuple):
235
- dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
236
- else:
237
- dict_data = fetched_data # type: ignore[assignment]
238
-
239
- return SQLResult(
240
- statement=statement,
241
- data=cast("list[RowT]", dict_data),
242
- column_names=column_names,
243
- rows_affected=len(dict_data),
244
- operation_type="SELECT",
245
- )
246
-
247
- operation_type = self._determine_operation_type(statement)
248
- return SQLResult(
249
- statement=statement,
250
- data=cast("list[RowT]", []),
251
- rows_affected=cursor.rowcount,
252
- operation_type=operation_type,
253
- metadata={"status_message": "OK"},
254
- )
255
-
256
- def _execute_many(
257
- self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
258
- ) -> SQLResult[RowT]:
259
- # Use provided connection or driver's default connection
260
- conn = connection if connection is not None else self._connection(None)
261
-
262
- with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
263
- # Normalize parameter list using consolidated utility
264
- normalized_param_list = normalize_parameter_sequence(param_list)
265
-
266
- with self._get_cursor(txn_conn) as cursor:
267
- try:
268
- cursor.executemany(sql, normalized_param_list or [])
269
- except Exception as e:
270
- if self.dialect == "postgres":
271
- with contextlib.suppress(Exception):
272
- cursor.execute("ROLLBACK")
273
- # Always re-raise the original exception
274
- raise e from e
275
-
276
- return SQLResult(
277
- statement=SQL(sql, _dialect=self.dialect),
278
- data=[],
279
- rows_affected=cursor.rowcount,
280
- operation_type="EXECUTE",
281
- metadata={"status_message": "OK"},
282
- )
283
-
284
- def _execute_script(
285
- self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
286
- ) -> SQLResult[RowT]:
287
- # Use provided connection or driver's default connection
288
- conn = connection if connection is not None else self._connection(None)
289
-
290
- with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
291
- # ADBC drivers don't support multiple statements in a single execute
292
- statements = self._split_script_statements(script)
293
-
294
- executed_count = 0
295
- with self._get_cursor(txn_conn) as cursor:
296
- for statement in statements:
297
- if statement.strip():
298
- self._execute_single_script_statement(cursor, statement)
299
- executed_count += 1
300
-
301
- return SQLResult(
302
- statement=SQL(script, _dialect=self.dialect).as_script(),
303
- data=[],
304
- rows_affected=0,
305
- operation_type="SCRIPT",
306
- metadata={"status_message": "SCRIPT EXECUTED"},
307
- total_statements=executed_count,
308
- successful_statements=executed_count,
309
- )
412
+ except Exception:
413
+ self._handle_postgres_rollback(cursor)
414
+ logger.exception("ADBC executemany failed")
415
+ raise
310
416
 
311
- def _execute_single_script_statement(self, cursor: "Cursor", statement: str) -> int:
312
- """Execute a single statement from a script and handle errors.
417
+ return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
313
418
 
314
- Args:
315
- cursor: The database cursor
316
- statement: The SQL statement to execute
419
+ def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
420
+ """Execute single SQL statement with ADBC-specific data handling."""
421
+ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
317
422
 
318
- Returns:
319
- 1 if successful, 0 if failed
320
- """
321
423
  try:
322
- cursor.execute(statement)
323
- except Exception as e:
324
- # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
325
- if self.dialect == "postgres":
326
- with contextlib.suppress(Exception):
327
- cursor.execute("ROLLBACK")
328
- raise e from e
329
- else:
330
- return 1
424
+ postgres_compatible_params = self._handle_postgres_empty_parameters(prepared_parameters)
425
+ cursor.execute(sql, parameters=postgres_compatible_params)
331
426
 
332
- def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
333
- """ADBC native Arrow table fetching.
427
+ except Exception:
428
+ self._handle_postgres_rollback(cursor)
429
+ raise
334
430
 
335
- ADBC has excellent native Arrow support through cursor.fetch_arrow_table()
336
- This provides zero-copy data transfer for optimal performance.
431
+ if statement.returns_rows():
432
+ fetched_data = cursor.fetchall()
433
+ column_names = [col[0] for col in cursor.description or []]
337
434
 
338
- Args:
339
- sql: Processed SQL object
340
- connection: Optional connection override
341
- **kwargs: Additional options (e.g., batch_size for streaming)
435
+ if fetched_data and isinstance(fetched_data[0], tuple):
436
+ dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
437
+ else:
438
+ dict_data = fetched_data # type: ignore[assignment]
439
+
440
+ return self.create_execution_result(
441
+ cursor,
442
+ selected_data=cast("list[dict[str, Any]]", dict_data),
443
+ column_names=column_names,
444
+ data_row_count=len(dict_data),
445
+ is_select_result=True,
446
+ )
342
447
 
343
- Returns:
344
- ArrowResult with native Arrow table
345
- """
346
- self._ensure_pyarrow_installed()
347
- conn = self._connection(connection)
448
+ row_count = cursor.rowcount if cursor.rowcount is not None else -1
449
+ return self.create_execution_result(cursor, rowcount_override=row_count)
348
450
 
349
- with wrap_exceptions(), self._get_cursor(conn) as cursor:
350
- # Execute the query
351
- params = sql.get_parameters(style=self.default_parameter_style)
352
- # ADBC expects parameters as a list for most drivers
353
- cursor_params = [params] if params is not None and not isinstance(params, (list, tuple)) else params
354
- cursor.execute(sql.to_sql(placeholder_style=self.default_parameter_style), cursor_params or [])
355
- arrow_table = cursor.fetch_arrow_table()
356
- return ArrowResult(statement=sql, data=arrow_table)
451
+ def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult":
452
+ """Execute SQL script with ADBC-specific transaction handling."""
453
+ if statement.is_script:
454
+ sql = statement._raw_sql
455
+ prepared_parameters: list[Any] = []
456
+ else:
457
+ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
357
458
 
358
- def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
359
- """ADBC-optimized Arrow table ingestion using native bulk insert.
459
+ statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
360
460
 
361
- ADBC drivers often support native Arrow table ingestion for high-performance
362
- bulk loading operations.
461
+ successful_count = 0
462
+ last_rowcount = 0
363
463
 
364
- Args:
365
- table: Arrow table to ingest
366
- table_name: Target database table name
367
- mode: Ingestion mode ('append', 'replace', 'create')
368
- **options: Additional ADBC-specific options
464
+ try:
465
+ for stmt in statements:
466
+ if prepared_parameters:
467
+ postgres_compatible_params = self._handle_postgres_empty_parameters(prepared_parameters)
468
+ cursor.execute(stmt, parameters=postgres_compatible_params)
469
+ else:
470
+ cursor.execute(stmt)
471
+ successful_count += 1
472
+ if cursor.rowcount is not None:
473
+ last_rowcount = cursor.rowcount
474
+ except Exception:
475
+ self._handle_postgres_rollback(cursor)
476
+ logger.exception("ADBC script execution failed")
477
+ raise
478
+
479
+ return self.create_execution_result(
480
+ cursor,
481
+ statement_count=len(statements),
482
+ successful_statements=successful_count,
483
+ rowcount_override=last_rowcount,
484
+ is_script_result=True,
485
+ )
486
+
487
+ def begin(self) -> None:
488
+ """Begin database transaction."""
489
+ try:
490
+ with self.with_cursor(self.connection) as cursor:
491
+ cursor.execute("BEGIN")
492
+ except Exception as e:
493
+ msg = f"Failed to begin ADBC transaction: {e}"
494
+ raise SQLSpecError(msg) from e
369
495
 
370
- Returns:
371
- Number of rows ingested
372
- """
373
- self._ensure_pyarrow_installed()
374
-
375
- conn = self._connection(None)
376
- with self._get_cursor(conn) as cursor:
377
- if mode == "replace":
378
- cursor.execute(
379
- SQL(f"TRUNCATE TABLE {table_name}", _dialect=self.dialect).to_sql(
380
- placeholder_style=ParameterStyle.STATIC
381
- )
382
- )
383
- elif mode == "create":
384
- msg = "'create' mode is not supported for ADBC ingestion"
385
- raise NotImplementedError(msg)
386
- return cursor.adbc_ingest(table_name, table, mode=mode, **options) # type: ignore[arg-type]
387
-
388
- def _connection(self, connection: Optional["AdbcConnection"] = None) -> "AdbcConnection":
389
- """Get the connection to use for the operation."""
390
- return connection or self.connection
496
+ def rollback(self) -> None:
497
+ """Rollback database transaction."""
498
+ try:
499
+ with self.with_cursor(self.connection) as cursor:
500
+ cursor.execute("ROLLBACK")
501
+ except Exception as e:
502
+ msg = f"Failed to rollback ADBC transaction: {e}"
503
+ raise SQLSpecError(msg) from e
504
+
505
+ def commit(self) -> None:
506
+ """Commit database transaction."""
507
+ try:
508
+ with self.with_cursor(self.connection) as cursor:
509
+ cursor.execute("COMMIT")
510
+ except Exception as e:
511
+ msg = f"Failed to commit ADBC transaction: {e}"
512
+ raise SQLSpecError(msg) from e