sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +725 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
- sqlspec-0.12.1.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/config.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union
|
|
4
|
+
|
|
5
|
+
from sqlspec.typing import ConnectionT, PoolT # pyright: ignore
|
|
6
|
+
from sqlspec.utils.logging import get_logger
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Awaitable
|
|
10
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
11
|
+
|
|
12
|
+
from sqlglot.dialects.dialect import DialectType
|
|
13
|
+
|
|
14
|
+
from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
|
|
15
|
+
from sqlspec.statement.result import StatementResult
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
StatementResultType = Union["StatementResult[dict[str, Any]]", "StatementResult[Any]"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = (
|
|
22
|
+
"AsyncConfigT",
|
|
23
|
+
"AsyncDatabaseConfig",
|
|
24
|
+
"ConfigT",
|
|
25
|
+
"DatabaseConfigProtocol",
|
|
26
|
+
"DriverT",
|
|
27
|
+
"GenericPoolConfig",
|
|
28
|
+
"NoPoolAsyncConfig",
|
|
29
|
+
"NoPoolSyncConfig",
|
|
30
|
+
"StatementResultType",
|
|
31
|
+
"SyncConfigT",
|
|
32
|
+
"SyncDatabaseConfig",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
AsyncConfigT = TypeVar("AsyncConfigT", bound="Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]]")
|
|
36
|
+
SyncConfigT = TypeVar("SyncConfigT", bound="Union[SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]")
|
|
37
|
+
ConfigT = TypeVar(
|
|
38
|
+
"ConfigT",
|
|
39
|
+
bound="Union[Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]], SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]",
|
|
40
|
+
)
|
|
41
|
+
DriverT = TypeVar("DriverT", bound="Union[SyncDriverAdapterProtocol[Any], AsyncDriverAdapterProtocol[Any]]")
|
|
42
|
+
|
|
43
|
+
logger = get_logger("config")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
48
|
+
"""Protocol defining the interface for database configurations."""
|
|
49
|
+
|
|
50
|
+
# Note: __slots__ cannot be used with dataclass fields in Python < 3.10
|
|
51
|
+
# Concrete subclasses can still use __slots__ for any additional attributes
|
|
52
|
+
__slots__ = ()
|
|
53
|
+
|
|
54
|
+
is_async: "ClassVar[bool]" = field(init=False, default=False)
|
|
55
|
+
supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=False)
|
|
56
|
+
supports_native_arrow_import: "ClassVar[bool]" = field(init=False, default=False)
|
|
57
|
+
supports_native_arrow_export: "ClassVar[bool]" = field(init=False, default=False)
|
|
58
|
+
supports_native_parquet_import: "ClassVar[bool]" = field(init=False, default=False)
|
|
59
|
+
supports_native_parquet_export: "ClassVar[bool]" = field(init=False, default=False)
|
|
60
|
+
connection_type: "type[ConnectionT]" = field(init=False, repr=False, hash=False, compare=False)
|
|
61
|
+
driver_type: "type[DriverT]" = field(init=False, repr=False, hash=False, compare=False)
|
|
62
|
+
pool_instance: "Optional[PoolT]" = field(default=None)
|
|
63
|
+
default_row_type: "type[Any]" = field(init=False)
|
|
64
|
+
_dialect: "DialectType" = field(default=None, init=False, repr=False, hash=False, compare=False)
|
|
65
|
+
|
|
66
|
+
supported_parameter_styles: "ClassVar[tuple[str, ...]]" = ()
|
|
67
|
+
"""Parameter styles supported by this database adapter (e.g., ('qmark', 'named_colon'))."""
|
|
68
|
+
|
|
69
|
+
preferred_parameter_style: "ClassVar[str]" = "none"
|
|
70
|
+
"""The preferred/native parameter style for this database."""
|
|
71
|
+
|
|
72
|
+
def __hash__(self) -> int:
|
|
73
|
+
return id(self)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def dialect(self) -> "DialectType":
|
|
77
|
+
"""Get the SQL dialect type lazily.
|
|
78
|
+
|
|
79
|
+
This property allows dialect to be set either statically as a class attribute
|
|
80
|
+
or dynamically via the _get_dialect() method. If a specific adapter needs
|
|
81
|
+
dynamic dialect detection (e.g., ADBC which supports multiple databases),
|
|
82
|
+
it can override _get_dialect() to provide custom logic.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
The SQL dialect type for this database.
|
|
86
|
+
"""
|
|
87
|
+
if self._dialect is None:
|
|
88
|
+
self._dialect = self._get_dialect() # type: ignore[misc]
|
|
89
|
+
return self._dialect
|
|
90
|
+
|
|
91
|
+
def _get_dialect(self) -> "DialectType":
|
|
92
|
+
"""Get the dialect for this database configuration.
|
|
93
|
+
|
|
94
|
+
This method should be overridden by configs that need dynamic dialect detection.
|
|
95
|
+
By default, it looks for the dialect on the driver class.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The SQL dialect type.
|
|
99
|
+
"""
|
|
100
|
+
# Get dialect from driver_class (all drivers must have a dialect attribute)
|
|
101
|
+
return self.driver_type.dialect
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def create_connection(self) -> "Union[ConnectionT, Awaitable[ConnectionT]]":
|
|
105
|
+
"""Create and return a new database connection."""
|
|
106
|
+
raise NotImplementedError
|
|
107
|
+
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def provide_connection(
|
|
110
|
+
self, *args: Any, **kwargs: Any
|
|
111
|
+
) -> "Union[AbstractContextManager[ConnectionT], AbstractAsyncContextManager[ConnectionT]]":
|
|
112
|
+
"""Provide a database connection context manager."""
|
|
113
|
+
raise NotImplementedError
|
|
114
|
+
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def provide_session(
|
|
117
|
+
self, *args: Any, **kwargs: Any
|
|
118
|
+
) -> "Union[AbstractContextManager[DriverT], AbstractAsyncContextManager[DriverT]]":
|
|
119
|
+
"""Provide a database session context manager."""
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def connection_config_dict(self) -> "dict[str, Any]":
|
|
125
|
+
"""Return the connection configuration as a dict."""
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
@abstractmethod
|
|
129
|
+
def create_pool(self) -> "Union[PoolT, Awaitable[PoolT]]":
|
|
130
|
+
"""Create and return connection pool."""
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
@abstractmethod
|
|
134
|
+
def close_pool(self) -> "Optional[Awaitable[None]]":
|
|
135
|
+
"""Terminate the connection pool."""
|
|
136
|
+
raise NotImplementedError
|
|
137
|
+
|
|
138
|
+
@abstractmethod
|
|
139
|
+
def provide_pool(
|
|
140
|
+
self, *args: Any, **kwargs: Any
|
|
141
|
+
) -> "Union[PoolT, Awaitable[PoolT], AbstractContextManager[PoolT], AbstractAsyncContextManager[PoolT]]":
|
|
142
|
+
"""Provide pool instance."""
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
|
|
146
|
+
"""Get the signature namespace for this database configuration.
|
|
147
|
+
|
|
148
|
+
This method returns a dictionary of type names to types that should be
|
|
149
|
+
registered with Litestar's signature namespace to prevent serialization
|
|
150
|
+
attempts on database-specific types.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dictionary mapping type names to types.
|
|
154
|
+
"""
|
|
155
|
+
namespace: dict[str, type[Any]] = {}
|
|
156
|
+
|
|
157
|
+
# Add the driver and config types
|
|
158
|
+
if hasattr(self, "driver_type") and self.driver_type:
|
|
159
|
+
namespace[self.driver_type.__name__] = self.driver_type
|
|
160
|
+
|
|
161
|
+
namespace[self.__class__.__name__] = self.__class__
|
|
162
|
+
|
|
163
|
+
# Add connection type(s)
|
|
164
|
+
if hasattr(self, "connection_type") and self.connection_type:
|
|
165
|
+
connection_type = self.connection_type
|
|
166
|
+
|
|
167
|
+
# Handle Union types (like AsyncPG's Union[Connection, PoolConnectionProxy])
|
|
168
|
+
if hasattr(connection_type, "__args__"):
|
|
169
|
+
# It's a generic type, extract the actual types
|
|
170
|
+
for arg_type in connection_type.__args__: # type: ignore[attr-defined]
|
|
171
|
+
if arg_type and hasattr(arg_type, "__name__"):
|
|
172
|
+
namespace[arg_type.__name__] = arg_type
|
|
173
|
+
elif hasattr(connection_type, "__name__"):
|
|
174
|
+
# Regular type
|
|
175
|
+
namespace[connection_type.__name__] = connection_type
|
|
176
|
+
|
|
177
|
+
return namespace
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@dataclass
|
|
181
|
+
class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
182
|
+
"""Base class for a sync database configurations that do not implement a pool."""
|
|
183
|
+
|
|
184
|
+
__slots__ = ()
|
|
185
|
+
|
|
186
|
+
is_async: "ClassVar[bool]" = field(init=False, default=False)
|
|
187
|
+
supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=False)
|
|
188
|
+
pool_instance: None = None
|
|
189
|
+
|
|
190
|
+
def create_connection(self) -> ConnectionT:
|
|
191
|
+
"""Create connection with instrumentation."""
|
|
192
|
+
raise NotImplementedError
|
|
193
|
+
|
|
194
|
+
def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[ConnectionT]":
|
|
195
|
+
"""Provide connection with instrumentation."""
|
|
196
|
+
raise NotImplementedError
|
|
197
|
+
|
|
198
|
+
def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[DriverT]":
|
|
199
|
+
"""Provide session with instrumentation."""
|
|
200
|
+
raise NotImplementedError
|
|
201
|
+
|
|
202
|
+
def create_pool(self) -> None:
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
def close_pool(self) -> None:
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
def provide_pool(self, *args: Any, **kwargs: Any) -> None:
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@dataclass
|
|
213
|
+
class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
214
|
+
"""Base class for an async database configurations that do not implement a pool."""
|
|
215
|
+
|
|
216
|
+
__slots__ = ()
|
|
217
|
+
|
|
218
|
+
is_async: "ClassVar[bool]" = field(init=False, default=True)
|
|
219
|
+
supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=False)
|
|
220
|
+
pool_instance: None = None
|
|
221
|
+
|
|
222
|
+
async def create_connection(self) -> ConnectionT:
|
|
223
|
+
"""Create connection with instrumentation."""
|
|
224
|
+
raise NotImplementedError
|
|
225
|
+
|
|
226
|
+
def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[ConnectionT]":
|
|
227
|
+
"""Provide connection with instrumentation."""
|
|
228
|
+
raise NotImplementedError
|
|
229
|
+
|
|
230
|
+
def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[DriverT]":
|
|
231
|
+
"""Provide session with instrumentation."""
|
|
232
|
+
raise NotImplementedError
|
|
233
|
+
|
|
234
|
+
async def create_pool(self) -> None:
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
async def close_pool(self) -> None:
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
def provide_pool(self, *args: Any, **kwargs: Any) -> None:
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@dataclass
|
|
245
|
+
class GenericPoolConfig:
|
|
246
|
+
"""Generic Database Pool Configuration."""
|
|
247
|
+
|
|
248
|
+
__slots__ = ()
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@dataclass
|
|
252
|
+
class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
|
|
253
|
+
"""Generic Sync Database Configuration."""
|
|
254
|
+
|
|
255
|
+
__slots__ = ()
|
|
256
|
+
|
|
257
|
+
is_async: "ClassVar[bool]" = field(init=False, default=False)
|
|
258
|
+
supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=True)
|
|
259
|
+
|
|
260
|
+
def create_pool(self) -> PoolT:
|
|
261
|
+
"""Create pool with instrumentation.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
The created pool.
|
|
265
|
+
"""
|
|
266
|
+
if self.pool_instance is not None:
|
|
267
|
+
return self.pool_instance
|
|
268
|
+
self.pool_instance = self._create_pool() # type: ignore[misc]
|
|
269
|
+
return self.pool_instance
|
|
270
|
+
|
|
271
|
+
def close_pool(self) -> None:
|
|
272
|
+
"""Close pool with instrumentation."""
|
|
273
|
+
self._close_pool()
|
|
274
|
+
|
|
275
|
+
def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
|
|
276
|
+
"""Provide pool instance."""
|
|
277
|
+
if self.pool_instance is None:
|
|
278
|
+
self.pool_instance = self.create_pool() # type: ignore[misc]
|
|
279
|
+
return self.pool_instance
|
|
280
|
+
|
|
281
|
+
def create_connection(self) -> ConnectionT:
|
|
282
|
+
"""Create connection with instrumentation."""
|
|
283
|
+
raise NotImplementedError
|
|
284
|
+
|
|
285
|
+
def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[ConnectionT]":
|
|
286
|
+
"""Provide connection with instrumentation."""
|
|
287
|
+
raise NotImplementedError
|
|
288
|
+
|
|
289
|
+
def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[DriverT]":
|
|
290
|
+
"""Provide session with instrumentation."""
|
|
291
|
+
raise NotImplementedError
|
|
292
|
+
|
|
293
|
+
@abstractmethod
|
|
294
|
+
def _create_pool(self) -> PoolT:
|
|
295
|
+
"""Actual pool creation implementation."""
|
|
296
|
+
raise NotImplementedError
|
|
297
|
+
|
|
298
|
+
@abstractmethod
|
|
299
|
+
def _close_pool(self) -> None:
|
|
300
|
+
"""Actual pool destruction implementation."""
|
|
301
|
+
raise NotImplementedError
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@dataclass
|
|
305
|
+
class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
|
|
306
|
+
"""Generic Async Database Configuration."""
|
|
307
|
+
|
|
308
|
+
__slots__ = ()
|
|
309
|
+
|
|
310
|
+
is_async: "ClassVar[bool]" = field(init=False, default=True)
|
|
311
|
+
supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=True)
|
|
312
|
+
|
|
313
|
+
async def create_pool(self) -> PoolT:
|
|
314
|
+
"""Create pool with instrumentation.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
The created pool.
|
|
318
|
+
"""
|
|
319
|
+
if self.pool_instance is not None:
|
|
320
|
+
return self.pool_instance
|
|
321
|
+
self.pool_instance = await self._create_pool() # type: ignore[misc]
|
|
322
|
+
return self.pool_instance
|
|
323
|
+
|
|
324
|
+
async def close_pool(self) -> None:
|
|
325
|
+
"""Close pool with instrumentation."""
|
|
326
|
+
await self._close_pool()
|
|
327
|
+
|
|
328
|
+
async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
|
|
329
|
+
"""Provide pool instance."""
|
|
330
|
+
if self.pool_instance is None:
|
|
331
|
+
self.pool_instance = await self.create_pool() # type: ignore[misc]
|
|
332
|
+
return self.pool_instance
|
|
333
|
+
|
|
334
|
+
async def create_connection(self) -> ConnectionT:
|
|
335
|
+
"""Create connection with instrumentation."""
|
|
336
|
+
raise NotImplementedError
|
|
337
|
+
|
|
338
|
+
def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[ConnectionT]":
|
|
339
|
+
"""Provide connection with instrumentation."""
|
|
340
|
+
raise NotImplementedError
|
|
341
|
+
|
|
342
|
+
def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[DriverT]":
|
|
343
|
+
"""Provide session with instrumentation."""
|
|
344
|
+
raise NotImplementedError
|
|
345
|
+
|
|
346
|
+
@abstractmethod
|
|
347
|
+
async def _create_pool(self) -> PoolT:
|
|
348
|
+
"""Actual async pool creation implementation."""
|
|
349
|
+
raise NotImplementedError
|
|
350
|
+
|
|
351
|
+
@abstractmethod
|
|
352
|
+
async def _close_pool(self) -> None:
|
|
353
|
+
"""Actual async pool destruction implementation."""
|
|
354
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Driver protocols and base classes for database adapters."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from sqlspec.driver import mixins
|
|
6
|
+
from sqlspec.driver._async import AsyncDriverAdapterProtocol
|
|
7
|
+
from sqlspec.driver._common import CommonDriverAttributesMixin
|
|
8
|
+
from sqlspec.driver._sync import SyncDriverAdapterProtocol
|
|
9
|
+
from sqlspec.typing import ConnectionT, RowT
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"AsyncDriverAdapterProtocol",
|
|
13
|
+
"CommonDriverAttributesMixin",
|
|
14
|
+
"DriverAdapterProtocol",
|
|
15
|
+
"SyncDriverAdapterProtocol",
|
|
16
|
+
"mixins",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Type alias for convenience
|
|
20
|
+
DriverAdapterProtocol = Union[
|
|
21
|
+
SyncDriverAdapterProtocol[ConnectionT, RowT], AsyncDriverAdapterProtocol[ConnectionT, RowT]
|
|
22
|
+
]
|
sqlspec/driver/_async.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Asynchronous driver protocol implementation."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
|
5
|
+
|
|
6
|
+
from sqlspec.driver._common import CommonDriverAttributesMixin
|
|
7
|
+
from sqlspec.statement.builder import DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder
|
|
8
|
+
from sqlspec.statement.filters import StatementFilter
|
|
9
|
+
from sqlspec.statement.result import SQLResult
|
|
10
|
+
from sqlspec.statement.sql import SQL, SQLConfig, Statement
|
|
11
|
+
from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict
|
|
15
|
+
|
|
16
|
+
__all__ = ("AsyncDriverAdapterProtocol",)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
EMPTY_FILTERS: "list[StatementFilter]" = []
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT], ABC):
|
|
23
|
+
__slots__ = ()
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
connection: "ConnectionT",
|
|
28
|
+
config: "Optional[SQLConfig]" = None,
|
|
29
|
+
default_row_type: "type[DictRow]" = DictRow,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initialize async driver adapter.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
connection: The database connection
|
|
35
|
+
config: SQL statement configuration
|
|
36
|
+
default_row_type: Default row type for results (DictRow, TupleRow, etc.)
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
39
|
+
|
|
40
|
+
def _build_statement(
|
|
41
|
+
self,
|
|
42
|
+
statement: "Union[Statement, QueryBuilder[Any]]",
|
|
43
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
44
|
+
_config: "Optional[SQLConfig]" = None,
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
) -> "SQL":
|
|
47
|
+
# Use driver's config if none provided
|
|
48
|
+
_config = _config or self.config
|
|
49
|
+
|
|
50
|
+
if isinstance(statement, QueryBuilder):
|
|
51
|
+
return statement.to_statement(config=_config)
|
|
52
|
+
# If statement is already a SQL object, return it as-is
|
|
53
|
+
if isinstance(statement, SQL):
|
|
54
|
+
return statement
|
|
55
|
+
return SQL(statement, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
async def _execute_statement(
|
|
59
|
+
self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
|
|
60
|
+
) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
|
|
61
|
+
"""Actual execution implementation by concrete drivers, using the raw connection.
|
|
62
|
+
|
|
63
|
+
Returns one of the standardized result dictionaries based on the statement type.
|
|
64
|
+
"""
|
|
65
|
+
raise NotImplementedError
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
async def _wrap_select_result(
|
|
69
|
+
self,
|
|
70
|
+
statement: "SQL",
|
|
71
|
+
result: "SelectResultDict",
|
|
72
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
75
|
+
raise NotImplementedError
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
async def _wrap_execute_result(
|
|
79
|
+
self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
|
|
80
|
+
) -> "SQLResult[RowT]":
|
|
81
|
+
raise NotImplementedError
|
|
82
|
+
|
|
83
|
+
# Type-safe overloads based on the refactor plan pattern
|
|
84
|
+
@overload
|
|
85
|
+
async def execute(
|
|
86
|
+
self,
|
|
87
|
+
statement: "SelectBuilder",
|
|
88
|
+
/,
|
|
89
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
90
|
+
schema_type: "type[ModelDTOT]",
|
|
91
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
92
|
+
_config: "Optional[SQLConfig]" = None,
|
|
93
|
+
**kwargs: Any,
|
|
94
|
+
) -> "SQLResult[ModelDTOT]": ...
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
async def execute(
|
|
98
|
+
self,
|
|
99
|
+
statement: "SelectBuilder",
|
|
100
|
+
/,
|
|
101
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
102
|
+
schema_type: None = None,
|
|
103
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
104
|
+
_config: "Optional[SQLConfig]" = None,
|
|
105
|
+
**kwargs: Any,
|
|
106
|
+
) -> "SQLResult[RowT]": ...
|
|
107
|
+
|
|
108
|
+
@overload
|
|
109
|
+
async def execute(
|
|
110
|
+
self,
|
|
111
|
+
statement: "Union[InsertBuilder, UpdateBuilder, DeleteBuilder]",
|
|
112
|
+
/,
|
|
113
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
114
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
115
|
+
_config: "Optional[SQLConfig]" = None,
|
|
116
|
+
**kwargs: Any,
|
|
117
|
+
) -> "SQLResult[RowT]": ...
|
|
118
|
+
|
|
119
|
+
@overload
|
|
120
|
+
async def execute(
|
|
121
|
+
self,
|
|
122
|
+
statement: "Union[str, SQL]", # exp.Expression
|
|
123
|
+
/,
|
|
124
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
125
|
+
schema_type: "type[ModelDTOT]",
|
|
126
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
127
|
+
_config: "Optional[SQLConfig]" = None,
|
|
128
|
+
**kwargs: Any,
|
|
129
|
+
) -> "SQLResult[ModelDTOT]": ...
|
|
130
|
+
|
|
131
|
+
@overload
|
|
132
|
+
async def execute(
|
|
133
|
+
self,
|
|
134
|
+
statement: "Union[str, SQL]",
|
|
135
|
+
/,
|
|
136
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
137
|
+
schema_type: None = None,
|
|
138
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
139
|
+
_config: "Optional[SQLConfig]" = None,
|
|
140
|
+
**kwargs: Any,
|
|
141
|
+
) -> "SQLResult[RowT]": ...
|
|
142
|
+
|
|
143
|
+
async def execute(
|
|
144
|
+
self,
|
|
145
|
+
statement: "Union[SQL, Statement, QueryBuilder[Any]]",
|
|
146
|
+
/,
|
|
147
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
148
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
149
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
150
|
+
_config: "Optional[SQLConfig]" = None,
|
|
151
|
+
**kwargs: Any,
|
|
152
|
+
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
153
|
+
sql_statement = self._build_statement(statement, *parameters, _config=_config or self.config, **kwargs)
|
|
154
|
+
result = await self._execute_statement(
|
|
155
|
+
statement=sql_statement, connection=self._connection(_connection), **kwargs
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if self.returns_rows(sql_statement.expression):
|
|
159
|
+
return await self._wrap_select_result(
|
|
160
|
+
sql_statement, cast("SelectResultDict", result), schema_type=schema_type, **kwargs
|
|
161
|
+
)
|
|
162
|
+
return await self._wrap_execute_result(
|
|
163
|
+
sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
async def execute_many(
|
|
167
|
+
self,
|
|
168
|
+
statement: "Union[SQL, Statement, QueryBuilder[Any]]", # QueryBuilder for DMLs will likely not return rows.
|
|
169
|
+
/,
|
|
170
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
171
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
172
|
+
_config: "Optional[SQLConfig]" = None,
|
|
173
|
+
**kwargs: Any,
|
|
174
|
+
) -> "SQLResult[RowT]":
|
|
175
|
+
# Separate parameters from filters
|
|
176
|
+
param_sequences = []
|
|
177
|
+
filters = []
|
|
178
|
+
for param in parameters:
|
|
179
|
+
if isinstance(param, StatementFilter):
|
|
180
|
+
filters.append(param)
|
|
181
|
+
else:
|
|
182
|
+
param_sequences.append(param)
|
|
183
|
+
|
|
184
|
+
# Use first parameter as the sequence for execute_many
|
|
185
|
+
param_sequence = param_sequences[0] if param_sequences else None
|
|
186
|
+
# Convert tuple to list if needed
|
|
187
|
+
if isinstance(param_sequence, tuple):
|
|
188
|
+
param_sequence = list(param_sequence)
|
|
189
|
+
# Ensure param_sequence is a list or None
|
|
190
|
+
if param_sequence is not None and not isinstance(param_sequence, list):
|
|
191
|
+
param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
|
|
192
|
+
sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs)
|
|
193
|
+
sql_statement = sql_statement.as_many(param_sequence)
|
|
194
|
+
result = await self._execute_statement(
|
|
195
|
+
statement=sql_statement,
|
|
196
|
+
connection=self._connection(_connection),
|
|
197
|
+
parameters=param_sequence,
|
|
198
|
+
is_many=True,
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
return await self._wrap_execute_result(
|
|
202
|
+
sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
async def execute_script(
|
|
206
|
+
self,
|
|
207
|
+
statement: "Union[str, SQL]",
|
|
208
|
+
/,
|
|
209
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
210
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
211
|
+
_config: "Optional[SQLConfig]" = None,
|
|
212
|
+
**kwargs: Any,
|
|
213
|
+
) -> "SQLResult[RowT]":
|
|
214
|
+
param_values = []
|
|
215
|
+
filters = []
|
|
216
|
+
for param in parameters:
|
|
217
|
+
if isinstance(param, StatementFilter):
|
|
218
|
+
filters.append(param)
|
|
219
|
+
else:
|
|
220
|
+
param_values.append(param)
|
|
221
|
+
|
|
222
|
+
# Use first parameter as the primary parameter value, or None if no parameters
|
|
223
|
+
primary_params = param_values[0] if param_values else None
|
|
224
|
+
|
|
225
|
+
script_config = _config or self.config
|
|
226
|
+
if script_config.enable_validation:
|
|
227
|
+
script_config = SQLConfig(
|
|
228
|
+
enable_parsing=script_config.enable_parsing,
|
|
229
|
+
enable_validation=False,
|
|
230
|
+
enable_transformations=script_config.enable_transformations,
|
|
231
|
+
enable_analysis=script_config.enable_analysis,
|
|
232
|
+
strict_mode=False,
|
|
233
|
+
cache_parsed_expression=script_config.cache_parsed_expression,
|
|
234
|
+
parameter_converter=script_config.parameter_converter,
|
|
235
|
+
parameter_validator=script_config.parameter_validator,
|
|
236
|
+
analysis_cache_size=script_config.analysis_cache_size,
|
|
237
|
+
allowed_parameter_styles=script_config.allowed_parameter_styles,
|
|
238
|
+
target_parameter_style=script_config.target_parameter_style,
|
|
239
|
+
allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
|
|
240
|
+
)
|
|
241
|
+
sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
|
|
242
|
+
sql_statement = sql_statement.as_script()
|
|
243
|
+
script_output = await self._execute_statement(
|
|
244
|
+
statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
|
|
245
|
+
)
|
|
246
|
+
if isinstance(script_output, str):
|
|
247
|
+
result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
|
|
248
|
+
result.total_statements = 1
|
|
249
|
+
result.successful_statements = 1
|
|
250
|
+
return result
|
|
251
|
+
# Wrap the ScriptResultDict using the driver's wrapper
|
|
252
|
+
return await self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)
|