sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -1,22 +1,29 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import logging
|
|
3
|
-
import
|
|
4
|
-
from collections.abc import Generator, Sequence
|
|
3
|
+
from collections.abc import Iterator
|
|
5
4
|
from contextlib import contextmanager
|
|
6
|
-
from
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
7
7
|
|
|
8
8
|
from adbc_driver_manager.dbapi import Connection, Cursor
|
|
9
|
-
from sqlglot import exp as sqlglot_exp
|
|
10
9
|
|
|
11
|
-
from sqlspec.
|
|
12
|
-
from sqlspec.
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
from sqlspec.driver import SyncDriverAdapterProtocol
|
|
11
|
+
from sqlspec.driver.mixins import (
|
|
12
|
+
SQLTranslatorMixin,
|
|
13
|
+
SyncPipelinedExecutionMixin,
|
|
14
|
+
SyncStorageMixin,
|
|
15
|
+
ToSchemaMixin,
|
|
16
|
+
TypeCoercionMixin,
|
|
17
|
+
)
|
|
18
|
+
from sqlspec.exceptions import wrap_exceptions
|
|
19
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
20
|
+
from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
|
|
21
|
+
from sqlspec.statement.sql import SQL, SQLConfig
|
|
22
|
+
from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
|
|
23
|
+
from sqlspec.utils.serializers import to_json
|
|
17
24
|
|
|
18
25
|
if TYPE_CHECKING:
|
|
19
|
-
from
|
|
26
|
+
from sqlglot.dialects.dialect import DialectType
|
|
20
27
|
|
|
21
28
|
__all__ = ("AdbcConnection", "AdbcDriver")
|
|
22
29
|
|
|
@@ -24,30 +31,67 @@ logger = logging.getLogger("sqlspec")
|
|
|
24
31
|
|
|
25
32
|
AdbcConnection = Connection
|
|
26
33
|
|
|
27
|
-
# SQLite named parameter pattern - simple pattern to find parameter references
|
|
28
|
-
SQLITE_PARAM_PATTERN = re.compile(r"(?::|\$|@)([a-zA-Z0-9_]+)")
|
|
29
|
-
|
|
30
|
-
# Patterns to identify comments and string literals
|
|
31
|
-
SQL_COMMENT_PATTERN = re.compile(r"--[^\n]*|/\*.*?\*/", re.DOTALL)
|
|
32
|
-
SQL_STRING_PATTERN = re.compile(r"'[^']*'|\"[^\"]*\"")
|
|
33
|
-
|
|
34
34
|
|
|
35
35
|
class AdbcDriver(
|
|
36
|
-
|
|
37
|
-
SQLTranslatorMixin
|
|
38
|
-
|
|
39
|
-
|
|
36
|
+
SyncDriverAdapterProtocol["AdbcConnection", RowT],
|
|
37
|
+
SQLTranslatorMixin,
|
|
38
|
+
TypeCoercionMixin,
|
|
39
|
+
SyncStorageMixin,
|
|
40
|
+
SyncPipelinedExecutionMixin,
|
|
41
|
+
ToSchemaMixin,
|
|
40
42
|
):
|
|
41
|
-
"""ADBC Sync Driver Adapter.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
43
|
+
"""ADBC Sync Driver Adapter with modern architecture.
|
|
44
|
+
|
|
45
|
+
ADBC (Arrow Database Connectivity) provides a universal interface for connecting
|
|
46
|
+
to multiple database systems with high-performance Arrow-native data transfer.
|
|
47
|
+
|
|
48
|
+
This driver provides:
|
|
49
|
+
- Universal connectivity across database backends (PostgreSQL, SQLite, DuckDB, etc.)
|
|
50
|
+
- High-performance Arrow data streaming and bulk operations
|
|
51
|
+
- Intelligent dialect detection and parameter style handling
|
|
52
|
+
- Seamless integration with cloud databases (BigQuery, Snowflake)
|
|
53
|
+
- Driver manager abstraction for easy multi-database support
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
supports_native_arrow_import: ClassVar[bool] = True
|
|
57
|
+
supports_native_arrow_export: ClassVar[bool] = True
|
|
58
|
+
supports_native_parquet_export: ClassVar[bool] = False # Not implemented yet
|
|
59
|
+
supports_native_parquet_import: ClassVar[bool] = True
|
|
60
|
+
__slots__ = ("default_parameter_style", "dialect", "supported_parameter_styles")
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
connection: "AdbcConnection",
|
|
65
|
+
config: "Optional[SQLConfig]" = None,
|
|
66
|
+
default_row_type: "type[DictRow]" = DictRow,
|
|
67
|
+
) -> None:
|
|
68
|
+
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
69
|
+
self.dialect: DialectType = self._get_dialect(connection)
|
|
70
|
+
self.default_parameter_style = self._get_parameter_style_for_dialect(self.dialect)
|
|
71
|
+
# Override supported parameter styles based on actual dialect capabilities
|
|
72
|
+
self.supported_parameter_styles = self._get_supported_parameter_styles_for_dialect(self.dialect)
|
|
73
|
+
|
|
74
|
+
def _coerce_boolean(self, value: Any) -> Any:
|
|
75
|
+
"""ADBC boolean handling varies by underlying driver."""
|
|
76
|
+
return value
|
|
77
|
+
|
|
78
|
+
def _coerce_decimal(self, value: Any) -> Any:
|
|
79
|
+
"""ADBC decimal handling varies by underlying driver."""
|
|
80
|
+
if isinstance(value, str):
|
|
81
|
+
return Decimal(value)
|
|
82
|
+
return value
|
|
83
|
+
|
|
84
|
+
def _coerce_json(self, value: Any) -> Any:
|
|
85
|
+
"""ADBC JSON handling varies by underlying driver."""
|
|
86
|
+
if self.dialect == "sqlite" and isinstance(value, (dict, list)):
|
|
87
|
+
return to_json(value)
|
|
88
|
+
return value
|
|
89
|
+
|
|
90
|
+
def _coerce_array(self, value: Any) -> Any:
|
|
91
|
+
"""ADBC array handling varies by underlying driver."""
|
|
92
|
+
if self.dialect == "sqlite" and isinstance(value, (list, tuple)):
|
|
93
|
+
return to_json(list(value))
|
|
94
|
+
return value
|
|
51
95
|
|
|
52
96
|
@staticmethod
|
|
53
97
|
def _get_dialect(connection: "AdbcConnection") -> str:
|
|
@@ -59,644 +103,286 @@ class AdbcDriver(
|
|
|
59
103
|
Returns:
|
|
60
104
|
The database dialect.
|
|
61
105
|
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
106
|
+
try:
|
|
107
|
+
driver_info = connection.adbc_get_info()
|
|
108
|
+
vendor_name = driver_info.get("vendor_name", "").lower()
|
|
109
|
+
driver_name = driver_info.get("driver_name", "").lower()
|
|
110
|
+
|
|
111
|
+
if "postgres" in vendor_name or "postgresql" in driver_name:
|
|
112
|
+
return "postgres"
|
|
113
|
+
if "bigquery" in vendor_name or "bigquery" in driver_name:
|
|
114
|
+
return "bigquery"
|
|
115
|
+
if "sqlite" in vendor_name or "sqlite" in driver_name:
|
|
116
|
+
return "sqlite"
|
|
117
|
+
if "duckdb" in vendor_name or "duckdb" in driver_name:
|
|
118
|
+
return "duckdb"
|
|
119
|
+
if "mysql" in vendor_name or "mysql" in driver_name:
|
|
120
|
+
return "mysql"
|
|
121
|
+
if "snowflake" in vendor_name or "snowflake" in driver_name:
|
|
122
|
+
return "snowflake"
|
|
123
|
+
if "flight" in driver_name or "flightsql" in driver_name:
|
|
124
|
+
return "sqlite"
|
|
125
|
+
except Exception:
|
|
126
|
+
logger.warning("Could not reliably determine ADBC dialect from driver info. Defaulting to 'postgres'.")
|
|
127
|
+
return "postgres"
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def _get_parameter_style_for_dialect(dialect: str) -> ParameterStyle:
|
|
131
|
+
"""Get the parameter style for a given dialect."""
|
|
132
|
+
dialect_style_map = {
|
|
133
|
+
"postgres": ParameterStyle.NUMERIC,
|
|
134
|
+
"postgresql": ParameterStyle.NUMERIC,
|
|
135
|
+
"bigquery": ParameterStyle.NAMED_AT,
|
|
136
|
+
"sqlite": ParameterStyle.QMARK,
|
|
137
|
+
"duckdb": ParameterStyle.QMARK,
|
|
138
|
+
"mysql": ParameterStyle.POSITIONAL_PYFORMAT,
|
|
139
|
+
"snowflake": ParameterStyle.QMARK,
|
|
140
|
+
}
|
|
141
|
+
return dialect_style_map.get(dialect, ParameterStyle.QMARK)
|
|
76
142
|
|
|
77
143
|
@staticmethod
|
|
78
|
-
def
|
|
79
|
-
|
|
144
|
+
def _get_supported_parameter_styles_for_dialect(dialect: str) -> "tuple[ParameterStyle, ...]":
|
|
145
|
+
"""Get the supported parameter styles for a given dialect.
|
|
146
|
+
|
|
147
|
+
Each ADBC driver supports different parameter styles based on the underlying database.
|
|
148
|
+
"""
|
|
149
|
+
dialect_supported_styles_map = {
|
|
150
|
+
"postgres": (ParameterStyle.NUMERIC,), # PostgreSQL only supports $1, $2, $3
|
|
151
|
+
"postgresql": (ParameterStyle.NUMERIC,),
|
|
152
|
+
"bigquery": (ParameterStyle.NAMED_AT,), # BigQuery only supports @param
|
|
153
|
+
"sqlite": (ParameterStyle.QMARK,), # ADBC SQLite only supports ? (not :param)
|
|
154
|
+
"duckdb": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # DuckDB supports ? and $1
|
|
155
|
+
"mysql": (ParameterStyle.POSITIONAL_PYFORMAT,), # MySQL only supports %s
|
|
156
|
+
"snowflake": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # Snowflake supports ? and :1
|
|
157
|
+
}
|
|
158
|
+
return dialect_supported_styles_map.get(dialect, (ParameterStyle.QMARK,))
|
|
80
159
|
|
|
160
|
+
@staticmethod
|
|
81
161
|
@contextmanager
|
|
82
|
-
def
|
|
83
|
-
cursor =
|
|
162
|
+
def _get_cursor(connection: "AdbcConnection") -> Iterator["Cursor"]:
|
|
163
|
+
cursor = connection.cursor()
|
|
84
164
|
try:
|
|
85
165
|
yield cursor
|
|
86
166
|
finally:
|
|
87
167
|
with contextlib.suppress(Exception):
|
|
88
168
|
cursor.close() # type: ignore[no-untyped-call]
|
|
89
169
|
|
|
90
|
-
def
|
|
91
|
-
self,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
]
|
|
126
|
-
all_matches.sort(reverse=True)
|
|
127
|
-
|
|
128
|
-
for start, end, _ in all_matches:
|
|
129
|
-
sql = sql[:start] + " " * (end - start) + sql[end:]
|
|
130
|
-
|
|
131
|
-
# Find named parameters in clean SQL
|
|
132
|
-
named_params = list(SQLITE_PARAM_PATTERN.finditer(sql))
|
|
133
|
-
|
|
134
|
-
if named_params:
|
|
135
|
-
param_positions = [(m.start(), m.end()) for m in named_params]
|
|
136
|
-
param_positions.sort(reverse=True)
|
|
137
|
-
for start, end in param_positions:
|
|
138
|
-
sql = sql[:start] + "?" + sql[end:]
|
|
139
|
-
if not isinstance(parameters, (list, tuple)):
|
|
140
|
-
return sql, (parameters,)
|
|
141
|
-
return sql, tuple(parameters)
|
|
142
|
-
|
|
143
|
-
# Standard processing for all other cases
|
|
144
|
-
merged_params = parameters
|
|
145
|
-
if kwargs:
|
|
146
|
-
if is_dict(parameters):
|
|
147
|
-
merged_params = {**parameters, **kwargs}
|
|
148
|
-
elif parameters is not None:
|
|
149
|
-
msg = "Cannot mix positional parameters with keyword arguments for adbc driver."
|
|
150
|
-
raise ParameterStyleMismatchError(msg)
|
|
170
|
+
def _execute_statement(
|
|
171
|
+
self, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
172
|
+
) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
|
|
173
|
+
if statement.is_script:
|
|
174
|
+
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
175
|
+
return self._execute_script(sql, connection=connection, **kwargs)
|
|
176
|
+
|
|
177
|
+
# Determine if we need to convert parameter style
|
|
178
|
+
detected_styles = {p.style for p in statement.parameter_info}
|
|
179
|
+
target_style = self.default_parameter_style
|
|
180
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
181
|
+
|
|
182
|
+
if unsupported_styles:
|
|
183
|
+
target_style = self.default_parameter_style
|
|
184
|
+
elif detected_styles:
|
|
185
|
+
for style in detected_styles:
|
|
186
|
+
if style in self.supported_parameter_styles:
|
|
187
|
+
target_style = style
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
191
|
+
params = self._process_parameters(params)
|
|
192
|
+
if statement.is_many:
|
|
193
|
+
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
194
|
+
|
|
195
|
+
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
196
|
+
|
|
197
|
+
def _execute(
|
|
198
|
+
self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
199
|
+
) -> Union[SelectResultDict, DMLResultDict]:
|
|
200
|
+
conn = self._connection(connection)
|
|
201
|
+
with self._get_cursor(conn) as cursor:
|
|
202
|
+
# ADBC expects parameters as a list for most drivers
|
|
203
|
+
if parameters is not None and not isinstance(parameters, (list, tuple)):
|
|
204
|
+
cursor_params = [parameters]
|
|
151
205
|
else:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
if (isinstance(node, sqlglot_exp.Parameter) and node.name and not node.name.isdigit())
|
|
212
|
-
or (
|
|
213
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
214
|
-
and node.this
|
|
215
|
-
and not isinstance(node.this, (sqlglot_exp.Identifier, sqlglot_exp.Literal))
|
|
216
|
-
and not str(node.this).isdigit()
|
|
217
|
-
)
|
|
218
|
-
]
|
|
219
|
-
|
|
220
|
-
# If we found named parameters, transform to question marks
|
|
221
|
-
if named_param_nodes:
|
|
222
|
-
|
|
223
|
-
def convert_to_qmark(node: sqlglot_exp.Expression) -> sqlglot_exp.Expression:
|
|
224
|
-
if (isinstance(node, sqlglot_exp.Parameter) and node.name and not node.name.isdigit()) or (
|
|
225
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
226
|
-
and node.this
|
|
227
|
-
and not isinstance(node.this, (sqlglot_exp.Identifier, sqlglot_exp.Literal))
|
|
228
|
-
and not str(node.this).isdigit()
|
|
229
|
-
):
|
|
230
|
-
return sqlglot_exp.Placeholder()
|
|
231
|
-
return node
|
|
232
|
-
|
|
233
|
-
# Transform the SQL
|
|
234
|
-
processed_sql = parsed_expr.transform(convert_to_qmark, copy=True).sql(dialect=self.dialect)
|
|
235
|
-
|
|
236
|
-
# If it's a scalar parameter, ensure it's wrapped in a tuple
|
|
237
|
-
if not isinstance(processed_params, (tuple, list)):
|
|
238
|
-
processed_params = (processed_params,) # type: ignore[unreachable]
|
|
239
|
-
|
|
240
|
-
# 6. Handle dictionary parameters
|
|
241
|
-
if is_dict(processed_params):
|
|
242
|
-
# Skip conversion if there's no parsed expression to work with
|
|
243
|
-
if parsed_expr is None:
|
|
244
|
-
msg = f"ADBC ({self.dialect}): Failed to parse SQL with dictionary parameters. Cannot determine parameter order."
|
|
245
|
-
raise SQLParsingError(msg)
|
|
246
|
-
|
|
247
|
-
# Collect named parameters in the order they appear in the SQL
|
|
248
|
-
named_params = []
|
|
249
|
-
for node in parsed_expr.find_all(sqlglot_exp.Parameter, sqlglot_exp.Placeholder):
|
|
250
|
-
if isinstance(node, sqlglot_exp.Parameter) and node.name and node.name in processed_params:
|
|
251
|
-
named_params.append(node.name) # type: ignore[arg-type]
|
|
252
|
-
elif (
|
|
253
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
254
|
-
and isinstance(node.this, str)
|
|
255
|
-
and node.this in processed_params
|
|
256
|
-
):
|
|
257
|
-
named_params.append(node.this) # type: ignore[arg-type]
|
|
258
|
-
|
|
259
|
-
# If we found named parameters, convert them to ? placeholders
|
|
260
|
-
if named_params:
|
|
261
|
-
# Transform SQL to use ? placeholders
|
|
262
|
-
def convert_to_qmark(node: sqlglot_exp.Expression) -> sqlglot_exp.Expression:
|
|
263
|
-
if isinstance(node, sqlglot_exp.Parameter) and node.name and node.name in processed_params:
|
|
264
|
-
return sqlglot_exp.Placeholder() # Anonymous ? placeholder
|
|
265
|
-
if (
|
|
266
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
267
|
-
and isinstance(node.this, str)
|
|
268
|
-
and node.this in processed_params
|
|
269
|
-
):
|
|
270
|
-
return sqlglot_exp.Placeholder() # Anonymous ? placeholder
|
|
271
|
-
return node
|
|
272
|
-
|
|
273
|
-
return parsed_expr.transform(convert_to_qmark, copy=True).sql(dialect=self.dialect), tuple(
|
|
274
|
-
processed_params[name] # type: ignore[index]
|
|
275
|
-
for name in named_params
|
|
276
|
-
)
|
|
277
|
-
return processed_sql, tuple(processed_params.values())
|
|
278
|
-
if isinstance(processed_params, (list, tuple)):
|
|
279
|
-
return processed_sql, tuple(processed_params)
|
|
280
|
-
return processed_sql, (processed_params,)
|
|
281
|
-
|
|
282
|
-
@overload
|
|
283
|
-
def select(
|
|
284
|
-
self,
|
|
285
|
-
sql: str,
|
|
286
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
287
|
-
/,
|
|
288
|
-
*filters: "StatementFilter",
|
|
289
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
290
|
-
schema_type: None = None,
|
|
291
|
-
**kwargs: Any,
|
|
292
|
-
) -> "Sequence[dict[str, Any]]": ...
|
|
293
|
-
@overload
|
|
294
|
-
def select(
|
|
295
|
-
self,
|
|
296
|
-
sql: str,
|
|
297
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
298
|
-
/,
|
|
299
|
-
*filters: "StatementFilter",
|
|
300
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
301
|
-
schema_type: "type[ModelDTOT]",
|
|
302
|
-
**kwargs: Any,
|
|
303
|
-
) -> "Sequence[ModelDTOT]": ...
|
|
304
|
-
def select(
|
|
305
|
-
self,
|
|
306
|
-
sql: str,
|
|
307
|
-
parameters: Optional["StatementParameterType"] = None,
|
|
308
|
-
/,
|
|
309
|
-
*filters: "StatementFilter",
|
|
310
|
-
connection: Optional["AdbcConnection"] = None,
|
|
311
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
312
|
-
**kwargs: Any,
|
|
313
|
-
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
|
|
314
|
-
"""Fetch data from the database.
|
|
315
|
-
|
|
316
|
-
Args:
|
|
317
|
-
sql: The SQL query string.
|
|
318
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
319
|
-
*filters: Statement filters to apply.
|
|
320
|
-
connection: Optional connection override.
|
|
321
|
-
schema_type: Optional schema class for the result.
|
|
322
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
323
|
-
|
|
324
|
-
Returns:
|
|
325
|
-
List of row data as either model instances or dictionaries.
|
|
326
|
-
"""
|
|
327
|
-
connection = self._connection(connection)
|
|
328
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
329
|
-
|
|
330
|
-
with self._with_cursor(connection) as cursor:
|
|
331
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
332
|
-
results = cursor.fetchall() # pyright: ignore
|
|
333
|
-
if not results:
|
|
334
|
-
return []
|
|
335
|
-
column_names = [column[0] for column in cursor.description or []]
|
|
336
|
-
|
|
337
|
-
return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type)
|
|
338
|
-
|
|
339
|
-
@overload
|
|
340
|
-
def select_one(
|
|
341
|
-
self,
|
|
342
|
-
sql: str,
|
|
343
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
344
|
-
/,
|
|
345
|
-
*filters: "StatementFilter",
|
|
346
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
347
|
-
schema_type: None = None,
|
|
348
|
-
**kwargs: Any,
|
|
349
|
-
) -> "dict[str, Any]": ...
|
|
350
|
-
@overload
|
|
351
|
-
def select_one(
|
|
352
|
-
self,
|
|
353
|
-
sql: str,
|
|
354
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
355
|
-
/,
|
|
356
|
-
*filters: "StatementFilter",
|
|
357
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
358
|
-
schema_type: "type[ModelDTOT]",
|
|
359
|
-
**kwargs: Any,
|
|
360
|
-
) -> "ModelDTOT": ...
|
|
361
|
-
def select_one(
|
|
362
|
-
self,
|
|
363
|
-
sql: str,
|
|
364
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
365
|
-
/,
|
|
366
|
-
*filters: "StatementFilter",
|
|
367
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
368
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
369
|
-
**kwargs: Any,
|
|
370
|
-
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
371
|
-
"""Fetch one row from the database.
|
|
372
|
-
|
|
373
|
-
Args:
|
|
374
|
-
sql: The SQL query string.
|
|
375
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
376
|
-
*filters: Statement filters to apply.
|
|
377
|
-
connection: Optional connection override.
|
|
378
|
-
schema_type: Optional schema class for the result.
|
|
379
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
380
|
-
|
|
381
|
-
Returns:
|
|
382
|
-
The first row of the query results.
|
|
383
|
-
"""
|
|
384
|
-
connection = self._connection(connection)
|
|
385
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
386
|
-
|
|
387
|
-
with self._with_cursor(connection) as cursor:
|
|
388
|
-
cursor.execute(sql, parameters)
|
|
389
|
-
result = cursor.fetchone()
|
|
390
|
-
result = self.check_not_found(result)
|
|
391
|
-
column_names = [column[0] for column in cursor.description or []]
|
|
392
|
-
return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type)
|
|
393
|
-
|
|
394
|
-
@overload
|
|
395
|
-
def select_one_or_none(
|
|
396
|
-
self,
|
|
397
|
-
sql: str,
|
|
398
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
399
|
-
/,
|
|
400
|
-
*filters: "StatementFilter",
|
|
401
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
402
|
-
schema_type: None = None,
|
|
403
|
-
**kwargs: Any,
|
|
404
|
-
) -> "Optional[dict[str, Any]]": ...
|
|
405
|
-
@overload
|
|
406
|
-
def select_one_or_none(
|
|
407
|
-
self,
|
|
408
|
-
sql: str,
|
|
409
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
410
|
-
/,
|
|
411
|
-
*filters: "StatementFilter",
|
|
412
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
413
|
-
schema_type: "type[ModelDTOT]",
|
|
414
|
-
**kwargs: Any,
|
|
415
|
-
) -> "Optional[ModelDTOT]": ...
|
|
416
|
-
def select_one_or_none(
|
|
417
|
-
self,
|
|
418
|
-
sql: str,
|
|
419
|
-
parameters: Optional["StatementParameterType"] = None,
|
|
420
|
-
/,
|
|
421
|
-
*filters: "StatementFilter",
|
|
422
|
-
connection: Optional["AdbcConnection"] = None,
|
|
423
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
424
|
-
**kwargs: Any,
|
|
425
|
-
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
426
|
-
"""Fetch one row from the database or return None if no rows found.
|
|
427
|
-
|
|
428
|
-
Args:
|
|
429
|
-
sql: The SQL query string.
|
|
430
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
431
|
-
*filters: Statement filters to apply.
|
|
432
|
-
connection: Optional connection override.
|
|
433
|
-
schema_type: Optional schema class for the result.
|
|
434
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
435
|
-
|
|
436
|
-
Returns:
|
|
437
|
-
The first row of the query results, or None if no results found.
|
|
438
|
-
"""
|
|
439
|
-
connection = self._connection(connection)
|
|
440
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
441
|
-
|
|
442
|
-
with self._with_cursor(connection) as cursor:
|
|
443
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
444
|
-
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
445
|
-
if result is None:
|
|
446
|
-
return None
|
|
447
|
-
column_names = [column[0] for column in cursor.description or []]
|
|
448
|
-
return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type)
|
|
449
|
-
|
|
450
|
-
@overload
|
|
451
|
-
def select_value(
|
|
452
|
-
self,
|
|
453
|
-
sql: str,
|
|
454
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
455
|
-
/,
|
|
456
|
-
*filters: StatementFilter,
|
|
457
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
458
|
-
schema_type: None = None,
|
|
459
|
-
**kwargs: Any,
|
|
460
|
-
) -> "Any": ...
|
|
461
|
-
@overload
|
|
462
|
-
def select_value(
|
|
463
|
-
self,
|
|
464
|
-
sql: str,
|
|
465
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
466
|
-
/,
|
|
467
|
-
*filters: StatementFilter,
|
|
468
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
469
|
-
schema_type: "type[T]",
|
|
470
|
-
**kwargs: Any,
|
|
471
|
-
) -> "T": ...
|
|
472
|
-
def select_value(
|
|
473
|
-
self,
|
|
474
|
-
sql: str,
|
|
475
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
476
|
-
/,
|
|
477
|
-
*filters: StatementFilter,
|
|
478
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
479
|
-
schema_type: "Optional[type[T]]" = None,
|
|
480
|
-
**kwargs: Any,
|
|
481
|
-
) -> "Union[T, Any]":
|
|
482
|
-
"""Fetch a single value from the database.
|
|
483
|
-
|
|
484
|
-
Args:
|
|
485
|
-
sql: The SQL query string.
|
|
486
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
487
|
-
*filters: Statement filters to apply.
|
|
488
|
-
connection: Optional connection override.
|
|
489
|
-
schema_type: Optional type to convert the result to.
|
|
490
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
491
|
-
|
|
492
|
-
Returns:
|
|
493
|
-
The first value of the first row of the query results.
|
|
494
|
-
"""
|
|
495
|
-
connection = self._connection(connection)
|
|
496
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
497
|
-
|
|
498
|
-
with self._with_cursor(connection) as cursor:
|
|
499
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
500
|
-
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
501
|
-
result = self.check_not_found(result) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
|
502
|
-
if schema_type is None:
|
|
503
|
-
return result[0] # pyright: ignore[reportUnknownVariableType]
|
|
504
|
-
return schema_type(result[0]) # type: ignore[call-arg]
|
|
505
|
-
|
|
506
|
-
@overload
|
|
507
|
-
def select_value_or_none(
|
|
508
|
-
self,
|
|
509
|
-
sql: str,
|
|
510
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
511
|
-
/,
|
|
512
|
-
*filters: StatementFilter,
|
|
513
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
514
|
-
schema_type: None = None,
|
|
515
|
-
**kwargs: Any,
|
|
516
|
-
) -> "Optional[Any]": ...
|
|
517
|
-
@overload
|
|
518
|
-
def select_value_or_none(
|
|
519
|
-
self,
|
|
520
|
-
sql: str,
|
|
521
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
522
|
-
/,
|
|
523
|
-
*filters: StatementFilter,
|
|
524
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
525
|
-
schema_type: "type[T]",
|
|
526
|
-
**kwargs: Any,
|
|
527
|
-
) -> "Optional[T]": ...
|
|
528
|
-
def select_value_or_none(
|
|
529
|
-
self,
|
|
530
|
-
sql: str,
|
|
531
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
532
|
-
/,
|
|
533
|
-
*filters: StatementFilter,
|
|
534
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
535
|
-
schema_type: "Optional[type[T]]" = None,
|
|
536
|
-
**kwargs: Any,
|
|
537
|
-
) -> "Optional[Union[T, Any]]":
|
|
538
|
-
"""Fetch a single value or None if not found.
|
|
539
|
-
|
|
540
|
-
Args:
|
|
541
|
-
sql: The SQL query string.
|
|
542
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
543
|
-
*filters: Statement filters to apply.
|
|
544
|
-
connection: Optional connection override.
|
|
545
|
-
schema_type: Optional type to convert the result to.
|
|
546
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
547
|
-
|
|
548
|
-
Returns:
|
|
549
|
-
The first value of the first row of the query results, or None if no results found.
|
|
550
|
-
"""
|
|
551
|
-
connection = self._connection(connection)
|
|
552
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
553
|
-
|
|
554
|
-
with self._with_cursor(connection) as cursor:
|
|
555
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
556
|
-
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
557
|
-
if result is None:
|
|
558
|
-
return None
|
|
559
|
-
if schema_type is None:
|
|
560
|
-
return result[0] # pyright: ignore[reportUnknownVariableType]
|
|
561
|
-
return schema_type(result[0]) # type: ignore[call-arg]
|
|
562
|
-
|
|
563
|
-
def insert_update_delete(
|
|
564
|
-
self,
|
|
565
|
-
sql: str,
|
|
566
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
567
|
-
/,
|
|
568
|
-
*filters: "StatementFilter",
|
|
569
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
570
|
-
**kwargs: Any,
|
|
571
|
-
) -> int:
|
|
572
|
-
"""Execute an insert, update, or delete statement.
|
|
206
|
+
cursor_params = parameters # type: ignore[assignment]
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
cursor.execute(sql, cursor_params or [])
|
|
210
|
+
except Exception as e:
|
|
211
|
+
# Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
|
|
212
|
+
if self.dialect == "postgres":
|
|
213
|
+
with contextlib.suppress(Exception):
|
|
214
|
+
cursor.execute("ROLLBACK")
|
|
215
|
+
raise e from e
|
|
216
|
+
|
|
217
|
+
if self.returns_rows(statement.expression):
|
|
218
|
+
fetched_data = cursor.fetchall()
|
|
219
|
+
column_names = [col[0] for col in cursor.description or []]
|
|
220
|
+
result: SelectResultDict = {
|
|
221
|
+
"data": fetched_data,
|
|
222
|
+
"column_names": column_names,
|
|
223
|
+
"rows_affected": len(fetched_data),
|
|
224
|
+
}
|
|
225
|
+
return result
|
|
226
|
+
|
|
227
|
+
dml_result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
|
|
228
|
+
return dml_result
|
|
229
|
+
|
|
230
|
+
def _execute_many(
|
|
231
|
+
self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
232
|
+
) -> DMLResultDict:
|
|
233
|
+
conn = self._connection(connection)
|
|
234
|
+
with self._get_cursor(conn) as cursor:
|
|
235
|
+
try:
|
|
236
|
+
cursor.executemany(sql, param_list or [])
|
|
237
|
+
except Exception as e:
|
|
238
|
+
if self.dialect == "postgres":
|
|
239
|
+
with contextlib.suppress(Exception):
|
|
240
|
+
cursor.execute("ROLLBACK")
|
|
241
|
+
# Always re-raise the original exception
|
|
242
|
+
raise e from e
|
|
243
|
+
|
|
244
|
+
result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
|
|
245
|
+
return result
|
|
246
|
+
|
|
247
|
+
def _execute_script(
|
|
248
|
+
self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
249
|
+
) -> ScriptResultDict:
|
|
250
|
+
conn = self._connection(connection)
|
|
251
|
+
# ADBC drivers don't support multiple statements in a single execute
|
|
252
|
+
# Use the shared implementation to split the script
|
|
253
|
+
statements = self._split_script_statements(script)
|
|
254
|
+
|
|
255
|
+
executed_count = 0
|
|
256
|
+
with self._get_cursor(conn) as cursor:
|
|
257
|
+
for statement in statements:
|
|
258
|
+
executed_count += self._execute_single_script_statement(cursor, statement)
|
|
259
|
+
|
|
260
|
+
result: ScriptResultDict = {"statements_executed": executed_count, "status_message": "SCRIPT EXECUTED"}
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
def _execute_single_script_statement(self, cursor: "Cursor", statement: str) -> int:
|
|
264
|
+
"""Execute a single statement from a script and handle errors.
|
|
573
265
|
|
|
574
266
|
Args:
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
*filters: Statement filters to apply.
|
|
578
|
-
connection: Optional connection override.
|
|
579
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
267
|
+
cursor: The database cursor
|
|
268
|
+
statement: The SQL statement to execute
|
|
580
269
|
|
|
581
270
|
Returns:
|
|
582
|
-
|
|
271
|
+
1 if successful, 0 if failed
|
|
583
272
|
"""
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
**kwargs: Any,
|
|
654
|
-
) -> str:
|
|
655
|
-
"""Execute a SQL script.
|
|
273
|
+
try:
|
|
274
|
+
cursor.execute(statement)
|
|
275
|
+
except Exception as e:
|
|
276
|
+
# Rollback transaction on error for PostgreSQL
|
|
277
|
+
if self.dialect == "postgres":
|
|
278
|
+
with contextlib.suppress(Exception):
|
|
279
|
+
cursor.execute("ROLLBACK")
|
|
280
|
+
raise e from e
|
|
281
|
+
else:
|
|
282
|
+
return 1
|
|
283
|
+
|
|
284
|
+
def _wrap_select_result(
|
|
285
|
+
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
286
|
+
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
287
|
+
# result must be a dict with keys: data, column_names, rows_affected
|
|
288
|
+
|
|
289
|
+
rows_as_dicts = [dict(zip(result["column_names"], row)) for row in result["data"]]
|
|
290
|
+
|
|
291
|
+
if schema_type:
|
|
292
|
+
return SQLResult[ModelDTOT](
|
|
293
|
+
statement=statement,
|
|
294
|
+
data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
|
|
295
|
+
column_names=result["column_names"],
|
|
296
|
+
rows_affected=result["rows_affected"],
|
|
297
|
+
operation_type="SELECT",
|
|
298
|
+
)
|
|
299
|
+
return SQLResult[RowT](
|
|
300
|
+
statement=statement,
|
|
301
|
+
data=rows_as_dicts,
|
|
302
|
+
column_names=result["column_names"],
|
|
303
|
+
rows_affected=result["rows_affected"],
|
|
304
|
+
operation_type="SELECT",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
def _wrap_execute_result(
|
|
308
|
+
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
309
|
+
) -> SQLResult[RowT]:
|
|
310
|
+
operation_type = (
|
|
311
|
+
str(statement.expression.key).upper()
|
|
312
|
+
if statement.expression and hasattr(statement.expression, "key")
|
|
313
|
+
else "UNKNOWN"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Handle TypedDict results
|
|
317
|
+
if is_dict_with_field(result, "statements_executed"):
|
|
318
|
+
return SQLResult[RowT](
|
|
319
|
+
statement=statement,
|
|
320
|
+
data=[],
|
|
321
|
+
rows_affected=0,
|
|
322
|
+
total_statements=result["statements_executed"],
|
|
323
|
+
operation_type="SCRIPT", # Scripts always have operation_type SCRIPT
|
|
324
|
+
metadata={"status_message": result["status_message"]},
|
|
325
|
+
)
|
|
326
|
+
if is_dict_with_field(result, "rows_affected"):
|
|
327
|
+
return SQLResult[RowT](
|
|
328
|
+
statement=statement,
|
|
329
|
+
data=[],
|
|
330
|
+
rows_affected=result["rows_affected"],
|
|
331
|
+
operation_type=operation_type,
|
|
332
|
+
metadata={"status_message": result["status_message"]},
|
|
333
|
+
)
|
|
334
|
+
msg = f"Unexpected result type: {type(result)}"
|
|
335
|
+
raise ValueError(msg)
|
|
336
|
+
|
|
337
|
+
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
338
|
+
"""ADBC native Arrow table fetching.
|
|
339
|
+
|
|
340
|
+
ADBC has excellent native Arrow support through cursor.fetch_arrow_table()
|
|
341
|
+
This provides zero-copy data transfer for optimal performance.
|
|
656
342
|
|
|
657
343
|
Args:
|
|
658
|
-
sql:
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
connection: Optional connection override.
|
|
662
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
344
|
+
sql: Processed SQL object
|
|
345
|
+
connection: Optional connection override
|
|
346
|
+
**kwargs: Additional options (e.g., batch_size for streaming)
|
|
663
347
|
|
|
664
348
|
Returns:
|
|
665
|
-
|
|
349
|
+
ArrowResult with native Arrow table
|
|
666
350
|
"""
|
|
667
|
-
|
|
668
|
-
|
|
351
|
+
self._ensure_pyarrow_installed()
|
|
352
|
+
conn = self._connection(connection)
|
|
669
353
|
|
|
670
|
-
with self.
|
|
671
|
-
|
|
672
|
-
|
|
354
|
+
with wrap_exceptions(), self._get_cursor(conn) as cursor:
|
|
355
|
+
# Execute the query
|
|
356
|
+
params = sql.get_parameters(style=self.default_parameter_style)
|
|
357
|
+
# ADBC expects parameters as a list for most drivers
|
|
358
|
+
cursor_params = [params] if params is not None and not isinstance(params, (list, tuple)) else params
|
|
359
|
+
cursor.execute(sql.to_sql(placeholder_style=self.default_parameter_style), cursor_params or [])
|
|
360
|
+
arrow_table = cursor.fetch_arrow_table()
|
|
361
|
+
return ArrowResult(statement=sql, data=arrow_table)
|
|
673
362
|
|
|
674
|
-
|
|
363
|
+
def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|
|
364
|
+
"""ADBC-optimized Arrow table ingestion using native bulk insert.
|
|
675
365
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
sql: str,
|
|
679
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
680
|
-
/,
|
|
681
|
-
*filters: StatementFilter,
|
|
682
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
683
|
-
**kwargs: Any,
|
|
684
|
-
) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType]
|
|
685
|
-
"""Execute a SQL query and return results as an Apache Arrow Table.
|
|
366
|
+
ADBC drivers often support native Arrow table ingestion for high-performance
|
|
367
|
+
bulk loading operations.
|
|
686
368
|
|
|
687
369
|
Args:
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
370
|
+
table: Arrow table to ingest
|
|
371
|
+
table_name: Target database table name
|
|
372
|
+
mode: Ingestion mode ('append', 'replace', 'create')
|
|
373
|
+
**options: Additional ADBC-specific options
|
|
693
374
|
|
|
694
375
|
Returns:
|
|
695
|
-
|
|
376
|
+
Number of rows ingested
|
|
696
377
|
"""
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
with self.
|
|
701
|
-
|
|
702
|
-
|
|
378
|
+
self._ensure_pyarrow_installed()
|
|
379
|
+
|
|
380
|
+
conn = self._connection(None)
|
|
381
|
+
with self._get_cursor(conn) as cursor:
|
|
382
|
+
# Handle different modes
|
|
383
|
+
if mode == "replace":
|
|
384
|
+
cursor.execute(SQL(f"TRUNCATE TABLE {table_name}").to_sql(placeholder_style=ParameterStyle.STATIC))
|
|
385
|
+
elif mode == "create":
|
|
386
|
+
msg = "'create' mode is not supported for ADBC ingestion"
|
|
387
|
+
raise NotImplementedError(msg)
|
|
388
|
+
return cursor.adbc_ingest(table_name, table, mode=mode, **options) # type: ignore[arg-type]
|