sqlspec 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +50 -25
- sqlspec/__main__.py +1 -1
- sqlspec/__metadata__.py +1 -3
- sqlspec/_serialization.py +1 -2
- sqlspec/_sql.py +480 -121
- sqlspec/_typing.py +278 -142
- sqlspec/adapters/adbc/__init__.py +4 -3
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +115 -260
- sqlspec/adapters/adbc/driver.py +462 -367
- sqlspec/adapters/aiosqlite/__init__.py +18 -3
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +199 -129
- sqlspec/adapters/aiosqlite/driver.py +230 -269
- sqlspec/adapters/asyncmy/__init__.py +18 -3
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +80 -168
- sqlspec/adapters/asyncmy/driver.py +260 -225
- sqlspec/adapters/asyncpg/__init__.py +19 -4
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +82 -181
- sqlspec/adapters/asyncpg/driver.py +285 -383
- sqlspec/adapters/bigquery/__init__.py +17 -3
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +191 -258
- sqlspec/adapters/bigquery/driver.py +474 -646
- sqlspec/adapters/duckdb/__init__.py +14 -3
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +415 -351
- sqlspec/adapters/duckdb/driver.py +343 -413
- sqlspec/adapters/oracledb/__init__.py +19 -5
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +123 -379
- sqlspec/adapters/oracledb/driver.py +507 -560
- sqlspec/adapters/psqlpy/__init__.py +13 -3
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +93 -254
- sqlspec/adapters/psqlpy/driver.py +505 -234
- sqlspec/adapters/psycopg/__init__.py +19 -5
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +143 -403
- sqlspec/adapters/psycopg/driver.py +706 -872
- sqlspec/adapters/sqlite/__init__.py +14 -3
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +202 -118
- sqlspec/adapters/sqlite/driver.py +264 -303
- sqlspec/base.py +105 -9
- sqlspec/{statement/builder → builder}/__init__.py +12 -14
- sqlspec/{statement/builder → builder}/_base.py +120 -55
- sqlspec/{statement/builder → builder}/_column.py +17 -6
- sqlspec/{statement/builder → builder}/_ddl.py +46 -79
- sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
- sqlspec/{statement/builder → builder}/_delete.py +6 -25
- sqlspec/{statement/builder → builder}/_insert.py +18 -65
- sqlspec/builder/_merge.py +56 -0
- sqlspec/{statement/builder → builder}/_parsing_utils.py +8 -11
- sqlspec/{statement/builder → builder}/_select.py +11 -56
- sqlspec/{statement/builder → builder}/_update.py +12 -18
- sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
- sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
- sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +34 -18
- sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
- sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +19 -9
- sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
- sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
- sqlspec/{statement/builder → builder}/mixins/_select_operations.py +25 -38
- sqlspec/{statement/builder → builder}/mixins/_update_operations.py +15 -16
- sqlspec/{statement/builder → builder}/mixins/_where_clause.py +210 -137
- sqlspec/cli.py +4 -5
- sqlspec/config.py +180 -133
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.py +830 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.py +664 -0
- sqlspec/{statement → core}/splitter.py +321 -191
- sqlspec/core/statement.py +666 -0
- sqlspec/driver/__init__.py +7 -10
- sqlspec/driver/_async.py +387 -176
- sqlspec/driver/_common.py +527 -289
- sqlspec/driver/_sync.py +390 -172
- sqlspec/driver/mixins/__init__.py +2 -19
- sqlspec/driver/mixins/_result_tools.py +164 -0
- sqlspec/driver/mixins/_sql_translator.py +6 -3
- sqlspec/exceptions.py +5 -252
- sqlspec/extensions/aiosql/adapter.py +93 -96
- sqlspec/extensions/litestar/cli.py +1 -1
- sqlspec/extensions/litestar/config.py +0 -1
- sqlspec/extensions/litestar/handlers.py +15 -26
- sqlspec/extensions/litestar/plugin.py +18 -16
- sqlspec/extensions/litestar/providers.py +17 -52
- sqlspec/loader.py +424 -105
- sqlspec/migrations/__init__.py +12 -0
- sqlspec/migrations/base.py +92 -68
- sqlspec/migrations/commands.py +24 -106
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +49 -51
- sqlspec/migrations/tracker.py +31 -44
- sqlspec/migrations/utils.py +64 -24
- sqlspec/protocols.py +7 -183
- sqlspec/storage/__init__.py +1 -1
- sqlspec/storage/backends/base.py +37 -40
- sqlspec/storage/backends/fsspec.py +136 -112
- sqlspec/storage/backends/obstore.py +138 -160
- sqlspec/storage/capabilities.py +5 -4
- sqlspec/storage/registry.py +57 -106
- sqlspec/typing.py +136 -115
- sqlspec/utils/__init__.py +2 -3
- sqlspec/utils/correlation.py +0 -3
- sqlspec/utils/deprecation.py +6 -6
- sqlspec/utils/fixtures.py +6 -6
- sqlspec/utils/logging.py +0 -2
- sqlspec/utils/module_loader.py +7 -12
- sqlspec/utils/singleton.py +0 -1
- sqlspec/utils/sync_tools.py +17 -38
- sqlspec/utils/text.py +12 -51
- sqlspec/utils/type_guards.py +443 -232
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/METADATA +7 -2
- sqlspec-0.16.0.dist-info/RECORD +134 -0
- sqlspec/adapters/adbc/transformers.py +0 -108
- sqlspec/driver/connection.py +0 -207
- sqlspec/driver/mixins/_cache.py +0 -114
- sqlspec/driver/mixins/_csv_writer.py +0 -91
- sqlspec/driver/mixins/_pipeline.py +0 -508
- sqlspec/driver/mixins/_query_tools.py +0 -796
- sqlspec/driver/mixins/_result_utils.py +0 -138
- sqlspec/driver/mixins/_storage.py +0 -912
- sqlspec/driver/mixins/_type_coercion.py +0 -128
- sqlspec/driver/parameters.py +0 -138
- sqlspec/statement/__init__.py +0 -21
- sqlspec/statement/builder/_merge.py +0 -95
- sqlspec/statement/cache.py +0 -50
- sqlspec/statement/filters.py +0 -625
- sqlspec/statement/parameters.py +0 -956
- sqlspec/statement/pipelines/__init__.py +0 -210
- sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
- sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
- sqlspec/statement/pipelines/context.py +0 -109
- sqlspec/statement/pipelines/transformers/__init__.py +0 -7
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
- sqlspec/statement/pipelines/validators/__init__.py +0 -23
- sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
- sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
- sqlspec/statement/pipelines/validators/_performance.py +0 -714
- sqlspec/statement/pipelines/validators/_security.py +0 -967
- sqlspec/statement/result.py +0 -435
- sqlspec/statement/sql.py +0 -1774
- sqlspec/utils/cached_property.py +0 -25
- sqlspec/utils/statement_hashing.py +0 -203
- sqlspec-0.14.1.dist-info/RECORD +0 -145
- /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -1,417 +1,512 @@
|
|
|
1
|
+
"""ADBC driver implementation for Arrow Database Connectivity.
|
|
2
|
+
|
|
3
|
+
This module provides ADBC driver integration with support for:
|
|
4
|
+
- Multi-dialect database connections through ADBC
|
|
5
|
+
- Arrow-native data handling with type coercion
|
|
6
|
+
- Parameter style conversion for different database backends
|
|
7
|
+
- Transaction management with proper error handling
|
|
8
|
+
"""
|
|
9
|
+
|
|
1
10
|
import contextlib
|
|
2
|
-
import
|
|
3
|
-
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
from sqlspec.driver import
|
|
12
|
-
from sqlspec.
|
|
13
|
-
from sqlspec.
|
|
14
|
-
SQLTranslatorMixin,
|
|
15
|
-
SyncAdapterCacheMixin,
|
|
16
|
-
SyncPipelinedExecutionMixin,
|
|
17
|
-
SyncStorageMixin,
|
|
18
|
-
ToSchemaMixin,
|
|
19
|
-
TypeCoercionMixin,
|
|
20
|
-
)
|
|
21
|
-
from sqlspec.driver.parameters import convert_parameter_sequence
|
|
22
|
-
from sqlspec.exceptions import wrap_exceptions
|
|
23
|
-
from sqlspec.statement.parameters import ParameterStyle
|
|
24
|
-
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
25
|
-
from sqlspec.statement.sql import SQL, SQLConfig
|
|
26
|
-
from sqlspec.typing import DictRow, RowT
|
|
11
|
+
import datetime
|
|
12
|
+
import decimal
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
14
|
+
|
|
15
|
+
from sqlglot import exp
|
|
16
|
+
|
|
17
|
+
from sqlspec.core.cache import get_cache_config
|
|
18
|
+
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
|
|
19
|
+
from sqlspec.core.statement import SQL, StatementConfig
|
|
20
|
+
from sqlspec.driver import SyncDriverAdapterBase
|
|
21
|
+
from sqlspec.exceptions import MissingDependencyError, SQLParsingError, SQLSpecError
|
|
22
|
+
from sqlspec.utils.logging import get_logger
|
|
27
23
|
from sqlspec.utils.serializers import to_json
|
|
28
24
|
|
|
29
25
|
if TYPE_CHECKING:
|
|
30
|
-
from
|
|
26
|
+
from contextlib import AbstractContextManager
|
|
27
|
+
|
|
28
|
+
from adbc_driver_manager.dbapi import Cursor
|
|
29
|
+
|
|
30
|
+
from sqlspec.adapters.adbc._types import AdbcConnection
|
|
31
|
+
from sqlspec.core.result import SQLResult
|
|
32
|
+
from sqlspec.driver import ExecutionResult
|
|
33
|
+
|
|
34
|
+
__all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config")
|
|
31
35
|
|
|
32
|
-
|
|
36
|
+
logger = get_logger("adapters.adbc")
|
|
33
37
|
|
|
34
|
-
|
|
38
|
+
DIALECT_PATTERNS = {
|
|
39
|
+
"postgres": ["postgres", "postgresql"],
|
|
40
|
+
"bigquery": ["bigquery"],
|
|
41
|
+
"sqlite": ["sqlite", "flight", "flightsql"],
|
|
42
|
+
"duckdb": ["duckdb"],
|
|
43
|
+
"mysql": ["mysql"],
|
|
44
|
+
"snowflake": ["snowflake"],
|
|
45
|
+
}
|
|
35
46
|
|
|
36
|
-
|
|
47
|
+
DIALECT_PARAMETER_STYLES = {
|
|
48
|
+
"postgres": (ParameterStyle.NUMERIC, [ParameterStyle.NUMERIC]),
|
|
49
|
+
"postgresql": (ParameterStyle.NUMERIC, [ParameterStyle.NUMERIC]),
|
|
50
|
+
"bigquery": (ParameterStyle.NAMED_AT, [ParameterStyle.NAMED_AT]),
|
|
51
|
+
"sqlite": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NAMED_COLON]),
|
|
52
|
+
"duckdb": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR]),
|
|
53
|
+
"mysql": (ParameterStyle.POSITIONAL_PYFORMAT, [ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT]),
|
|
54
|
+
"snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]),
|
|
55
|
+
}
|
|
37
56
|
|
|
38
57
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
SyncAdapterCacheMixin,
|
|
42
|
-
SQLTranslatorMixin,
|
|
43
|
-
TypeCoercionMixin,
|
|
44
|
-
SyncStorageMixin,
|
|
45
|
-
SyncPipelinedExecutionMixin,
|
|
46
|
-
ToSchemaMixin,
|
|
47
|
-
):
|
|
48
|
-
"""ADBC Sync Driver Adapter with modern architecture.
|
|
58
|
+
def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
|
|
59
|
+
"""ADBC-specific AST transformer for NULL parameter handling.
|
|
49
60
|
|
|
50
|
-
|
|
51
|
-
to
|
|
61
|
+
For PostgreSQL, this transformer replaces NULL parameter placeholders with NULL literals
|
|
62
|
+
in the AST to prevent Arrow from inferring 'na' types which cause binding errors.
|
|
52
63
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
64
|
+
The transformer:
|
|
65
|
+
1. Detects None parameters in the parameter list
|
|
66
|
+
2. Replaces corresponding placeholders in the AST with NULL literals
|
|
67
|
+
3. Removes the None parameters from the list
|
|
68
|
+
4. Renumbers remaining placeholders to maintain correct mapping
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
expression: SQLGlot AST expression
|
|
72
|
+
parameters: Parameter values that may contain None
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Tuple of (modified_expression, cleaned_parameters)
|
|
59
76
|
"""
|
|
77
|
+
if not parameters:
|
|
78
|
+
return expression, parameters
|
|
79
|
+
|
|
80
|
+
# Detect NULL parameter positions
|
|
81
|
+
null_positions = set()
|
|
82
|
+
if isinstance(parameters, (list, tuple)):
|
|
83
|
+
for i, param in enumerate(parameters):
|
|
84
|
+
if param is None:
|
|
85
|
+
null_positions.add(i)
|
|
86
|
+
elif isinstance(parameters, dict):
|
|
87
|
+
for key, param in parameters.items():
|
|
88
|
+
if param is None:
|
|
89
|
+
try:
|
|
90
|
+
if isinstance(key, str) and key.lstrip("$").isdigit():
|
|
91
|
+
param_num = int(key.lstrip("$"))
|
|
92
|
+
null_positions.add(param_num - 1)
|
|
93
|
+
except ValueError:
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
if not null_positions:
|
|
97
|
+
return expression, parameters
|
|
98
|
+
|
|
99
|
+
# Track position for QMARK-style placeholders
|
|
100
|
+
qmark_position = [0]
|
|
101
|
+
|
|
102
|
+
def transform_node(node: Any) -> Any:
|
|
103
|
+
"""Transform parameter nodes to NULL literals and renumber remaining ones."""
|
|
104
|
+
# Handle QMARK-style placeholders (?, ?, ?)
|
|
105
|
+
if isinstance(node, exp.Placeholder) and (not hasattr(node, "this") or node.this is None):
|
|
106
|
+
current_pos = qmark_position[0]
|
|
107
|
+
qmark_position[0] += 1
|
|
108
|
+
|
|
109
|
+
if current_pos in null_positions:
|
|
110
|
+
return exp.Null()
|
|
111
|
+
# Don't renumber QMARK placeholders - they stay as ?
|
|
112
|
+
return node
|
|
113
|
+
|
|
114
|
+
# Handle PostgreSQL-style placeholders ($1, $2, etc.)
|
|
115
|
+
if isinstance(node, exp.Placeholder) and hasattr(node, "this") and node.this is not None:
|
|
116
|
+
try:
|
|
117
|
+
param_str = str(node.this).lstrip("$")
|
|
118
|
+
param_num = int(param_str)
|
|
119
|
+
param_index = param_num - 1 # Convert to 0-based
|
|
120
|
+
|
|
121
|
+
if param_index in null_positions:
|
|
122
|
+
return exp.Null()
|
|
123
|
+
# Renumber placeholder to account for removed NULLs
|
|
124
|
+
nulls_before = sum(1 for idx in null_positions if idx < param_index)
|
|
125
|
+
new_param_num = param_num - nulls_before
|
|
126
|
+
return exp.Placeholder(this=f"${new_param_num}")
|
|
127
|
+
except (ValueError, AttributeError):
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
# Handle generic parameter nodes
|
|
131
|
+
if isinstance(node, exp.Parameter) and hasattr(node, "this"):
|
|
132
|
+
try:
|
|
133
|
+
param_str = str(node.this)
|
|
134
|
+
param_num = int(param_str)
|
|
135
|
+
param_index = param_num - 1 # Convert to 0-based
|
|
136
|
+
|
|
137
|
+
if param_index in null_positions:
|
|
138
|
+
return exp.Null()
|
|
139
|
+
# Renumber parameter to account for removed NULLs
|
|
140
|
+
nulls_before = sum(1 for idx in null_positions if idx < param_index)
|
|
141
|
+
new_param_num = param_num - nulls_before
|
|
142
|
+
return exp.Parameter(this=str(new_param_num))
|
|
143
|
+
except (ValueError, AttributeError):
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
return node
|
|
147
|
+
|
|
148
|
+
# Transform the AST
|
|
149
|
+
modified_expression = expression.transform(transform_node)
|
|
150
|
+
|
|
151
|
+
# Remove NULL parameters from the parameter list
|
|
152
|
+
cleaned_params: Any
|
|
153
|
+
if isinstance(parameters, (list, tuple)):
|
|
154
|
+
cleaned_params = [p for i, p in enumerate(parameters) if i not in null_positions]
|
|
155
|
+
elif isinstance(parameters, dict):
|
|
156
|
+
cleaned_params_dict = {}
|
|
157
|
+
new_num = 1
|
|
158
|
+
for val in parameters.values():
|
|
159
|
+
if val is not None:
|
|
160
|
+
cleaned_params_dict[str(new_num)] = val
|
|
161
|
+
new_num += 1
|
|
162
|
+
cleaned_params = cleaned_params_dict
|
|
163
|
+
else:
|
|
164
|
+
cleaned_params = parameters
|
|
165
|
+
|
|
166
|
+
return modified_expression, cleaned_params
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
|
|
170
|
+
"""Create ADBC statement configuration for the specified dialect."""
|
|
171
|
+
default_style, supported_styles = DIALECT_PARAMETER_STYLES.get(
|
|
172
|
+
detected_dialect, (ParameterStyle.QMARK, [ParameterStyle.QMARK])
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
type_map = get_type_coercion_map(detected_dialect)
|
|
176
|
+
|
|
177
|
+
sqlglot_dialect = "postgres" if detected_dialect == "postgresql" else detected_dialect
|
|
178
|
+
|
|
179
|
+
parameter_config = ParameterStyleConfig(
|
|
180
|
+
default_parameter_style=default_style,
|
|
181
|
+
supported_parameter_styles=set(supported_styles),
|
|
182
|
+
default_execution_parameter_style=default_style,
|
|
183
|
+
supported_execution_parameter_styles=set(supported_styles),
|
|
184
|
+
type_coercion_map=type_map,
|
|
185
|
+
has_native_list_expansion=True,
|
|
186
|
+
needs_static_script_compilation=False,
|
|
187
|
+
preserve_parameter_format=True,
|
|
188
|
+
ast_transformer=_adbc_ast_transformer if detected_dialect in {"postgres", "postgresql"} else None,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return StatementConfig(
|
|
192
|
+
dialect=sqlglot_dialect,
|
|
193
|
+
parameter_config=parameter_config,
|
|
194
|
+
enable_parsing=True,
|
|
195
|
+
enable_validation=True,
|
|
196
|
+
enable_caching=True,
|
|
197
|
+
enable_parameter_type_wrapping=True,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _convert_array_for_postgres_adbc(value: Any) -> Any:
|
|
202
|
+
"""Convert array values for PostgreSQL ADBC compatibility."""
|
|
203
|
+
if isinstance(value, tuple):
|
|
204
|
+
return list(value)
|
|
205
|
+
return value
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
|
|
209
|
+
"""Get type coercion map for Arrow/ADBC type handling."""
|
|
210
|
+
type_map = {
|
|
211
|
+
datetime.datetime: lambda x: x,
|
|
212
|
+
datetime.date: lambda x: x,
|
|
213
|
+
datetime.time: lambda x: x,
|
|
214
|
+
decimal.Decimal: float,
|
|
215
|
+
bool: lambda x: x,
|
|
216
|
+
int: lambda x: x,
|
|
217
|
+
float: lambda x: x,
|
|
218
|
+
str: lambda x: x,
|
|
219
|
+
bytes: lambda x: x,
|
|
220
|
+
tuple: _convert_array_for_postgres_adbc,
|
|
221
|
+
list: _convert_array_for_postgres_adbc,
|
|
222
|
+
dict: lambda x: x,
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
if dialect in {"postgres", "postgresql"}:
|
|
226
|
+
type_map[dict] = lambda x: to_json(x) if x is not None else None
|
|
227
|
+
|
|
228
|
+
return type_map
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class AdbcCursor:
|
|
232
|
+
"""Context manager for ADBC cursor management."""
|
|
233
|
+
|
|
234
|
+
__slots__ = ("connection", "cursor")
|
|
235
|
+
|
|
236
|
+
def __init__(self, connection: "AdbcConnection") -> None:
|
|
237
|
+
self.connection = connection
|
|
238
|
+
self.cursor: Optional[Cursor] = None
|
|
239
|
+
|
|
240
|
+
def __enter__(self) -> "Cursor":
|
|
241
|
+
self.cursor = self.connection.cursor()
|
|
242
|
+
return self.cursor
|
|
243
|
+
|
|
244
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
245
|
+
_ = (exc_type, exc_val, exc_tb)
|
|
246
|
+
if self.cursor is not None:
|
|
247
|
+
with contextlib.suppress(Exception):
|
|
248
|
+
self.cursor.close() # type: ignore[no-untyped-call]
|
|
249
|
+
|
|
60
250
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
251
|
+
class AdbcExceptionHandler:
|
|
252
|
+
"""Custom sync context manager for handling ADBC database exceptions."""
|
|
253
|
+
|
|
254
|
+
__slots__ = ()
|
|
255
|
+
|
|
256
|
+
def __enter__(self) -> None:
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
260
|
+
if exc_type is None:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
|
|
265
|
+
|
|
266
|
+
if issubclass(exc_type, IntegrityError):
|
|
267
|
+
e = exc_val
|
|
268
|
+
msg = f"ADBC integrity constraint violation: {e}"
|
|
269
|
+
raise SQLSpecError(msg) from e
|
|
270
|
+
if issubclass(exc_type, ProgrammingError):
|
|
271
|
+
e = exc_val
|
|
272
|
+
error_msg = str(e).lower()
|
|
273
|
+
if "syntax" in error_msg or "parse" in error_msg:
|
|
274
|
+
msg = f"ADBC SQL syntax error: {e}"
|
|
275
|
+
raise SQLParsingError(msg) from e
|
|
276
|
+
msg = f"ADBC programming error: {e}"
|
|
277
|
+
raise SQLSpecError(msg) from e
|
|
278
|
+
if issubclass(exc_type, OperationalError):
|
|
279
|
+
e = exc_val
|
|
280
|
+
msg = f"ADBC operational error: {e}"
|
|
281
|
+
raise SQLSpecError(msg) from e
|
|
282
|
+
if issubclass(exc_type, DatabaseError):
|
|
283
|
+
e = exc_val
|
|
284
|
+
msg = f"ADBC database error: {e}"
|
|
285
|
+
raise SQLSpecError(msg) from e
|
|
286
|
+
except ImportError:
|
|
287
|
+
pass
|
|
288
|
+
if issubclass(exc_type, Exception):
|
|
289
|
+
e = exc_val
|
|
290
|
+
error_msg = str(e).lower()
|
|
291
|
+
if "parse" in error_msg or "syntax" in error_msg:
|
|
292
|
+
msg = f"SQL parsing failed: {e}"
|
|
293
|
+
raise SQLParsingError(msg) from e
|
|
294
|
+
msg = f"Unexpected database operation error: {e}"
|
|
295
|
+
raise SQLSpecError(msg) from e
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class AdbcDriver(SyncDriverAdapterBase):
|
|
299
|
+
"""ADBC driver for Arrow Database Connectivity.
|
|
300
|
+
|
|
301
|
+
Provides database connectivity through ADBC with support for:
|
|
302
|
+
- Multi-database dialect support with automatic detection
|
|
303
|
+
- Arrow-native data handling with type coercion
|
|
304
|
+
- Parameter style conversion for different backends
|
|
305
|
+
- Transaction management with proper error handling
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
__slots__ = ("_detected_dialect", "dialect")
|
|
65
309
|
|
|
66
310
|
def __init__(
|
|
67
311
|
self,
|
|
68
312
|
connection: "AdbcConnection",
|
|
69
|
-
|
|
70
|
-
|
|
313
|
+
statement_config: "Optional[StatementConfig]" = None,
|
|
314
|
+
driver_features: "Optional[dict[str, Any]]" = None,
|
|
71
315
|
) -> None:
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
self.
|
|
84
|
-
|
|
85
|
-
def _coerce_boolean(self, value: Any) -> Any:
|
|
86
|
-
"""ADBC boolean handling varies by underlying driver."""
|
|
87
|
-
return value
|
|
88
|
-
|
|
89
|
-
def _coerce_decimal(self, value: Any) -> Any:
|
|
90
|
-
"""ADBC decimal handling varies by underlying driver."""
|
|
91
|
-
if isinstance(value, str):
|
|
92
|
-
return Decimal(value)
|
|
93
|
-
return value
|
|
94
|
-
|
|
95
|
-
def _coerce_json(self, value: Any) -> Any:
|
|
96
|
-
"""ADBC JSON handling varies by underlying driver."""
|
|
97
|
-
if self.dialect == "sqlite" and isinstance(value, (dict, list)):
|
|
98
|
-
return to_json(value)
|
|
99
|
-
return value
|
|
100
|
-
|
|
101
|
-
def _coerce_array(self, value: Any) -> Any:
|
|
102
|
-
"""ADBC array handling varies by underlying driver."""
|
|
103
|
-
if self.dialect == "sqlite" and isinstance(value, (list, tuple)):
|
|
104
|
-
return to_json(list(value))
|
|
105
|
-
return value
|
|
316
|
+
self._detected_dialect = self._get_dialect(connection)
|
|
317
|
+
|
|
318
|
+
if statement_config is None:
|
|
319
|
+
cache_config = get_cache_config()
|
|
320
|
+
base_config = get_adbc_statement_config(self._detected_dialect)
|
|
321
|
+
enhanced_config = base_config.replace(
|
|
322
|
+
enable_caching=cache_config.compiled_cache_enabled, enable_parsing=True, enable_validation=True
|
|
323
|
+
)
|
|
324
|
+
statement_config = enhanced_config
|
|
325
|
+
|
|
326
|
+
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
|
|
327
|
+
self.dialect = statement_config.dialect
|
|
106
328
|
|
|
107
329
|
@staticmethod
|
|
108
|
-
def
|
|
109
|
-
"""
|
|
330
|
+
def _ensure_pyarrow_installed() -> None:
|
|
331
|
+
"""Ensure PyArrow is installed for Arrow operations."""
|
|
332
|
+
from sqlspec.typing import PYARROW_INSTALLED
|
|
110
333
|
|
|
111
|
-
|
|
112
|
-
|
|
334
|
+
if not PYARROW_INSTALLED:
|
|
335
|
+
raise MissingDependencyError(package="pyarrow", install_package="arrow")
|
|
113
336
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"""
|
|
337
|
+
@staticmethod
|
|
338
|
+
def _get_dialect(connection: "AdbcConnection") -> str:
|
|
339
|
+
"""Detect database dialect from ADBC connection information."""
|
|
117
340
|
try:
|
|
118
341
|
driver_info = connection.adbc_get_info()
|
|
119
342
|
vendor_name = driver_info.get("vendor_name", "").lower()
|
|
120
343
|
driver_name = driver_info.get("driver_name", "").lower()
|
|
121
344
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if "mysql" in vendor_name or "mysql" in driver_name:
|
|
131
|
-
return "mysql"
|
|
132
|
-
if "snowflake" in vendor_name or "snowflake" in driver_name:
|
|
133
|
-
return "snowflake"
|
|
134
|
-
if "flight" in driver_name or "flightsql" in driver_name:
|
|
135
|
-
return "sqlite"
|
|
136
|
-
except Exception:
|
|
137
|
-
logger.warning("Could not reliably determine ADBC dialect from driver info. Defaulting to 'postgres'.")
|
|
345
|
+
for dialect, patterns in DIALECT_PATTERNS.items():
|
|
346
|
+
if any(pattern in vendor_name or pattern in driver_name for pattern in patterns):
|
|
347
|
+
logger.debug("ADBC dialect detected: %s (from %s/%s)", dialect, vendor_name, driver_name)
|
|
348
|
+
return dialect
|
|
349
|
+
except Exception as e:
|
|
350
|
+
logger.debug("ADBC dialect detection failed: %s", e)
|
|
351
|
+
|
|
352
|
+
logger.warning("Could not reliably determine ADBC dialect from driver info. Defaulting to 'postgres'.")
|
|
138
353
|
return "postgres"
|
|
139
354
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
"bigquery": ParameterStyle.NAMED_AT,
|
|
147
|
-
"sqlite": ParameterStyle.QMARK,
|
|
148
|
-
"duckdb": ParameterStyle.QMARK,
|
|
149
|
-
"mysql": ParameterStyle.POSITIONAL_PYFORMAT,
|
|
150
|
-
"snowflake": ParameterStyle.QMARK,
|
|
151
|
-
}
|
|
152
|
-
return dialect_style_map.get(dialect, ParameterStyle.QMARK)
|
|
355
|
+
def _handle_postgres_rollback(self, cursor: "Cursor") -> None:
|
|
356
|
+
"""Execute rollback for PostgreSQL after transaction failure."""
|
|
357
|
+
if self.dialect == "postgres":
|
|
358
|
+
with contextlib.suppress(Exception):
|
|
359
|
+
cursor.execute("ROLLBACK")
|
|
360
|
+
logger.debug("PostgreSQL rollback executed after ADBC transaction failure")
|
|
153
361
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
""
|
|
362
|
+
def _handle_postgres_empty_parameters(self, parameters: Any) -> Any:
|
|
363
|
+
"""Process empty parameters for PostgreSQL compatibility."""
|
|
364
|
+
if self.dialect == "postgres" and isinstance(parameters, dict) and not parameters:
|
|
365
|
+
return None
|
|
366
|
+
return parameters
|
|
367
|
+
|
|
368
|
+
def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
|
|
369
|
+
"""Create context manager for ADBC cursor."""
|
|
370
|
+
return AdbcCursor(connection)
|
|
371
|
+
|
|
372
|
+
def handle_database_exceptions(self) -> "AbstractContextManager[None]":
|
|
373
|
+
"""Handle database-specific exceptions and wrap them appropriately."""
|
|
374
|
+
return AdbcExceptionHandler()
|
|
157
375
|
|
|
158
|
-
|
|
376
|
+
def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[SQLResult]":
|
|
377
|
+
"""Handle ADBC-specific operations.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
cursor: ADBC cursor object
|
|
381
|
+
statement: SQL statement to analyze
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
SQLResult if special operation was handled, None for standard execution
|
|
159
385
|
"""
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
"mysql": (ParameterStyle.POSITIONAL_PYFORMAT,), # MySQL only supports %s
|
|
167
|
-
"snowflake": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # Snowflake supports ? and :1
|
|
168
|
-
}
|
|
169
|
-
return dialect_supported_styles_map.get(dialect, (ParameterStyle.QMARK,))
|
|
386
|
+
_ = (cursor, statement)
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
|
|
390
|
+
"""Execute SQL with multiple parameter sets using batch processing."""
|
|
391
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
170
392
|
|
|
171
|
-
@staticmethod
|
|
172
|
-
@contextmanager
|
|
173
|
-
def _get_cursor(connection: "AdbcConnection") -> Iterator["Cursor"]:
|
|
174
|
-
cursor = connection.cursor()
|
|
175
393
|
try:
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
394
|
+
if not prepared_parameters:
|
|
395
|
+
cursor._rowcount = 0
|
|
396
|
+
row_count = 0
|
|
397
|
+
elif isinstance(prepared_parameters, list) and prepared_parameters:
|
|
398
|
+
processed_params = []
|
|
399
|
+
for param_set in prepared_parameters:
|
|
400
|
+
postgres_compatible = self._handle_postgres_empty_parameters(param_set)
|
|
401
|
+
formatted_params = self.prepare_driver_parameters(
|
|
402
|
+
postgres_compatible, self.statement_config, is_many=False
|
|
403
|
+
)
|
|
404
|
+
processed_params.append(formatted_params)
|
|
180
405
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
) -> SQLResult[RowT]:
|
|
184
|
-
if statement.is_script:
|
|
185
|
-
sql, _ = self._get_compiled_sql(statement, ParameterStyle.STATIC)
|
|
186
|
-
return self._execute_script(sql, connection=connection, **kwargs)
|
|
187
|
-
|
|
188
|
-
detected_styles = {p.style for p in statement.parameter_info}
|
|
189
|
-
|
|
190
|
-
target_style = self.default_parameter_style
|
|
191
|
-
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
192
|
-
|
|
193
|
-
if unsupported_styles:
|
|
194
|
-
target_style = self.default_parameter_style
|
|
195
|
-
elif detected_styles:
|
|
196
|
-
for style in detected_styles:
|
|
197
|
-
if style in self.supported_parameter_styles:
|
|
198
|
-
target_style = style
|
|
199
|
-
break
|
|
200
|
-
|
|
201
|
-
sql, params = self._get_compiled_sql(statement, target_style)
|
|
202
|
-
params = self._process_parameters(params)
|
|
203
|
-
if statement.is_many:
|
|
204
|
-
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
205
|
-
|
|
206
|
-
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
207
|
-
|
|
208
|
-
def _execute(
|
|
209
|
-
self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
210
|
-
) -> SQLResult[RowT]:
|
|
211
|
-
# Use provided connection or driver's default connection
|
|
212
|
-
conn = connection if connection is not None else self._connection(None)
|
|
213
|
-
|
|
214
|
-
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
215
|
-
converted_params = convert_parameter_sequence(parameters)
|
|
216
|
-
if converted_params is not None and not isinstance(converted_params, (list, tuple)):
|
|
217
|
-
cursor_params = [converted_params]
|
|
406
|
+
cursor.executemany(sql, processed_params)
|
|
407
|
+
row_count = cursor.rowcount if cursor.rowcount is not None else -1
|
|
218
408
|
else:
|
|
219
|
-
|
|
409
|
+
cursor.executemany(sql, prepared_parameters)
|
|
410
|
+
row_count = cursor.rowcount if cursor.rowcount is not None else -1
|
|
220
411
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
# can still cause "Can't map Arrow type 'na' to Postgres type" errors
|
|
226
|
-
cursor.execute(sql, cursor_params or [])
|
|
227
|
-
except Exception as e:
|
|
228
|
-
# Rollback transaction on error for PostgreSQL to avoid
|
|
229
|
-
# "current transaction is aborted" errors
|
|
230
|
-
if self.dialect == "postgres":
|
|
231
|
-
with contextlib.suppress(Exception):
|
|
232
|
-
cursor.execute("ROLLBACK")
|
|
233
|
-
raise e from e
|
|
234
|
-
|
|
235
|
-
if self.returns_rows(statement.expression):
|
|
236
|
-
fetched_data = cursor.fetchall()
|
|
237
|
-
column_names = [col[0] for col in cursor.description or []]
|
|
238
|
-
|
|
239
|
-
if fetched_data and isinstance(fetched_data[0], tuple):
|
|
240
|
-
dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
|
|
241
|
-
else:
|
|
242
|
-
dict_data = fetched_data # type: ignore[assignment]
|
|
243
|
-
|
|
244
|
-
return SQLResult(
|
|
245
|
-
statement=statement,
|
|
246
|
-
data=cast("list[RowT]", dict_data),
|
|
247
|
-
column_names=column_names,
|
|
248
|
-
rows_affected=len(dict_data),
|
|
249
|
-
operation_type="SELECT",
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
operation_type = self._determine_operation_type(statement)
|
|
253
|
-
return SQLResult(
|
|
254
|
-
statement=statement,
|
|
255
|
-
data=cast("list[RowT]", []),
|
|
256
|
-
rows_affected=cursor.rowcount,
|
|
257
|
-
operation_type=operation_type,
|
|
258
|
-
metadata={"status_message": "OK"},
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
def _execute_many(
|
|
262
|
-
self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
263
|
-
) -> SQLResult[RowT]:
|
|
264
|
-
# Use provided connection or driver's default connection
|
|
265
|
-
conn = connection if connection is not None else self._connection(None)
|
|
266
|
-
|
|
267
|
-
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
268
|
-
# Normalize parameter list using consolidated utility
|
|
269
|
-
converted_param_list = convert_parameter_sequence(param_list)
|
|
270
|
-
|
|
271
|
-
# Handle empty parameter list case for PostgreSQL
|
|
272
|
-
if not converted_param_list and self.dialect == "postgres":
|
|
273
|
-
# Return empty result without executing
|
|
274
|
-
return SQLResult(
|
|
275
|
-
statement=SQL(sql, _dialect=self.dialect),
|
|
276
|
-
data=[],
|
|
277
|
-
rows_affected=0,
|
|
278
|
-
operation_type="EXECUTE",
|
|
279
|
-
metadata={"status_message": "OK"},
|
|
280
|
-
)
|
|
281
|
-
|
|
282
|
-
with self._get_cursor(txn_conn) as cursor:
|
|
283
|
-
try:
|
|
284
|
-
cursor.executemany(sql, converted_param_list or [])
|
|
285
|
-
except Exception as e:
|
|
286
|
-
if self.dialect == "postgres":
|
|
287
|
-
with contextlib.suppress(Exception):
|
|
288
|
-
cursor.execute("ROLLBACK")
|
|
289
|
-
# Always re-raise the original exception
|
|
290
|
-
raise e from e
|
|
291
|
-
|
|
292
|
-
return SQLResult(
|
|
293
|
-
statement=SQL(sql, _dialect=self.dialect),
|
|
294
|
-
data=[],
|
|
295
|
-
rows_affected=cursor.rowcount,
|
|
296
|
-
operation_type="EXECUTE",
|
|
297
|
-
metadata={"status_message": "OK"},
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
def _execute_script(
|
|
301
|
-
self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
302
|
-
) -> SQLResult[RowT]:
|
|
303
|
-
# Use provided connection or driver's default connection
|
|
304
|
-
conn = connection if connection is not None else self._connection(None)
|
|
305
|
-
|
|
306
|
-
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
307
|
-
# ADBC drivers don't support multiple statements in a single execute
|
|
308
|
-
statements = self._split_script_statements(script)
|
|
309
|
-
suppress_warnings = kwargs.get("_suppress_warnings", False)
|
|
310
|
-
|
|
311
|
-
executed_count = 0
|
|
312
|
-
total_rows = 0
|
|
313
|
-
with self._get_cursor(txn_conn) as cursor:
|
|
314
|
-
for statement in statements:
|
|
315
|
-
if statement.strip():
|
|
316
|
-
# Validate each statement unless warnings suppressed
|
|
317
|
-
if not suppress_warnings:
|
|
318
|
-
# Run validation through pipeline
|
|
319
|
-
temp_sql = SQL(statement, config=self.config)
|
|
320
|
-
temp_sql._ensure_processed()
|
|
321
|
-
# Validation errors are logged as warnings by default
|
|
322
|
-
|
|
323
|
-
rows = self._execute_single_script_statement(cursor, statement)
|
|
324
|
-
executed_count += 1
|
|
325
|
-
total_rows += rows
|
|
326
|
-
|
|
327
|
-
return SQLResult(
|
|
328
|
-
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
329
|
-
data=[],
|
|
330
|
-
rows_affected=total_rows,
|
|
331
|
-
operation_type="SCRIPT",
|
|
332
|
-
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
333
|
-
total_statements=executed_count,
|
|
334
|
-
successful_statements=executed_count,
|
|
335
|
-
)
|
|
412
|
+
except Exception:
|
|
413
|
+
self._handle_postgres_rollback(cursor)
|
|
414
|
+
logger.exception("ADBC executemany failed")
|
|
415
|
+
raise
|
|
336
416
|
|
|
337
|
-
|
|
338
|
-
"""Execute a single statement from a script and handle errors.
|
|
417
|
+
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
|
|
339
418
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
419
|
+
def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
|
|
420
|
+
"""Execute single SQL statement with ADBC-specific data handling."""
|
|
421
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
343
422
|
|
|
344
|
-
Returns:
|
|
345
|
-
Number of rows affected
|
|
346
|
-
"""
|
|
347
423
|
try:
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
# Rollback transaction on error for PostgreSQL to avoid
|
|
351
|
-
# "current transaction is aborted" errors
|
|
352
|
-
if self.dialect == "postgres":
|
|
353
|
-
with contextlib.suppress(Exception):
|
|
354
|
-
cursor.execute("ROLLBACK")
|
|
355
|
-
raise e from e
|
|
356
|
-
else:
|
|
357
|
-
return cursor.rowcount or 0
|
|
424
|
+
postgres_compatible_params = self._handle_postgres_empty_parameters(prepared_parameters)
|
|
425
|
+
cursor.execute(sql, parameters=postgres_compatible_params)
|
|
358
426
|
|
|
359
|
-
|
|
360
|
-
|
|
427
|
+
except Exception:
|
|
428
|
+
self._handle_postgres_rollback(cursor)
|
|
429
|
+
raise
|
|
361
430
|
|
|
362
|
-
|
|
363
|
-
|
|
431
|
+
if statement.returns_rows():
|
|
432
|
+
fetched_data = cursor.fetchall()
|
|
433
|
+
column_names = [col[0] for col in cursor.description or []]
|
|
364
434
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
435
|
+
if fetched_data and isinstance(fetched_data[0], tuple):
|
|
436
|
+
dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
|
|
437
|
+
else:
|
|
438
|
+
dict_data = fetched_data # type: ignore[assignment]
|
|
439
|
+
|
|
440
|
+
return self.create_execution_result(
|
|
441
|
+
cursor,
|
|
442
|
+
selected_data=cast("list[dict[str, Any]]", dict_data),
|
|
443
|
+
column_names=column_names,
|
|
444
|
+
data_row_count=len(dict_data),
|
|
445
|
+
is_select_result=True,
|
|
446
|
+
)
|
|
369
447
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
"""
|
|
373
|
-
self._ensure_pyarrow_installed()
|
|
374
|
-
conn = self._connection(connection)
|
|
448
|
+
row_count = cursor.rowcount if cursor.rowcount is not None else -1
|
|
449
|
+
return self.create_execution_result(cursor, rowcount_override=row_count)
|
|
375
450
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
return ArrowResult(statement=sql, data=arrow_table)
|
|
451
|
+
def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult":
|
|
452
|
+
"""Execute SQL script with ADBC-specific transaction handling."""
|
|
453
|
+
if statement.is_script:
|
|
454
|
+
sql = statement._raw_sql
|
|
455
|
+
prepared_parameters: list[Any] = []
|
|
456
|
+
else:
|
|
457
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
384
458
|
|
|
385
|
-
|
|
386
|
-
"""ADBC-optimized Arrow table ingestion using native bulk insert.
|
|
459
|
+
statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
|
|
387
460
|
|
|
388
|
-
|
|
389
|
-
|
|
461
|
+
successful_count = 0
|
|
462
|
+
last_rowcount = 0
|
|
390
463
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
464
|
+
try:
|
|
465
|
+
for stmt in statements:
|
|
466
|
+
if prepared_parameters:
|
|
467
|
+
postgres_compatible_params = self._handle_postgres_empty_parameters(prepared_parameters)
|
|
468
|
+
cursor.execute(stmt, parameters=postgres_compatible_params)
|
|
469
|
+
else:
|
|
470
|
+
cursor.execute(stmt)
|
|
471
|
+
successful_count += 1
|
|
472
|
+
if cursor.rowcount is not None:
|
|
473
|
+
last_rowcount = cursor.rowcount
|
|
474
|
+
except Exception:
|
|
475
|
+
self._handle_postgres_rollback(cursor)
|
|
476
|
+
logger.exception("ADBC script execution failed")
|
|
477
|
+
raise
|
|
478
|
+
|
|
479
|
+
return self.create_execution_result(
|
|
480
|
+
cursor,
|
|
481
|
+
statement_count=len(statements),
|
|
482
|
+
successful_statements=successful_count,
|
|
483
|
+
rowcount_override=last_rowcount,
|
|
484
|
+
is_script_result=True,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def begin(self) -> None:
|
|
488
|
+
"""Begin database transaction."""
|
|
489
|
+
try:
|
|
490
|
+
with self.with_cursor(self.connection) as cursor:
|
|
491
|
+
cursor.execute("BEGIN")
|
|
492
|
+
except Exception as e:
|
|
493
|
+
msg = f"Failed to begin ADBC transaction: {e}"
|
|
494
|
+
raise SQLSpecError(msg) from e
|
|
396
495
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
def _connection(self, connection: Optional["AdbcConnection"] = None) -> "AdbcConnection":
|
|
416
|
-
"""Get the connection to use for the operation."""
|
|
417
|
-
return connection or self.connection
|
|
496
|
+
def rollback(self) -> None:
|
|
497
|
+
"""Rollback database transaction."""
|
|
498
|
+
try:
|
|
499
|
+
with self.with_cursor(self.connection) as cursor:
|
|
500
|
+
cursor.execute("ROLLBACK")
|
|
501
|
+
except Exception as e:
|
|
502
|
+
msg = f"Failed to rollback ADBC transaction: {e}"
|
|
503
|
+
raise SQLSpecError(msg) from e
|
|
504
|
+
|
|
505
|
+
def commit(self) -> None:
|
|
506
|
+
"""Commit database transaction."""
|
|
507
|
+
try:
|
|
508
|
+
with self.with_cursor(self.connection) as cursor:
|
|
509
|
+
cursor.execute("COMMIT")
|
|
510
|
+
except Exception as e:
|
|
511
|
+
msg = f"Failed to commit ADBC transaction: {e}"
|
|
512
|
+
raise SQLSpecError(msg) from e
|