sqlspec 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/METADATA +97 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -1,22 +1,29 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import logging
|
|
3
|
-
import
|
|
4
|
-
from collections.abc import Generator, Mapping, Sequence
|
|
3
|
+
from collections.abc import Iterator
|
|
5
4
|
from contextlib import contextmanager
|
|
6
|
-
from
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
7
7
|
|
|
8
8
|
from adbc_driver_manager.dbapi import Connection, Cursor
|
|
9
|
-
from sqlglot import exp as sqlglot_exp
|
|
10
9
|
|
|
11
|
-
from sqlspec.
|
|
12
|
-
from sqlspec.
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
from sqlspec.driver import SyncDriverAdapterProtocol
|
|
11
|
+
from sqlspec.driver.mixins import (
|
|
12
|
+
SQLTranslatorMixin,
|
|
13
|
+
SyncPipelinedExecutionMixin,
|
|
14
|
+
SyncStorageMixin,
|
|
15
|
+
ToSchemaMixin,
|
|
16
|
+
TypeCoercionMixin,
|
|
17
|
+
)
|
|
18
|
+
from sqlspec.exceptions import wrap_exceptions
|
|
19
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
20
|
+
from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
|
|
21
|
+
from sqlspec.statement.sql import SQL, SQLConfig
|
|
22
|
+
from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
|
|
23
|
+
from sqlspec.utils.serializers import to_json
|
|
17
24
|
|
|
18
25
|
if TYPE_CHECKING:
|
|
19
|
-
from
|
|
26
|
+
from sqlglot.dialects.dialect import DialectType
|
|
20
27
|
|
|
21
28
|
__all__ = ("AdbcConnection", "AdbcDriver")
|
|
22
29
|
|
|
@@ -24,30 +31,67 @@ logger = logging.getLogger("sqlspec")
|
|
|
24
31
|
|
|
25
32
|
AdbcConnection = Connection
|
|
26
33
|
|
|
27
|
-
# SQLite named parameter pattern - simple pattern to find parameter references
|
|
28
|
-
SQLITE_PARAM_PATTERN = re.compile(r"(?::|\$|@)([a-zA-Z0-9_]+)")
|
|
29
|
-
|
|
30
|
-
# Patterns to identify comments and string literals
|
|
31
|
-
SQL_COMMENT_PATTERN = re.compile(r"--[^\n]*|/\*.*?\*/", re.DOTALL)
|
|
32
|
-
SQL_STRING_PATTERN = re.compile(r"'[^']*'|\"[^\"]*\"")
|
|
33
|
-
|
|
34
34
|
|
|
35
35
|
class AdbcDriver(
|
|
36
|
-
|
|
37
|
-
SQLTranslatorMixin
|
|
38
|
-
|
|
39
|
-
|
|
36
|
+
SyncDriverAdapterProtocol["AdbcConnection", RowT],
|
|
37
|
+
SQLTranslatorMixin,
|
|
38
|
+
TypeCoercionMixin,
|
|
39
|
+
SyncStorageMixin,
|
|
40
|
+
SyncPipelinedExecutionMixin,
|
|
41
|
+
ToSchemaMixin,
|
|
40
42
|
):
|
|
41
|
-
"""ADBC Sync Driver Adapter.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
43
|
+
"""ADBC Sync Driver Adapter with modern architecture.
|
|
44
|
+
|
|
45
|
+
ADBC (Arrow Database Connectivity) provides a universal interface for connecting
|
|
46
|
+
to multiple database systems with high-performance Arrow-native data transfer.
|
|
47
|
+
|
|
48
|
+
This driver provides:
|
|
49
|
+
- Universal connectivity across database backends (PostgreSQL, SQLite, DuckDB, etc.)
|
|
50
|
+
- High-performance Arrow data streaming and bulk operations
|
|
51
|
+
- Intelligent dialect detection and parameter style handling
|
|
52
|
+
- Seamless integration with cloud databases (BigQuery, Snowflake)
|
|
53
|
+
- Driver manager abstraction for easy multi-database support
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
supports_native_arrow_import: ClassVar[bool] = True
|
|
57
|
+
supports_native_arrow_export: ClassVar[bool] = True
|
|
58
|
+
supports_native_parquet_export: ClassVar[bool] = False # Not implemented yet
|
|
59
|
+
supports_native_parquet_import: ClassVar[bool] = True
|
|
60
|
+
__slots__ = ("default_parameter_style", "dialect", "supported_parameter_styles")
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
connection: "AdbcConnection",
|
|
65
|
+
config: "Optional[SQLConfig]" = None,
|
|
66
|
+
default_row_type: "type[DictRow]" = DictRow,
|
|
67
|
+
) -> None:
|
|
68
|
+
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
69
|
+
self.dialect: DialectType = self._get_dialect(connection)
|
|
70
|
+
self.default_parameter_style = self._get_parameter_style_for_dialect(self.dialect)
|
|
71
|
+
# Override supported parameter styles based on actual dialect capabilities
|
|
72
|
+
self.supported_parameter_styles = self._get_supported_parameter_styles_for_dialect(self.dialect)
|
|
73
|
+
|
|
74
|
+
def _coerce_boolean(self, value: Any) -> Any:
|
|
75
|
+
"""ADBC boolean handling varies by underlying driver."""
|
|
76
|
+
return value
|
|
77
|
+
|
|
78
|
+
def _coerce_decimal(self, value: Any) -> Any:
|
|
79
|
+
"""ADBC decimal handling varies by underlying driver."""
|
|
80
|
+
if isinstance(value, str):
|
|
81
|
+
return Decimal(value)
|
|
82
|
+
return value
|
|
83
|
+
|
|
84
|
+
def _coerce_json(self, value: Any) -> Any:
|
|
85
|
+
"""ADBC JSON handling varies by underlying driver."""
|
|
86
|
+
if self.dialect == "sqlite" and isinstance(value, (dict, list)):
|
|
87
|
+
return to_json(value)
|
|
88
|
+
return value
|
|
89
|
+
|
|
90
|
+
def _coerce_array(self, value: Any) -> Any:
|
|
91
|
+
"""ADBC array handling varies by underlying driver."""
|
|
92
|
+
if self.dialect == "sqlite" and isinstance(value, (list, tuple)):
|
|
93
|
+
return to_json(list(value))
|
|
94
|
+
return value
|
|
51
95
|
|
|
52
96
|
@staticmethod
|
|
53
97
|
def _get_dialect(connection: "AdbcConnection") -> str:
|
|
@@ -59,621 +103,286 @@ class AdbcDriver(
|
|
|
59
103
|
Returns:
|
|
60
104
|
The database dialect.
|
|
61
105
|
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
106
|
+
try:
|
|
107
|
+
driver_info = connection.adbc_get_info()
|
|
108
|
+
vendor_name = driver_info.get("vendor_name", "").lower()
|
|
109
|
+
driver_name = driver_info.get("driver_name", "").lower()
|
|
110
|
+
|
|
111
|
+
if "postgres" in vendor_name or "postgresql" in driver_name:
|
|
112
|
+
return "postgres"
|
|
113
|
+
if "bigquery" in vendor_name or "bigquery" in driver_name:
|
|
114
|
+
return "bigquery"
|
|
115
|
+
if "sqlite" in vendor_name or "sqlite" in driver_name:
|
|
116
|
+
return "sqlite"
|
|
117
|
+
if "duckdb" in vendor_name or "duckdb" in driver_name:
|
|
118
|
+
return "duckdb"
|
|
119
|
+
if "mysql" in vendor_name or "mysql" in driver_name:
|
|
120
|
+
return "mysql"
|
|
121
|
+
if "snowflake" in vendor_name or "snowflake" in driver_name:
|
|
122
|
+
return "snowflake"
|
|
123
|
+
if "flight" in driver_name or "flightsql" in driver_name:
|
|
124
|
+
return "sqlite"
|
|
125
|
+
except Exception:
|
|
126
|
+
logger.warning("Could not reliably determine ADBC dialect from driver info. Defaulting to 'postgres'.")
|
|
127
|
+
return "postgres"
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def _get_parameter_style_for_dialect(dialect: str) -> ParameterStyle:
|
|
131
|
+
"""Get the parameter style for a given dialect."""
|
|
132
|
+
dialect_style_map = {
|
|
133
|
+
"postgres": ParameterStyle.NUMERIC,
|
|
134
|
+
"postgresql": ParameterStyle.NUMERIC,
|
|
135
|
+
"bigquery": ParameterStyle.NAMED_AT,
|
|
136
|
+
"sqlite": ParameterStyle.QMARK,
|
|
137
|
+
"duckdb": ParameterStyle.QMARK,
|
|
138
|
+
"mysql": ParameterStyle.POSITIONAL_PYFORMAT,
|
|
139
|
+
"snowflake": ParameterStyle.QMARK,
|
|
140
|
+
}
|
|
141
|
+
return dialect_style_map.get(dialect, ParameterStyle.QMARK)
|
|
76
142
|
|
|
77
143
|
@staticmethod
|
|
78
|
-
def
|
|
79
|
-
|
|
144
|
+
def _get_supported_parameter_styles_for_dialect(dialect: str) -> "tuple[ParameterStyle, ...]":
|
|
145
|
+
"""Get the supported parameter styles for a given dialect.
|
|
146
|
+
|
|
147
|
+
Each ADBC driver supports different parameter styles based on the underlying database.
|
|
148
|
+
"""
|
|
149
|
+
dialect_supported_styles_map = {
|
|
150
|
+
"postgres": (ParameterStyle.NUMERIC,), # PostgreSQL only supports $1, $2, $3
|
|
151
|
+
"postgresql": (ParameterStyle.NUMERIC,),
|
|
152
|
+
"bigquery": (ParameterStyle.NAMED_AT,), # BigQuery only supports @param
|
|
153
|
+
"sqlite": (ParameterStyle.QMARK,), # ADBC SQLite only supports ? (not :param)
|
|
154
|
+
"duckdb": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # DuckDB supports ? and $1
|
|
155
|
+
"mysql": (ParameterStyle.POSITIONAL_PYFORMAT,), # MySQL only supports %s
|
|
156
|
+
"snowflake": (ParameterStyle.QMARK, ParameterStyle.NUMERIC), # Snowflake supports ? and :1
|
|
157
|
+
}
|
|
158
|
+
return dialect_supported_styles_map.get(dialect, (ParameterStyle.QMARK,))
|
|
80
159
|
|
|
160
|
+
@staticmethod
|
|
81
161
|
@contextmanager
|
|
82
|
-
def
|
|
83
|
-
cursor =
|
|
162
|
+
def _get_cursor(connection: "AdbcConnection") -> Iterator["Cursor"]:
|
|
163
|
+
cursor = connection.cursor()
|
|
84
164
|
try:
|
|
85
165
|
yield cursor
|
|
86
166
|
finally:
|
|
87
167
|
with contextlib.suppress(Exception):
|
|
88
168
|
cursor.close() # type: ignore[no-untyped-call]
|
|
89
169
|
|
|
90
|
-
def
|
|
91
|
-
self,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
170
|
+
def _execute_statement(
|
|
171
|
+
self, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
172
|
+
) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
|
|
173
|
+
if statement.is_script:
|
|
174
|
+
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
175
|
+
return self._execute_script(sql, connection=connection, **kwargs)
|
|
176
|
+
|
|
177
|
+
# Determine if we need to convert parameter style
|
|
178
|
+
detected_styles = {p.style for p in statement.parameter_info}
|
|
179
|
+
target_style = self.default_parameter_style
|
|
180
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
181
|
+
|
|
182
|
+
if unsupported_styles:
|
|
183
|
+
target_style = self.default_parameter_style
|
|
184
|
+
elif detected_styles:
|
|
185
|
+
for style in detected_styles:
|
|
186
|
+
if style in self.supported_parameter_styles:
|
|
187
|
+
target_style = style
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
191
|
+
params = self._process_parameters(params)
|
|
192
|
+
if statement.is_many:
|
|
193
|
+
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
194
|
+
|
|
195
|
+
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
196
|
+
|
|
197
|
+
def _execute(
|
|
198
|
+
self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
199
|
+
) -> Union[SelectResultDict, DMLResultDict]:
|
|
200
|
+
conn = self._connection(connection)
|
|
201
|
+
with self._get_cursor(conn) as cursor:
|
|
202
|
+
# ADBC expects parameters as a list for most drivers
|
|
203
|
+
if parameters is not None and not isinstance(parameters, (list, tuple)):
|
|
204
|
+
cursor_params = [parameters]
|
|
122
205
|
else:
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
param_spans = [] # Store (start, end) of each parameter
|
|
183
|
-
|
|
184
|
-
for match in SQLITE_PARAM_PATTERN.finditer(masked_sql):
|
|
185
|
-
param_name = match.group(1)
|
|
186
|
-
if param_name in processed_params:
|
|
187
|
-
param_order.append(param_name)
|
|
188
|
-
param_spans.append((match.start(), match.end()))
|
|
189
|
-
|
|
190
|
-
if param_order:
|
|
191
|
-
# Replace parameters with ? placeholders in reverse order to preserve positions
|
|
192
|
-
result_sql = processed_sql
|
|
193
|
-
for i, (start, end) in enumerate(reversed(param_spans)): # noqa: B007
|
|
194
|
-
# Replace :param with ?
|
|
195
|
-
result_sql = result_sql[:start] + "?" + result_sql[start + 1 + len(param_order[-(i + 1)]) :]
|
|
196
|
-
|
|
197
|
-
return result_sql, tuple(processed_params[name] for name in param_order)
|
|
198
|
-
|
|
199
|
-
if processed_params is None:
|
|
200
|
-
return processed_sql, ()
|
|
201
|
-
if (
|
|
202
|
-
isinstance(processed_params, (tuple, list))
|
|
203
|
-
or (processed_params is not None and not isinstance(processed_params, dict))
|
|
204
|
-
) and parsed_expr is not None:
|
|
205
|
-
# Find all named placeholders
|
|
206
|
-
named_param_nodes = [
|
|
207
|
-
node
|
|
208
|
-
for node in parsed_expr.find_all(sqlglot_exp.Parameter, sqlglot_exp.Placeholder)
|
|
209
|
-
if (isinstance(node, sqlglot_exp.Parameter) and node.name and not node.name.isdigit())
|
|
210
|
-
or (
|
|
211
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
212
|
-
and node.this
|
|
213
|
-
and not isinstance(node.this, (sqlglot_exp.Identifier, sqlglot_exp.Literal))
|
|
214
|
-
and not str(node.this).isdigit()
|
|
215
|
-
)
|
|
216
|
-
]
|
|
217
|
-
|
|
218
|
-
# If we found named parameters, transform to question marks
|
|
219
|
-
if named_param_nodes:
|
|
220
|
-
|
|
221
|
-
def convert_to_qmark(node: sqlglot_exp.Expression) -> sqlglot_exp.Expression:
|
|
222
|
-
if (isinstance(node, sqlglot_exp.Parameter) and node.name and not node.name.isdigit()) or (
|
|
223
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
224
|
-
and node.this
|
|
225
|
-
and not isinstance(node.this, (sqlglot_exp.Identifier, sqlglot_exp.Literal))
|
|
226
|
-
and not str(node.this).isdigit()
|
|
227
|
-
):
|
|
228
|
-
return sqlglot_exp.Placeholder()
|
|
229
|
-
return node
|
|
230
|
-
|
|
231
|
-
# Transform the SQL
|
|
232
|
-
processed_sql = parsed_expr.transform(convert_to_qmark, copy=True).sql(dialect=self.dialect)
|
|
233
|
-
|
|
234
|
-
# If it's a scalar parameter, ensure it's wrapped in a tuple
|
|
235
|
-
if not isinstance(processed_params, (tuple, list)):
|
|
236
|
-
processed_params = (processed_params,) # type: ignore[unreachable]
|
|
237
|
-
|
|
238
|
-
# 6. Handle dictionary parameters
|
|
239
|
-
if is_dict(processed_params):
|
|
240
|
-
# Skip conversion if there's no parsed expression to work with
|
|
241
|
-
if parsed_expr is None:
|
|
242
|
-
msg = f"ADBC ({self.dialect}): Failed to parse SQL with dictionary parameters. Cannot determine parameter order."
|
|
243
|
-
raise SQLParsingError(msg)
|
|
244
|
-
|
|
245
|
-
# Collect named parameters in the order they appear in the SQL
|
|
246
|
-
named_params = []
|
|
247
|
-
for node in parsed_expr.find_all(sqlglot_exp.Parameter, sqlglot_exp.Placeholder):
|
|
248
|
-
if isinstance(node, sqlglot_exp.Parameter) and node.name and node.name in processed_params:
|
|
249
|
-
named_params.append(node.name) # type: ignore[arg-type]
|
|
250
|
-
elif (
|
|
251
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
252
|
-
and isinstance(node.this, str)
|
|
253
|
-
and node.this in processed_params
|
|
254
|
-
):
|
|
255
|
-
named_params.append(node.this) # type: ignore[arg-type]
|
|
256
|
-
|
|
257
|
-
# If we found named parameters, convert them to ? placeholders
|
|
258
|
-
if named_params:
|
|
259
|
-
# Transform SQL to use ? placeholders
|
|
260
|
-
def convert_to_qmark(node: sqlglot_exp.Expression) -> sqlglot_exp.Expression:
|
|
261
|
-
if isinstance(node, sqlglot_exp.Parameter) and node.name and node.name in processed_params:
|
|
262
|
-
return sqlglot_exp.Placeholder() # Anonymous ? placeholder
|
|
263
|
-
if (
|
|
264
|
-
isinstance(node, sqlglot_exp.Placeholder)
|
|
265
|
-
and isinstance(node.this, str)
|
|
266
|
-
and node.this in processed_params
|
|
267
|
-
):
|
|
268
|
-
return sqlglot_exp.Placeholder() # Anonymous ? placeholder
|
|
269
|
-
return node
|
|
270
|
-
|
|
271
|
-
return parsed_expr.transform(convert_to_qmark, copy=True).sql(dialect=self.dialect), tuple(
|
|
272
|
-
processed_params[name] # type: ignore[index]
|
|
273
|
-
for name in named_params
|
|
274
|
-
)
|
|
275
|
-
return processed_sql, tuple(processed_params.values())
|
|
276
|
-
if isinstance(processed_params, (list, tuple)):
|
|
277
|
-
return processed_sql, tuple(processed_params)
|
|
278
|
-
return processed_sql, (processed_params,)
|
|
279
|
-
|
|
280
|
-
@overload
|
|
281
|
-
def select(
|
|
282
|
-
self,
|
|
283
|
-
sql: str,
|
|
284
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
285
|
-
*filters: "StatementFilter",
|
|
286
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
287
|
-
schema_type: None = None,
|
|
288
|
-
**kwargs: Any,
|
|
289
|
-
) -> "Sequence[dict[str, Any]]": ...
|
|
290
|
-
@overload
|
|
291
|
-
def select(
|
|
292
|
-
self,
|
|
293
|
-
sql: str,
|
|
294
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
295
|
-
*filters: "StatementFilter",
|
|
296
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
297
|
-
schema_type: "type[ModelDTOT]",
|
|
298
|
-
**kwargs: Any,
|
|
299
|
-
) -> "Sequence[ModelDTOT]": ...
|
|
300
|
-
def select(
|
|
301
|
-
self,
|
|
302
|
-
sql: str,
|
|
303
|
-
parameters: Optional["StatementParameterType"] = None,
|
|
304
|
-
*filters: "StatementFilter",
|
|
305
|
-
connection: Optional["AdbcConnection"] = None,
|
|
306
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
307
|
-
**kwargs: Any,
|
|
308
|
-
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
|
|
309
|
-
"""Fetch data from the database.
|
|
310
|
-
|
|
311
|
-
Args:
|
|
312
|
-
sql: The SQL query string.
|
|
313
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
314
|
-
*filters: Statement filters to apply.
|
|
315
|
-
connection: Optional connection override.
|
|
316
|
-
schema_type: Optional schema class for the result.
|
|
317
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
318
|
-
|
|
319
|
-
Returns:
|
|
320
|
-
List of row data as either model instances or dictionaries.
|
|
321
|
-
"""
|
|
322
|
-
connection = self._connection(connection)
|
|
323
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
324
|
-
|
|
325
|
-
with self._with_cursor(connection) as cursor:
|
|
326
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
327
|
-
results = cursor.fetchall() # pyright: ignore
|
|
328
|
-
if not results:
|
|
329
|
-
return []
|
|
330
|
-
column_names = [column[0] for column in cursor.description or []]
|
|
331
|
-
|
|
332
|
-
return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type)
|
|
333
|
-
|
|
334
|
-
@overload
|
|
335
|
-
def select_one(
|
|
336
|
-
self,
|
|
337
|
-
sql: str,
|
|
338
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
339
|
-
*filters: "StatementFilter",
|
|
340
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
341
|
-
schema_type: None = None,
|
|
342
|
-
**kwargs: Any,
|
|
343
|
-
) -> "dict[str, Any]": ...
|
|
344
|
-
@overload
|
|
345
|
-
def select_one(
|
|
346
|
-
self,
|
|
347
|
-
sql: str,
|
|
348
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
349
|
-
*filters: "StatementFilter",
|
|
350
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
351
|
-
schema_type: "type[ModelDTOT]",
|
|
352
|
-
**kwargs: Any,
|
|
353
|
-
) -> "ModelDTOT": ...
|
|
354
|
-
def select_one(
|
|
355
|
-
self,
|
|
356
|
-
sql: str,
|
|
357
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
358
|
-
*filters: "StatementFilter",
|
|
359
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
360
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
361
|
-
**kwargs: Any,
|
|
362
|
-
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
363
|
-
"""Fetch one row from the database.
|
|
364
|
-
|
|
365
|
-
Args:
|
|
366
|
-
sql: The SQL query string.
|
|
367
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
368
|
-
*filters: Statement filters to apply.
|
|
369
|
-
connection: Optional connection override.
|
|
370
|
-
schema_type: Optional schema class for the result.
|
|
371
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
372
|
-
|
|
373
|
-
Returns:
|
|
374
|
-
The first row of the query results.
|
|
375
|
-
"""
|
|
376
|
-
connection = self._connection(connection)
|
|
377
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
378
|
-
|
|
379
|
-
with self._with_cursor(connection) as cursor:
|
|
380
|
-
cursor.execute(sql, parameters)
|
|
381
|
-
result = cursor.fetchone()
|
|
382
|
-
result = self.check_not_found(result)
|
|
383
|
-
column_names = [column[0] for column in cursor.description or []]
|
|
384
|
-
return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type)
|
|
385
|
-
|
|
386
|
-
@overload
|
|
387
|
-
def select_one_or_none(
|
|
388
|
-
self,
|
|
389
|
-
sql: str,
|
|
390
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
391
|
-
*filters: "StatementFilter",
|
|
392
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
393
|
-
schema_type: None = None,
|
|
394
|
-
**kwargs: Any,
|
|
395
|
-
) -> "Optional[dict[str, Any]]": ...
|
|
396
|
-
@overload
|
|
397
|
-
def select_one_or_none(
|
|
398
|
-
self,
|
|
399
|
-
sql: str,
|
|
400
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
401
|
-
*filters: "StatementFilter",
|
|
402
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
403
|
-
schema_type: "type[ModelDTOT]",
|
|
404
|
-
**kwargs: Any,
|
|
405
|
-
) -> "Optional[ModelDTOT]": ...
|
|
406
|
-
def select_one_or_none(
|
|
407
|
-
self,
|
|
408
|
-
sql: str,
|
|
409
|
-
parameters: Optional["StatementParameterType"] = None,
|
|
410
|
-
*filters: "StatementFilter",
|
|
411
|
-
connection: Optional["AdbcConnection"] = None,
|
|
412
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
413
|
-
**kwargs: Any,
|
|
414
|
-
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
415
|
-
"""Fetch one row from the database or return None if no rows found.
|
|
416
|
-
|
|
417
|
-
Args:
|
|
418
|
-
sql: The SQL query string.
|
|
419
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
420
|
-
*filters: Statement filters to apply.
|
|
421
|
-
connection: Optional connection override.
|
|
422
|
-
schema_type: Optional schema class for the result.
|
|
423
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
424
|
-
|
|
425
|
-
Returns:
|
|
426
|
-
The first row of the query results, or None if no results found.
|
|
427
|
-
"""
|
|
428
|
-
connection = self._connection(connection)
|
|
429
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
430
|
-
|
|
431
|
-
with self._with_cursor(connection) as cursor:
|
|
432
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
433
|
-
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
434
|
-
if result is None:
|
|
435
|
-
return None
|
|
436
|
-
column_names = [column[0] for column in cursor.description or []]
|
|
437
|
-
return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type)
|
|
438
|
-
|
|
439
|
-
@overload
|
|
440
|
-
def select_value(
|
|
441
|
-
self,
|
|
442
|
-
sql: str,
|
|
443
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
444
|
-
*filters: "StatementFilter",
|
|
445
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
446
|
-
schema_type: None = None,
|
|
447
|
-
**kwargs: Any,
|
|
448
|
-
) -> "Any": ...
|
|
449
|
-
@overload
|
|
450
|
-
def select_value(
|
|
451
|
-
self,
|
|
452
|
-
sql: str,
|
|
453
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
454
|
-
*filters: "StatementFilter",
|
|
455
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
456
|
-
schema_type: "type[T]",
|
|
457
|
-
**kwargs: Any,
|
|
458
|
-
) -> "T": ...
|
|
459
|
-
def select_value(
|
|
460
|
-
self,
|
|
461
|
-
sql: str,
|
|
462
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
463
|
-
*filters: "StatementFilter",
|
|
464
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
465
|
-
schema_type: "Optional[type[T]]" = None,
|
|
466
|
-
**kwargs: Any,
|
|
467
|
-
) -> "Union[T, Any]":
|
|
468
|
-
"""Fetch a single value from the database.
|
|
469
|
-
|
|
470
|
-
Args:
|
|
471
|
-
sql: The SQL query string.
|
|
472
|
-
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
473
|
-
*filters: Statement filters to apply.
|
|
474
|
-
connection: Optional connection override.
|
|
475
|
-
schema_type: Optional type to convert the result to.
|
|
476
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
477
|
-
|
|
478
|
-
Returns:
|
|
479
|
-
The first value of the first row of the query results.
|
|
480
|
-
"""
|
|
481
|
-
connection = self._connection(connection)
|
|
482
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
483
|
-
|
|
484
|
-
with self._with_cursor(connection) as cursor:
|
|
485
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
486
|
-
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
487
|
-
result = self.check_not_found(result) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
|
488
|
-
if schema_type is None:
|
|
489
|
-
return result[0] # pyright: ignore[reportUnknownVariableType]
|
|
490
|
-
return schema_type(result[0]) # type: ignore[call-arg]
|
|
491
|
-
|
|
492
|
-
@overload
|
|
493
|
-
def select_value_or_none(
|
|
494
|
-
self,
|
|
495
|
-
sql: str,
|
|
496
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
497
|
-
*filters: "StatementFilter",
|
|
498
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
499
|
-
schema_type: None = None,
|
|
500
|
-
**kwargs: Any,
|
|
501
|
-
) -> "Optional[Any]": ...
|
|
502
|
-
@overload
|
|
503
|
-
def select_value_or_none(
|
|
504
|
-
self,
|
|
505
|
-
sql: str,
|
|
506
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
507
|
-
*filters: "StatementFilter",
|
|
508
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
509
|
-
schema_type: "type[T]",
|
|
510
|
-
**kwargs: Any,
|
|
511
|
-
) -> "Optional[T]": ...
|
|
512
|
-
def select_value_or_none(
|
|
513
|
-
self,
|
|
514
|
-
sql: str,
|
|
515
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
516
|
-
*filters: "StatementFilter",
|
|
517
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
518
|
-
schema_type: "Optional[type[T]]" = None,
|
|
519
|
-
**kwargs: Any,
|
|
520
|
-
) -> "Optional[Union[T, Any]]":
|
|
521
|
-
"""Fetch a single value or None if not found.
|
|
206
|
+
cursor_params = parameters # type: ignore[assignment]
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
cursor.execute(sql, cursor_params or [])
|
|
210
|
+
except Exception as e:
|
|
211
|
+
# Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
|
|
212
|
+
if self.dialect == "postgres":
|
|
213
|
+
with contextlib.suppress(Exception):
|
|
214
|
+
cursor.execute("ROLLBACK")
|
|
215
|
+
raise e from e
|
|
216
|
+
|
|
217
|
+
if self.returns_rows(statement.expression):
|
|
218
|
+
fetched_data = cursor.fetchall()
|
|
219
|
+
column_names = [col[0] for col in cursor.description or []]
|
|
220
|
+
result: SelectResultDict = {
|
|
221
|
+
"data": fetched_data,
|
|
222
|
+
"column_names": column_names,
|
|
223
|
+
"rows_affected": len(fetched_data),
|
|
224
|
+
}
|
|
225
|
+
return result
|
|
226
|
+
|
|
227
|
+
dml_result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
|
|
228
|
+
return dml_result
|
|
229
|
+
|
|
230
|
+
def _execute_many(
|
|
231
|
+
self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
232
|
+
) -> DMLResultDict:
|
|
233
|
+
conn = self._connection(connection)
|
|
234
|
+
with self._get_cursor(conn) as cursor:
|
|
235
|
+
try:
|
|
236
|
+
cursor.executemany(sql, param_list or [])
|
|
237
|
+
except Exception as e:
|
|
238
|
+
if self.dialect == "postgres":
|
|
239
|
+
with contextlib.suppress(Exception):
|
|
240
|
+
cursor.execute("ROLLBACK")
|
|
241
|
+
# Always re-raise the original exception
|
|
242
|
+
raise e from e
|
|
243
|
+
|
|
244
|
+
result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
|
|
245
|
+
return result
|
|
246
|
+
|
|
247
|
+
def _execute_script(
|
|
248
|
+
self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
249
|
+
) -> ScriptResultDict:
|
|
250
|
+
conn = self._connection(connection)
|
|
251
|
+
# ADBC drivers don't support multiple statements in a single execute
|
|
252
|
+
# Use the shared implementation to split the script
|
|
253
|
+
statements = self._split_script_statements(script)
|
|
254
|
+
|
|
255
|
+
executed_count = 0
|
|
256
|
+
with self._get_cursor(conn) as cursor:
|
|
257
|
+
for statement in statements:
|
|
258
|
+
executed_count += self._execute_single_script_statement(cursor, statement)
|
|
259
|
+
|
|
260
|
+
result: ScriptResultDict = {"statements_executed": executed_count, "status_message": "SCRIPT EXECUTED"}
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
def _execute_single_script_statement(self, cursor: "Cursor", statement: str) -> int:
|
|
264
|
+
"""Execute a single statement from a script and handle errors.
|
|
522
265
|
|
|
523
266
|
Args:
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
*filters: Statement filters to apply.
|
|
527
|
-
connection: Optional connection override.
|
|
528
|
-
schema_type: Optional type to convert the result to.
|
|
529
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
267
|
+
cursor: The database cursor
|
|
268
|
+
statement: The SQL statement to execute
|
|
530
269
|
|
|
531
270
|
Returns:
|
|
532
|
-
|
|
271
|
+
1 if successful, 0 if failed
|
|
533
272
|
"""
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
Args:
|
|
605
|
-
sql: The SQL statement string.
|
|
606
|
-
parameters: The parameters for the statement (dict, tuple, list, or None).
|
|
607
|
-
*filters: Statement filters to apply.
|
|
608
|
-
connection: Optional connection override.
|
|
609
|
-
schema_type: Optional schema class for the result.
|
|
610
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
611
|
-
|
|
612
|
-
Returns:
|
|
613
|
-
The returned row data, or None if no row returned.
|
|
614
|
-
"""
|
|
615
|
-
connection = self._connection(connection)
|
|
616
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
617
|
-
|
|
618
|
-
with self._with_cursor(connection) as cursor:
|
|
619
|
-
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
620
|
-
result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
621
|
-
if not result:
|
|
622
|
-
return None
|
|
623
|
-
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
624
|
-
return self.to_schema(dict(zip(column_names, result[0])), schema_type=schema_type)
|
|
625
|
-
|
|
626
|
-
def execute_script(
|
|
627
|
-
self,
|
|
628
|
-
sql: str,
|
|
629
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
630
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
631
|
-
**kwargs: Any,
|
|
632
|
-
) -> str:
|
|
633
|
-
"""Execute a SQL script.
|
|
273
|
+
try:
|
|
274
|
+
cursor.execute(statement)
|
|
275
|
+
except Exception as e:
|
|
276
|
+
# Rollback transaction on error for PostgreSQL
|
|
277
|
+
if self.dialect == "postgres":
|
|
278
|
+
with contextlib.suppress(Exception):
|
|
279
|
+
cursor.execute("ROLLBACK")
|
|
280
|
+
raise e from e
|
|
281
|
+
else:
|
|
282
|
+
return 1
|
|
283
|
+
|
|
284
|
+
def _wrap_select_result(
|
|
285
|
+
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
286
|
+
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
287
|
+
# result must be a dict with keys: data, column_names, rows_affected
|
|
288
|
+
|
|
289
|
+
rows_as_dicts = [dict(zip(result["column_names"], row)) for row in result["data"]]
|
|
290
|
+
|
|
291
|
+
if schema_type:
|
|
292
|
+
return SQLResult[ModelDTOT](
|
|
293
|
+
statement=statement,
|
|
294
|
+
data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
|
|
295
|
+
column_names=result["column_names"],
|
|
296
|
+
rows_affected=result["rows_affected"],
|
|
297
|
+
operation_type="SELECT",
|
|
298
|
+
)
|
|
299
|
+
return SQLResult[RowT](
|
|
300
|
+
statement=statement,
|
|
301
|
+
data=rows_as_dicts,
|
|
302
|
+
column_names=result["column_names"],
|
|
303
|
+
rows_affected=result["rows_affected"],
|
|
304
|
+
operation_type="SELECT",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
def _wrap_execute_result(
|
|
308
|
+
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
309
|
+
) -> SQLResult[RowT]:
|
|
310
|
+
operation_type = (
|
|
311
|
+
str(statement.expression.key).upper()
|
|
312
|
+
if statement.expression and hasattr(statement.expression, "key")
|
|
313
|
+
else "UNKNOWN"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Handle TypedDict results
|
|
317
|
+
if is_dict_with_field(result, "statements_executed"):
|
|
318
|
+
return SQLResult[RowT](
|
|
319
|
+
statement=statement,
|
|
320
|
+
data=[],
|
|
321
|
+
rows_affected=0,
|
|
322
|
+
total_statements=result["statements_executed"],
|
|
323
|
+
operation_type="SCRIPT", # Scripts always have operation_type SCRIPT
|
|
324
|
+
metadata={"status_message": result["status_message"]},
|
|
325
|
+
)
|
|
326
|
+
if is_dict_with_field(result, "rows_affected"):
|
|
327
|
+
return SQLResult[RowT](
|
|
328
|
+
statement=statement,
|
|
329
|
+
data=[],
|
|
330
|
+
rows_affected=result["rows_affected"],
|
|
331
|
+
operation_type=operation_type,
|
|
332
|
+
metadata={"status_message": result["status_message"]},
|
|
333
|
+
)
|
|
334
|
+
msg = f"Unexpected result type: {type(result)}"
|
|
335
|
+
raise ValueError(msg)
|
|
336
|
+
|
|
337
|
+
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
338
|
+
"""ADBC native Arrow table fetching.
|
|
339
|
+
|
|
340
|
+
ADBC has excellent native Arrow support through cursor.fetch_arrow_table()
|
|
341
|
+
This provides zero-copy data transfer for optimal performance.
|
|
634
342
|
|
|
635
343
|
Args:
|
|
636
|
-
sql:
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
connection: Optional connection override.
|
|
640
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
344
|
+
sql: Processed SQL object
|
|
345
|
+
connection: Optional connection override
|
|
346
|
+
**kwargs: Additional options (e.g., batch_size for streaming)
|
|
641
347
|
|
|
642
348
|
Returns:
|
|
643
|
-
|
|
349
|
+
ArrowResult with native Arrow table
|
|
644
350
|
"""
|
|
645
|
-
|
|
646
|
-
|
|
351
|
+
self._ensure_pyarrow_installed()
|
|
352
|
+
conn = self._connection(connection)
|
|
647
353
|
|
|
648
|
-
with self.
|
|
649
|
-
|
|
650
|
-
|
|
354
|
+
with wrap_exceptions(), self._get_cursor(conn) as cursor:
|
|
355
|
+
# Execute the query
|
|
356
|
+
params = sql.get_parameters(style=self.default_parameter_style)
|
|
357
|
+
# ADBC expects parameters as a list for most drivers
|
|
358
|
+
cursor_params = [params] if params is not None and not isinstance(params, (list, tuple)) else params
|
|
359
|
+
cursor.execute(sql.to_sql(placeholder_style=self.default_parameter_style), cursor_params or [])
|
|
360
|
+
arrow_table = cursor.fetch_arrow_table()
|
|
361
|
+
return ArrowResult(statement=sql, data=arrow_table)
|
|
651
362
|
|
|
652
|
-
|
|
363
|
+
def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|
|
364
|
+
"""ADBC-optimized Arrow table ingestion using native bulk insert.
|
|
653
365
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
sql: str,
|
|
657
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
658
|
-
*filters: "StatementFilter",
|
|
659
|
-
connection: "Optional[AdbcConnection]" = None,
|
|
660
|
-
**kwargs: Any,
|
|
661
|
-
) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType]
|
|
662
|
-
"""Execute a SQL query and return results as an Apache Arrow Table.
|
|
366
|
+
ADBC drivers often support native Arrow table ingestion for high-performance
|
|
367
|
+
bulk loading operations.
|
|
663
368
|
|
|
664
369
|
Args:
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
|
|
370
|
+
table: Arrow table to ingest
|
|
371
|
+
table_name: Target database table name
|
|
372
|
+
mode: Ingestion mode ('append', 'replace', 'create')
|
|
373
|
+
**options: Additional ADBC-specific options
|
|
670
374
|
|
|
671
375
|
Returns:
|
|
672
|
-
|
|
376
|
+
Number of rows ingested
|
|
673
377
|
"""
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
with self.
|
|
678
|
-
|
|
679
|
-
|
|
378
|
+
self._ensure_pyarrow_installed()
|
|
379
|
+
|
|
380
|
+
conn = self._connection(None)
|
|
381
|
+
with self._get_cursor(conn) as cursor:
|
|
382
|
+
# Handle different modes
|
|
383
|
+
if mode == "replace":
|
|
384
|
+
cursor.execute(SQL(f"TRUNCATE TABLE {table_name}").to_sql(placeholder_style=ParameterStyle.STATIC))
|
|
385
|
+
elif mode == "create":
|
|
386
|
+
msg = "'create' mode is not supported for ADBC ingestion"
|
|
387
|
+
raise NotImplementedError(msg)
|
|
388
|
+
return cursor.adbc_ingest(table_name, table, mode=mode, **options) # type: ignore[arg-type]
|