sqlspec 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_serialization.py +223 -21
- sqlspec/_sql.py +12 -50
- sqlspec/_typing.py +9 -0
- sqlspec/adapters/adbc/config.py +8 -1
- sqlspec/adapters/adbc/data_dictionary.py +290 -0
- sqlspec/adapters/adbc/driver.py +127 -18
- sqlspec/adapters/adbc/type_converter.py +159 -0
- sqlspec/adapters/aiosqlite/config.py +3 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
- sqlspec/adapters/aiosqlite/driver.py +17 -3
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/config.py +11 -8
- sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
- sqlspec/adapters/asyncmy/driver.py +31 -7
- sqlspec/adapters/asyncpg/config.py +3 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
- sqlspec/adapters/asyncpg/driver.py +19 -4
- sqlspec/adapters/bigquery/config.py +3 -0
- sqlspec/adapters/bigquery/data_dictionary.py +109 -0
- sqlspec/adapters/bigquery/driver.py +21 -3
- sqlspec/adapters/bigquery/type_converter.py +93 -0
- sqlspec/adapters/duckdb/_types.py +1 -1
- sqlspec/adapters/duckdb/config.py +2 -0
- sqlspec/adapters/duckdb/data_dictionary.py +124 -0
- sqlspec/adapters/duckdb/driver.py +32 -5
- sqlspec/adapters/duckdb/pool.py +1 -1
- sqlspec/adapters/duckdb/type_converter.py +103 -0
- sqlspec/adapters/oracledb/config.py +6 -0
- sqlspec/adapters/oracledb/data_dictionary.py +442 -0
- sqlspec/adapters/oracledb/driver.py +63 -9
- sqlspec/adapters/oracledb/migrations.py +51 -67
- sqlspec/adapters/oracledb/type_converter.py +132 -0
- sqlspec/adapters/psqlpy/config.py +3 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
- sqlspec/adapters/psqlpy/driver.py +23 -179
- sqlspec/adapters/psqlpy/type_converter.py +73 -0
- sqlspec/adapters/psycopg/config.py +6 -0
- sqlspec/adapters/psycopg/data_dictionary.py +257 -0
- sqlspec/adapters/psycopg/driver.py +40 -5
- sqlspec/adapters/sqlite/config.py +3 -0
- sqlspec/adapters/sqlite/data_dictionary.py +117 -0
- sqlspec/adapters/sqlite/driver.py +18 -3
- sqlspec/adapters/sqlite/pool.py +13 -4
- sqlspec/builder/_base.py +82 -42
- sqlspec/builder/_column.py +57 -24
- sqlspec/builder/_ddl.py +84 -34
- sqlspec/builder/_insert.py +30 -52
- sqlspec/builder/_parsing_utils.py +104 -8
- sqlspec/builder/_select.py +147 -2
- sqlspec/builder/mixins/_cte_and_set_ops.py +1 -2
- sqlspec/builder/mixins/_join_operations.py +14 -30
- sqlspec/builder/mixins/_merge_operations.py +167 -61
- sqlspec/builder/mixins/_order_limit_operations.py +3 -10
- sqlspec/builder/mixins/_select_operations.py +3 -9
- sqlspec/builder/mixins/_update_operations.py +3 -22
- sqlspec/builder/mixins/_where_clause.py +4 -10
- sqlspec/cli.py +246 -140
- sqlspec/config.py +33 -19
- sqlspec/core/cache.py +2 -2
- sqlspec/core/compiler.py +56 -1
- sqlspec/core/parameters.py +7 -3
- sqlspec/core/statement.py +5 -0
- sqlspec/core/type_conversion.py +234 -0
- sqlspec/driver/__init__.py +6 -3
- sqlspec/driver/_async.py +106 -3
- sqlspec/driver/_common.py +156 -4
- sqlspec/driver/_sync.py +106 -3
- sqlspec/exceptions.py +5 -0
- sqlspec/migrations/__init__.py +4 -3
- sqlspec/migrations/base.py +153 -14
- sqlspec/migrations/commands.py +34 -96
- sqlspec/migrations/context.py +145 -0
- sqlspec/migrations/loaders.py +25 -8
- sqlspec/migrations/runner.py +352 -82
- sqlspec/typing.py +2 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/serializers.py +50 -2
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
- sqlspec-0.26.0.dist-info/RECORD +157 -0
- sqlspec-0.25.0.dist-info/RECORD +0 -139
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""ADBC multi-dialect data dictionary for metadata queries."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import TYPE_CHECKING, Optional, cast
|
|
5
|
+
|
|
6
|
+
from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo
|
|
7
|
+
from sqlspec.utils.logging import get_logger
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
|
|
12
|
+
from sqlspec.adapters.adbc.driver import AdbcDriver
|
|
13
|
+
|
|
14
|
+
logger = get_logger("adapters.adbc.data_dictionary")
|
|
15
|
+
|
|
16
|
+
POSTGRES_VERSION_PATTERN = re.compile(r"PostgreSQL (\d+)\.(\d+)(?:\.(\d+))?")
|
|
17
|
+
SQLITE_VERSION_PATTERN = re.compile(r"(\d+)\.(\d+)\.(\d+)")
|
|
18
|
+
DUCKDB_VERSION_PATTERN = re.compile(r"v?(\d+)\.(\d+)\.(\d+)")
|
|
19
|
+
MYSQL_VERSION_PATTERN = re.compile(r"(\d+)\.(\d+)\.(\d+)")
|
|
20
|
+
|
|
21
|
+
__all__ = ("AdbcDataDictionary",)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AdbcDataDictionary(SyncDataDictionaryBase):
|
|
25
|
+
"""ADBC multi-dialect data dictionary.
|
|
26
|
+
|
|
27
|
+
Delegates to appropriate dialect-specific logic based on the driver's dialect.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def _get_dialect(self, driver: SyncDriverAdapterBase) -> str:
|
|
31
|
+
"""Get dialect from ADBC driver.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
driver: ADBC driver instance
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Dialect name
|
|
38
|
+
"""
|
|
39
|
+
return str(cast("AdbcDriver", driver).dialect)
|
|
40
|
+
|
|
41
|
+
def get_version(self, driver: SyncDriverAdapterBase) -> "Optional[VersionInfo]":
|
|
42
|
+
"""Get database version information based on detected dialect.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
driver: ADBC driver instance
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Database version information or None if detection fails
|
|
49
|
+
"""
|
|
50
|
+
dialect = self._get_dialect(driver)
|
|
51
|
+
adbc_driver = cast("AdbcDriver", driver)
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
if dialect == "postgres":
|
|
55
|
+
version_str = adbc_driver.select_value("SELECT version()")
|
|
56
|
+
if version_str:
|
|
57
|
+
match = POSTGRES_VERSION_PATTERN.search(str(version_str))
|
|
58
|
+
if match:
|
|
59
|
+
major = int(match.group(1))
|
|
60
|
+
minor = int(match.group(2))
|
|
61
|
+
patch = int(match.group(3)) if match.group(3) else 0
|
|
62
|
+
return VersionInfo(major, minor, patch)
|
|
63
|
+
|
|
64
|
+
elif dialect == "sqlite":
|
|
65
|
+
version_str = adbc_driver.select_value("SELECT sqlite_version()")
|
|
66
|
+
if version_str:
|
|
67
|
+
match = SQLITE_VERSION_PATTERN.match(str(version_str))
|
|
68
|
+
if match:
|
|
69
|
+
major, minor, patch = map(int, match.groups())
|
|
70
|
+
return VersionInfo(major, minor, patch)
|
|
71
|
+
|
|
72
|
+
elif dialect == "duckdb":
|
|
73
|
+
version_str = adbc_driver.select_value("SELECT version()")
|
|
74
|
+
if version_str:
|
|
75
|
+
match = DUCKDB_VERSION_PATTERN.search(str(version_str))
|
|
76
|
+
if match:
|
|
77
|
+
major, minor, patch = map(int, match.groups())
|
|
78
|
+
return VersionInfo(major, minor, patch)
|
|
79
|
+
|
|
80
|
+
elif dialect == "mysql":
|
|
81
|
+
version_str = adbc_driver.select_value("SELECT VERSION()")
|
|
82
|
+
if version_str:
|
|
83
|
+
match = MYSQL_VERSION_PATTERN.search(str(version_str))
|
|
84
|
+
if match:
|
|
85
|
+
major, minor, patch = map(int, match.groups())
|
|
86
|
+
return VersionInfo(major, minor, patch)
|
|
87
|
+
|
|
88
|
+
elif dialect == "bigquery":
|
|
89
|
+
return VersionInfo(1, 0, 0)
|
|
90
|
+
|
|
91
|
+
except Exception:
|
|
92
|
+
logger.warning("Failed to get %s version", dialect)
|
|
93
|
+
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
|
|
97
|
+
"""Check if database supports a specific feature based on detected dialect.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
driver: ADBC driver instance
|
|
101
|
+
feature: Feature name to check
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
True if feature is supported, False otherwise
|
|
105
|
+
"""
|
|
106
|
+
dialect = self._get_dialect(driver)
|
|
107
|
+
version_info = self.get_version(driver)
|
|
108
|
+
|
|
109
|
+
if dialect == "postgres":
|
|
110
|
+
feature_checks: dict[str, Callable[..., bool]] = {
|
|
111
|
+
"supports_json": lambda v: v and v >= VersionInfo(9, 2, 0),
|
|
112
|
+
"supports_jsonb": lambda v: v and v >= VersionInfo(9, 4, 0),
|
|
113
|
+
"supports_uuid": lambda _: True,
|
|
114
|
+
"supports_arrays": lambda _: True,
|
|
115
|
+
"supports_returning": lambda v: v and v >= VersionInfo(8, 2, 0),
|
|
116
|
+
"supports_upsert": lambda v: v and v >= VersionInfo(9, 5, 0),
|
|
117
|
+
"supports_window_functions": lambda v: v and v >= VersionInfo(8, 4, 0),
|
|
118
|
+
"supports_cte": lambda v: v and v >= VersionInfo(8, 4, 0),
|
|
119
|
+
"supports_transactions": lambda _: True,
|
|
120
|
+
"supports_prepared_statements": lambda _: True,
|
|
121
|
+
"supports_schemas": lambda _: True,
|
|
122
|
+
}
|
|
123
|
+
elif dialect == "sqlite":
|
|
124
|
+
feature_checks = {
|
|
125
|
+
"supports_json": lambda v: v and v >= VersionInfo(3, 38, 0),
|
|
126
|
+
"supports_returning": lambda v: v and v >= VersionInfo(3, 35, 0),
|
|
127
|
+
"supports_upsert": lambda v: v and v >= VersionInfo(3, 24, 0),
|
|
128
|
+
"supports_window_functions": lambda v: v and v >= VersionInfo(3, 25, 0),
|
|
129
|
+
"supports_cte": lambda v: v and v >= VersionInfo(3, 8, 3),
|
|
130
|
+
"supports_transactions": lambda _: True,
|
|
131
|
+
"supports_prepared_statements": lambda _: True,
|
|
132
|
+
"supports_schemas": lambda _: False,
|
|
133
|
+
"supports_arrays": lambda _: False,
|
|
134
|
+
"supports_uuid": lambda _: False,
|
|
135
|
+
}
|
|
136
|
+
elif dialect == "duckdb":
|
|
137
|
+
feature_checks = {
|
|
138
|
+
"supports_json": lambda _: True,
|
|
139
|
+
"supports_arrays": lambda _: True,
|
|
140
|
+
"supports_uuid": lambda _: True,
|
|
141
|
+
"supports_returning": lambda v: v and v >= VersionInfo(0, 8, 0),
|
|
142
|
+
"supports_upsert": lambda v: v and v >= VersionInfo(0, 8, 0),
|
|
143
|
+
"supports_window_functions": lambda _: True,
|
|
144
|
+
"supports_cte": lambda _: True,
|
|
145
|
+
"supports_transactions": lambda _: True,
|
|
146
|
+
"supports_prepared_statements": lambda _: True,
|
|
147
|
+
"supports_schemas": lambda _: True,
|
|
148
|
+
}
|
|
149
|
+
elif dialect == "mysql":
|
|
150
|
+
feature_checks = {
|
|
151
|
+
"supports_json": lambda v: v and v >= VersionInfo(5, 7, 8),
|
|
152
|
+
"supports_cte": lambda v: v and v >= VersionInfo(8, 0, 1),
|
|
153
|
+
"supports_returning": lambda _: False,
|
|
154
|
+
"supports_upsert": lambda _: True,
|
|
155
|
+
"supports_window_functions": lambda v: v and v >= VersionInfo(8, 0, 2),
|
|
156
|
+
"supports_transactions": lambda _: True,
|
|
157
|
+
"supports_prepared_statements": lambda _: True,
|
|
158
|
+
"supports_schemas": lambda _: True,
|
|
159
|
+
"supports_uuid": lambda _: False,
|
|
160
|
+
"supports_arrays": lambda _: False,
|
|
161
|
+
}
|
|
162
|
+
elif dialect == "bigquery":
|
|
163
|
+
feature_checks = {
|
|
164
|
+
"supports_json": lambda _: True,
|
|
165
|
+
"supports_arrays": lambda _: True,
|
|
166
|
+
"supports_structs": lambda _: True,
|
|
167
|
+
"supports_returning": lambda _: False,
|
|
168
|
+
"supports_upsert": lambda _: True,
|
|
169
|
+
"supports_window_functions": lambda _: True,
|
|
170
|
+
"supports_cte": lambda _: True,
|
|
171
|
+
"supports_transactions": lambda _: False,
|
|
172
|
+
"supports_prepared_statements": lambda _: True,
|
|
173
|
+
"supports_schemas": lambda _: True,
|
|
174
|
+
"supports_uuid": lambda _: False,
|
|
175
|
+
}
|
|
176
|
+
else:
|
|
177
|
+
feature_checks = {
|
|
178
|
+
"supports_transactions": lambda _: True,
|
|
179
|
+
"supports_prepared_statements": lambda _: True,
|
|
180
|
+
"supports_window_functions": lambda _: True,
|
|
181
|
+
"supports_cte": lambda _: True,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
if feature in feature_checks:
|
|
185
|
+
return bool(feature_checks[feature](version_info))
|
|
186
|
+
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> str:
|
|
190
|
+
"""Get optimal database type for a category based on detected dialect.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
driver: ADBC driver instance
|
|
194
|
+
type_category: Type category
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Database-specific type name
|
|
198
|
+
"""
|
|
199
|
+
dialect = self._get_dialect(driver)
|
|
200
|
+
version_info = self.get_version(driver)
|
|
201
|
+
|
|
202
|
+
if dialect == "postgres":
|
|
203
|
+
if type_category == "json":
|
|
204
|
+
if version_info and version_info >= VersionInfo(9, 4, 0):
|
|
205
|
+
return "JSONB"
|
|
206
|
+
if version_info and version_info >= VersionInfo(9, 2, 0):
|
|
207
|
+
return "JSON"
|
|
208
|
+
return "TEXT"
|
|
209
|
+
type_map = {
|
|
210
|
+
"uuid": "UUID",
|
|
211
|
+
"boolean": "BOOLEAN",
|
|
212
|
+
"timestamp": "TIMESTAMP WITH TIME ZONE",
|
|
213
|
+
"text": "TEXT",
|
|
214
|
+
"blob": "BYTEA",
|
|
215
|
+
"array": "ARRAY",
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
elif dialect == "sqlite":
|
|
219
|
+
if type_category == "json":
|
|
220
|
+
if version_info and version_info >= VersionInfo(3, 38, 0):
|
|
221
|
+
return "JSON"
|
|
222
|
+
return "TEXT"
|
|
223
|
+
type_map = {"uuid": "TEXT", "boolean": "INTEGER", "timestamp": "TIMESTAMP", "text": "TEXT", "blob": "BLOB"}
|
|
224
|
+
|
|
225
|
+
elif dialect == "duckdb":
|
|
226
|
+
type_map = {
|
|
227
|
+
"json": "JSON",
|
|
228
|
+
"uuid": "UUID",
|
|
229
|
+
"boolean": "BOOLEAN",
|
|
230
|
+
"timestamp": "TIMESTAMP",
|
|
231
|
+
"text": "TEXT",
|
|
232
|
+
"blob": "BLOB",
|
|
233
|
+
"array": "LIST",
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
elif dialect == "mysql":
|
|
237
|
+
if type_category == "json":
|
|
238
|
+
if version_info and version_info >= VersionInfo(5, 7, 8):
|
|
239
|
+
return "JSON"
|
|
240
|
+
return "TEXT"
|
|
241
|
+
type_map = {
|
|
242
|
+
"uuid": "VARCHAR(36)",
|
|
243
|
+
"boolean": "TINYINT(1)",
|
|
244
|
+
"timestamp": "TIMESTAMP",
|
|
245
|
+
"text": "TEXT",
|
|
246
|
+
"blob": "BLOB",
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
elif dialect == "bigquery":
|
|
250
|
+
type_map = {
|
|
251
|
+
"json": "JSON",
|
|
252
|
+
"uuid": "STRING",
|
|
253
|
+
"boolean": "BOOL",
|
|
254
|
+
"timestamp": "TIMESTAMP",
|
|
255
|
+
"text": "STRING",
|
|
256
|
+
"blob": "BYTES",
|
|
257
|
+
"array": "ARRAY",
|
|
258
|
+
}
|
|
259
|
+
else:
|
|
260
|
+
type_map = {
|
|
261
|
+
"json": "TEXT",
|
|
262
|
+
"uuid": "VARCHAR(36)",
|
|
263
|
+
"boolean": "INTEGER",
|
|
264
|
+
"timestamp": "TIMESTAMP",
|
|
265
|
+
"text": "TEXT",
|
|
266
|
+
"blob": "BLOB",
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
return type_map.get(type_category, "TEXT")
|
|
270
|
+
|
|
271
|
+
def list_available_features(self) -> "list[str]":
|
|
272
|
+
"""List available feature flags across all supported dialects.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
List of supported feature names
|
|
276
|
+
"""
|
|
277
|
+
return [
|
|
278
|
+
"supports_json",
|
|
279
|
+
"supports_jsonb",
|
|
280
|
+
"supports_uuid",
|
|
281
|
+
"supports_arrays",
|
|
282
|
+
"supports_structs",
|
|
283
|
+
"supports_returning",
|
|
284
|
+
"supports_upsert",
|
|
285
|
+
"supports_window_functions",
|
|
286
|
+
"supports_cte",
|
|
287
|
+
"supports_transactions",
|
|
288
|
+
"supports_prepared_statements",
|
|
289
|
+
"supports_schemas",
|
|
290
|
+
]
|
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -12,13 +12,15 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
12
12
|
from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
|
|
13
13
|
from sqlglot import exp
|
|
14
14
|
|
|
15
|
+
from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
|
|
16
|
+
from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter
|
|
15
17
|
from sqlspec.core.cache import get_cache_config
|
|
16
18
|
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
|
|
17
19
|
from sqlspec.core.statement import SQL, StatementConfig
|
|
18
20
|
from sqlspec.driver import SyncDriverAdapterBase
|
|
19
21
|
from sqlspec.exceptions import MissingDependencyError, SQLParsingError, SQLSpecError
|
|
22
|
+
from sqlspec.typing import Empty
|
|
20
23
|
from sqlspec.utils.logging import get_logger
|
|
21
|
-
from sqlspec.utils.serializers import to_json
|
|
22
24
|
|
|
23
25
|
if TYPE_CHECKING:
|
|
24
26
|
from contextlib import AbstractContextManager
|
|
@@ -28,6 +30,7 @@ if TYPE_CHECKING:
|
|
|
28
30
|
from sqlspec.adapters.adbc._types import AdbcConnection
|
|
29
31
|
from sqlspec.core.result import SQLResult
|
|
30
32
|
from sqlspec.driver import ExecutionResult
|
|
33
|
+
from sqlspec.driver._sync import SyncDataDictionaryBase
|
|
31
34
|
|
|
32
35
|
__all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config")
|
|
33
36
|
|
|
@@ -293,7 +296,7 @@ def _convert_array_for_postgres_adbc(value: Any) -> Any:
|
|
|
293
296
|
|
|
294
297
|
|
|
295
298
|
def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
|
|
296
|
-
"""Get type coercion map for Arrow type handling.
|
|
299
|
+
"""Get type coercion map for Arrow type handling with dialect-aware type conversion.
|
|
297
300
|
|
|
298
301
|
Args:
|
|
299
302
|
dialect: Database dialect name
|
|
@@ -301,7 +304,9 @@ def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
|
|
|
301
304
|
Returns:
|
|
302
305
|
Mapping of Python types to conversion functions
|
|
303
306
|
"""
|
|
304
|
-
|
|
307
|
+
tc = ADBCTypeConverter(dialect)
|
|
308
|
+
|
|
309
|
+
return {
|
|
305
310
|
datetime.datetime: lambda x: x,
|
|
306
311
|
datetime.date: lambda x: x,
|
|
307
312
|
datetime.time: lambda x: x,
|
|
@@ -309,18 +314,13 @@ def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
|
|
|
309
314
|
bool: lambda x: x,
|
|
310
315
|
int: lambda x: x,
|
|
311
316
|
float: lambda x: x,
|
|
312
|
-
str:
|
|
317
|
+
str: tc.convert_if_detected,
|
|
313
318
|
bytes: lambda x: x,
|
|
314
319
|
tuple: _convert_array_for_postgres_adbc,
|
|
315
320
|
list: _convert_array_for_postgres_adbc,
|
|
316
321
|
dict: lambda x: x,
|
|
317
322
|
}
|
|
318
323
|
|
|
319
|
-
if dialect in {"postgres", "postgresql"}:
|
|
320
|
-
type_map[dict] = lambda x: to_json(x) if x is not None else None
|
|
321
|
-
|
|
322
|
-
return type_map
|
|
323
|
-
|
|
324
324
|
|
|
325
325
|
class AdbcCursor:
|
|
326
326
|
"""Context manager for cursor management."""
|
|
@@ -335,8 +335,7 @@ class AdbcCursor:
|
|
|
335
335
|
self.cursor = self.connection.cursor()
|
|
336
336
|
return self.cursor
|
|
337
337
|
|
|
338
|
-
def __exit__(self,
|
|
339
|
-
_ = (exc_type, exc_val, exc_tb)
|
|
338
|
+
def __exit__(self, *_: Any) -> None:
|
|
340
339
|
if self.cursor is not None:
|
|
341
340
|
with contextlib.suppress(Exception):
|
|
342
341
|
self.cursor.close() # type: ignore[no-untyped-call]
|
|
@@ -394,7 +393,7 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
394
393
|
database dialects, parameter style conversion, and transaction management.
|
|
395
394
|
"""
|
|
396
395
|
|
|
397
|
-
__slots__ = ("_detected_dialect", "dialect")
|
|
396
|
+
__slots__ = ("_data_dictionary", "_detected_dialect", "dialect")
|
|
398
397
|
|
|
399
398
|
def __init__(
|
|
400
399
|
self,
|
|
@@ -413,6 +412,7 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
413
412
|
|
|
414
413
|
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
|
|
415
414
|
self.dialect = statement_config.dialect
|
|
415
|
+
self._data_dictionary: Optional[SyncDataDictionaryBase] = None
|
|
416
416
|
|
|
417
417
|
@staticmethod
|
|
418
418
|
def _ensure_pyarrow_installed() -> None:
|
|
@@ -475,6 +475,87 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
475
475
|
return None
|
|
476
476
|
return parameters
|
|
477
477
|
|
|
478
|
+
def prepare_driver_parameters(
|
|
479
|
+
self,
|
|
480
|
+
parameters: Any,
|
|
481
|
+
statement_config: "StatementConfig",
|
|
482
|
+
is_many: bool = False,
|
|
483
|
+
prepared_statement: Optional[Any] = None,
|
|
484
|
+
) -> Any:
|
|
485
|
+
"""Prepare parameters with cast-aware type coercion for ADBC.
|
|
486
|
+
|
|
487
|
+
For PostgreSQL, applies cast-aware parameter processing using metadata from the compiled statement.
|
|
488
|
+
This allows proper handling of JSONB casts and other type conversions.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
parameters: Parameters in any format
|
|
492
|
+
statement_config: Statement configuration
|
|
493
|
+
is_many: Whether this is for execute_many operation
|
|
494
|
+
prepared_statement: Prepared statement containing the original SQL statement
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Parameters with cast-aware type coercion applied
|
|
498
|
+
"""
|
|
499
|
+
if prepared_statement and self.dialect in {"postgres", "postgresql"} and not is_many:
|
|
500
|
+
parameter_casts = self._get_parameter_casts(prepared_statement)
|
|
501
|
+
postgres_compatible = self._handle_postgres_empty_parameters(parameters)
|
|
502
|
+
return self._prepare_parameters_with_casts(postgres_compatible, parameter_casts, statement_config)
|
|
503
|
+
|
|
504
|
+
return super().prepare_driver_parameters(parameters, statement_config, is_many, prepared_statement)
|
|
505
|
+
|
|
506
|
+
def _get_parameter_casts(self, statement: SQL) -> "dict[int, str]":
|
|
507
|
+
"""Get parameter cast metadata from compiled statement.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
statement: SQL statement with compiled metadata
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
Dict mapping parameter positions to cast types
|
|
514
|
+
"""
|
|
515
|
+
|
|
516
|
+
processed_state = statement.get_processed_state()
|
|
517
|
+
if processed_state is not Empty:
|
|
518
|
+
return processed_state.parameter_casts or {}
|
|
519
|
+
return {}
|
|
520
|
+
|
|
521
|
+
def _prepare_parameters_with_casts(
|
|
522
|
+
self, parameters: Any, parameter_casts: "dict[int, str]", statement_config: "StatementConfig"
|
|
523
|
+
) -> Any:
|
|
524
|
+
"""Prepare parameters with cast-aware type coercion.
|
|
525
|
+
|
|
526
|
+
Uses type coercion map for non-dict types and dialect-aware dict handling.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
parameters: Parameter values (list, tuple, or scalar)
|
|
530
|
+
parameter_casts: Mapping of parameter positions to cast types
|
|
531
|
+
statement_config: Statement configuration for type coercion
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
Parameters with cast-aware type coercion applied
|
|
535
|
+
"""
|
|
536
|
+
from sqlspec._serialization import encode_json
|
|
537
|
+
|
|
538
|
+
if isinstance(parameters, (list, tuple)):
|
|
539
|
+
result: list[Any] = []
|
|
540
|
+
for idx, param in enumerate(parameters, start=1): # pyright: ignore
|
|
541
|
+
cast_type = parameter_casts.get(idx, "").upper()
|
|
542
|
+
if cast_type in {"JSON", "JSONB", "TYPE.JSON", "TYPE.JSONB"}:
|
|
543
|
+
if isinstance(param, dict):
|
|
544
|
+
result.append(encode_json(param))
|
|
545
|
+
else:
|
|
546
|
+
result.append(param)
|
|
547
|
+
elif isinstance(param, dict):
|
|
548
|
+
result.append(ADBCTypeConverter(self.dialect).convert_dict(param)) # type: ignore[arg-type]
|
|
549
|
+
else:
|
|
550
|
+
if statement_config.parameter_config.type_coercion_map:
|
|
551
|
+
for type_check, converter in statement_config.parameter_config.type_coercion_map.items():
|
|
552
|
+
if type_check is not dict and isinstance(param, type_check):
|
|
553
|
+
param = converter(param)
|
|
554
|
+
break
|
|
555
|
+
result.append(param)
|
|
556
|
+
return tuple(result) if isinstance(parameters, tuple) else result
|
|
557
|
+
return parameters
|
|
558
|
+
|
|
478
559
|
def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
|
|
479
560
|
"""Create context manager for cursor.
|
|
480
561
|
|
|
@@ -519,17 +600,26 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
519
600
|
"""
|
|
520
601
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
521
602
|
|
|
603
|
+
parameter_casts = self._get_parameter_casts(statement)
|
|
604
|
+
|
|
522
605
|
try:
|
|
523
606
|
if not prepared_parameters:
|
|
524
607
|
cursor._rowcount = 0 # pyright: ignore[reportPrivateUsage]
|
|
525
608
|
row_count = 0
|
|
526
|
-
elif isinstance(prepared_parameters, list) and prepared_parameters:
|
|
609
|
+
elif isinstance(prepared_parameters, (list, tuple)) and prepared_parameters:
|
|
527
610
|
processed_params = []
|
|
528
611
|
for param_set in prepared_parameters:
|
|
529
612
|
postgres_compatible = self._handle_postgres_empty_parameters(param_set)
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
613
|
+
|
|
614
|
+
if self.dialect in {"postgres", "postgresql"}:
|
|
615
|
+
# For postgres, always use cast-aware parameter preparation
|
|
616
|
+
formatted_params = self._prepare_parameters_with_casts(
|
|
617
|
+
postgres_compatible, parameter_casts, self.statement_config
|
|
618
|
+
)
|
|
619
|
+
else:
|
|
620
|
+
formatted_params = self.prepare_driver_parameters(
|
|
621
|
+
postgres_compatible, self.statement_config, is_many=False
|
|
622
|
+
)
|
|
533
623
|
processed_params.append(formatted_params)
|
|
534
624
|
|
|
535
625
|
cursor.executemany(sql, processed_params)
|
|
@@ -540,7 +630,6 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
540
630
|
|
|
541
631
|
except Exception:
|
|
542
632
|
self._handle_postgres_rollback(cursor)
|
|
543
|
-
logger.exception("Executemany failed")
|
|
544
633
|
raise
|
|
545
634
|
|
|
546
635
|
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
|
|
@@ -557,9 +646,18 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
557
646
|
"""
|
|
558
647
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
559
648
|
|
|
649
|
+
parameter_casts = self._get_parameter_casts(statement)
|
|
650
|
+
|
|
560
651
|
try:
|
|
561
652
|
postgres_compatible_params = self._handle_postgres_empty_parameters(prepared_parameters)
|
|
562
|
-
|
|
653
|
+
|
|
654
|
+
if self.dialect in {"postgres", "postgresql"}:
|
|
655
|
+
formatted_params = self._prepare_parameters_with_casts(
|
|
656
|
+
postgres_compatible_params, parameter_casts, self.statement_config
|
|
657
|
+
)
|
|
658
|
+
cursor.execute(sql, parameters=formatted_params)
|
|
659
|
+
else:
|
|
660
|
+
cursor.execute(sql, parameters=postgres_compatible_params)
|
|
563
661
|
|
|
564
662
|
except Exception:
|
|
565
663
|
self._handle_postgres_rollback(cursor)
|
|
@@ -655,3 +753,14 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
655
753
|
except Exception as e:
|
|
656
754
|
msg = f"Failed to commit transaction: {e}"
|
|
657
755
|
raise SQLSpecError(msg) from e
|
|
756
|
+
|
|
757
|
+
@property
|
|
758
|
+
def data_dictionary(self) -> "SyncDataDictionaryBase":
|
|
759
|
+
"""Get the data dictionary for this driver.
|
|
760
|
+
|
|
761
|
+
Returns:
|
|
762
|
+
Data dictionary instance for metadata queries
|
|
763
|
+
"""
|
|
764
|
+
if self._data_dictionary is None:
|
|
765
|
+
self._data_dictionary = AdbcDataDictionary()
|
|
766
|
+
return self._data_dictionary
|