sqlspec 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +1 -1
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +160 -17
- sqlspec/adapters/adbc/driver.py +333 -0
- sqlspec/adapters/aiosqlite/__init__.py +6 -2
- sqlspec/adapters/aiosqlite/config.py +25 -7
- sqlspec/adapters/aiosqlite/driver.py +275 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +75 -14
- sqlspec/adapters/asyncmy/driver.py +255 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +99 -20
- sqlspec/adapters/asyncpg/driver.py +288 -0
- sqlspec/adapters/duckdb/__init__.py +6 -2
- sqlspec/adapters/duckdb/config.py +197 -15
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +11 -8
- sqlspec/adapters/oracledb/config/__init__.py +6 -6
- sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +99 -14
- sqlspec/adapters/oracledb/driver.py +498 -0
- sqlspec/adapters/psycopg/__init__.py +11 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +105 -13
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +105 -13
- sqlspec/adapters/psycopg/driver.py +616 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +25 -7
- sqlspec/adapters/sqlite/driver.py +303 -0
- sqlspec/base.py +416 -36
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +81 -0
- sqlspec/extensions/litestar/handlers.py +188 -0
- sqlspec/extensions/litestar/plugin.py +103 -11
- sqlspec/typing.py +72 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/deprecation.py +1 -1
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
- sqlspec-0.8.0.dist-info/RECORD +57 -0
- sqlspec-0.7.0.dist-info/RECORD +0 -46
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/base.py
CHANGED
|
@@ -1,8 +1,23 @@
|
|
|
1
|
+
# ruff: noqa: PLR6301
|
|
2
|
+
import re
|
|
1
3
|
from abc import ABC, abstractmethod
|
|
2
4
|
from collections.abc import AsyncGenerator, Awaitable, Generator
|
|
3
5
|
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from typing import
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import (
|
|
8
|
+
Annotated,
|
|
9
|
+
Any,
|
|
10
|
+
ClassVar,
|
|
11
|
+
Generic,
|
|
12
|
+
Optional,
|
|
13
|
+
TypeVar,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from sqlspec.exceptions import NotFoundError
|
|
20
|
+
from sqlspec.typing import ModelDTOT, StatementParameterType
|
|
6
21
|
|
|
7
22
|
__all__ = (
|
|
8
23
|
"AsyncDatabaseConfig",
|
|
@@ -13,16 +28,34 @@ __all__ = (
|
|
|
13
28
|
"SyncDatabaseConfig",
|
|
14
29
|
)
|
|
15
30
|
|
|
31
|
+
T = TypeVar("T")
|
|
16
32
|
ConnectionT = TypeVar("ConnectionT")
|
|
17
33
|
PoolT = TypeVar("PoolT")
|
|
18
|
-
|
|
19
|
-
|
|
34
|
+
PoolT_co = TypeVar("PoolT_co", covariant=True)
|
|
35
|
+
AsyncConfigT = TypeVar("AsyncConfigT", bound="Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]]")
|
|
36
|
+
SyncConfigT = TypeVar("SyncConfigT", bound="Union[SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]")
|
|
37
|
+
ConfigT = TypeVar(
|
|
38
|
+
"ConfigT",
|
|
39
|
+
bound="Union[Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]], SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]",
|
|
40
|
+
)
|
|
41
|
+
DriverT = TypeVar("DriverT", bound="Union[SyncDriverAdapterProtocol[Any], AsyncDriverAdapterProtocol[Any]]")
|
|
42
|
+
|
|
43
|
+
# Regex to find :param style placeholders, avoiding those inside quotes
|
|
44
|
+
# Handles basic cases, might need refinement for complex SQL
|
|
45
|
+
PARAM_REGEX = re.compile(
|
|
46
|
+
r"(?P<dquote>\"(?:[^\"]|\"\")*\")|" # Double-quoted strings
|
|
47
|
+
r"(?P<squote>'(?:[^']|'')*')|" # Single-quoted strings
|
|
48
|
+
r"(?P<lead>[^:]):(?P<var_name>[a-zA-Z_][a-zA-Z0-9_]*)" # :param placeholder
|
|
49
|
+
)
|
|
20
50
|
|
|
21
51
|
|
|
22
52
|
@dataclass
|
|
23
|
-
class DatabaseConfigProtocol(Generic[ConnectionT, PoolT
|
|
53
|
+
class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
24
54
|
"""Protocol defining the interface for database configurations."""
|
|
25
55
|
|
|
56
|
+
connection_type: "type[ConnectionT]" = field(init=False)
|
|
57
|
+
driver_type: "type[DriverT]" = field(init=False)
|
|
58
|
+
pool_instance: "Optional[PoolT]" = field(default=None)
|
|
26
59
|
__is_async__: ClassVar[bool] = False
|
|
27
60
|
__supports_connection_pooling__: ClassVar[bool] = False
|
|
28
61
|
|
|
@@ -59,6 +92,11 @@ class DatabaseConfigProtocol(Generic[ConnectionT, PoolT], ABC):
|
|
|
59
92
|
"""Create and return connection pool."""
|
|
60
93
|
raise NotImplementedError
|
|
61
94
|
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def close_pool(self) -> Optional[Awaitable[None]]:
|
|
97
|
+
"""Terminate the connection pool."""
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
62
100
|
@abstractmethod
|
|
63
101
|
def provide_pool(
|
|
64
102
|
self,
|
|
@@ -79,31 +117,39 @@ class DatabaseConfigProtocol(Generic[ConnectionT, PoolT], ABC):
|
|
|
79
117
|
return self.__supports_connection_pooling__
|
|
80
118
|
|
|
81
119
|
|
|
82
|
-
class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None]):
|
|
120
|
+
class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
83
121
|
"""Base class for a sync database configurations that do not implement a pool."""
|
|
84
122
|
|
|
85
123
|
__is_async__ = False
|
|
86
124
|
__supports_connection_pooling__ = False
|
|
125
|
+
pool_instance: None = None
|
|
87
126
|
|
|
88
127
|
def create_pool(self) -> None:
|
|
89
128
|
"""This database backend has not implemented the pooling configurations."""
|
|
90
129
|
return
|
|
91
130
|
|
|
131
|
+
def close_pool(self) -> None:
|
|
132
|
+
return
|
|
133
|
+
|
|
92
134
|
def provide_pool(self, *args: Any, **kwargs: Any) -> None:
|
|
93
135
|
"""This database backend has not implemented the pooling configurations."""
|
|
94
136
|
return
|
|
95
137
|
|
|
96
138
|
|
|
97
|
-
class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None]):
|
|
139
|
+
class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
98
140
|
"""Base class for an async database configurations that do not implement a pool."""
|
|
99
141
|
|
|
100
142
|
__is_async__ = True
|
|
101
143
|
__supports_connection_pooling__ = False
|
|
144
|
+
pool_instance: None = None
|
|
102
145
|
|
|
103
146
|
async def create_pool(self) -> None:
|
|
104
147
|
"""This database backend has not implemented the pooling configurations."""
|
|
105
148
|
return
|
|
106
149
|
|
|
150
|
+
async def close_pool(self) -> None:
|
|
151
|
+
return
|
|
152
|
+
|
|
107
153
|
def provide_pool(self, *args: Any, **kwargs: Any) -> None:
|
|
108
154
|
"""This database backend has not implemented the pooling configurations."""
|
|
109
155
|
return
|
|
@@ -115,7 +161,7 @@ class GenericPoolConfig:
|
|
|
115
161
|
|
|
116
162
|
|
|
117
163
|
@dataclass
|
|
118
|
-
class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT]):
|
|
164
|
+
class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
|
|
119
165
|
"""Generic Sync Database Configuration."""
|
|
120
166
|
|
|
121
167
|
__is_async__ = False
|
|
@@ -123,18 +169,20 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT]):
|
|
|
123
169
|
|
|
124
170
|
|
|
125
171
|
@dataclass
|
|
126
|
-
class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT]):
|
|
172
|
+
class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
|
|
127
173
|
"""Generic Async Database Configuration."""
|
|
128
174
|
|
|
129
175
|
__is_async__ = True
|
|
130
176
|
__supports_connection_pooling__ = True
|
|
131
177
|
|
|
132
178
|
|
|
133
|
-
class
|
|
134
|
-
"""Type-safe configuration manager
|
|
179
|
+
class SQLSpec:
|
|
180
|
+
"""Type-safe configuration manager and registry for database connections and pools."""
|
|
181
|
+
|
|
182
|
+
__slots__ = ("_configs",)
|
|
135
183
|
|
|
136
184
|
def __init__(self) -> None:
|
|
137
|
-
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any]] = {}
|
|
185
|
+
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
138
186
|
|
|
139
187
|
@overload
|
|
140
188
|
def add_config(self, config: SyncConfigT) -> type[SyncConfigT]: ...
|
|
@@ -149,7 +197,11 @@ class ConfigManager:
|
|
|
149
197
|
AsyncConfigT,
|
|
150
198
|
],
|
|
151
199
|
) -> Union[Annotated[type[SyncConfigT], int], Annotated[type[AsyncConfigT], int]]: # pyright: ignore[reportInvalidTypeVarUse]
|
|
152
|
-
"""Add a new configuration to the manager.
|
|
200
|
+
"""Add a new configuration to the manager.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
A unique type key that can be used to retrieve the configuration later.
|
|
204
|
+
"""
|
|
153
205
|
key = Annotated[type(config), id(config)] # type: ignore[valid-type]
|
|
154
206
|
self._configs[key] = config
|
|
155
207
|
return key # type: ignore[return-value] # pyright: ignore[reportReturnType]
|
|
@@ -162,9 +214,16 @@ class ConfigManager:
|
|
|
162
214
|
|
|
163
215
|
def get_config(
|
|
164
216
|
self,
|
|
165
|
-
name: Union[type[DatabaseConfigProtocol[ConnectionT, PoolT]], Any],
|
|
166
|
-
) -> DatabaseConfigProtocol[ConnectionT, PoolT]:
|
|
167
|
-
"""Retrieve a configuration by its type.
|
|
217
|
+
name: Union[type[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]], Any],
|
|
218
|
+
) -> DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]:
|
|
219
|
+
"""Retrieve a configuration by its type.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
DatabaseConfigProtocol: The configuration instance for the given type.
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
KeyError: If no configuration is found for the given type.
|
|
226
|
+
"""
|
|
168
227
|
config = self._configs.get(name)
|
|
169
228
|
if not config:
|
|
170
229
|
msg = f"No configuration found for {name}"
|
|
@@ -175,8 +234,8 @@ class ConfigManager:
|
|
|
175
234
|
def get_connection(
|
|
176
235
|
self,
|
|
177
236
|
name: Union[
|
|
178
|
-
type[NoPoolSyncConfig[ConnectionT]],
|
|
179
|
-
type[SyncDatabaseConfig[ConnectionT, PoolT]], # pyright: ignore[reportInvalidTypeVarUse]
|
|
237
|
+
type[NoPoolSyncConfig[ConnectionT, DriverT]],
|
|
238
|
+
type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]], # pyright: ignore[reportInvalidTypeVarUse]
|
|
180
239
|
],
|
|
181
240
|
) -> ConnectionT: ...
|
|
182
241
|
|
|
@@ -184,44 +243,365 @@ class ConfigManager:
|
|
|
184
243
|
def get_connection(
|
|
185
244
|
self,
|
|
186
245
|
name: Union[
|
|
187
|
-
type[NoPoolAsyncConfig[ConnectionT]],
|
|
188
|
-
type[AsyncDatabaseConfig[ConnectionT, PoolT]], # pyright: ignore[reportInvalidTypeVarUse]
|
|
246
|
+
type[NoPoolAsyncConfig[ConnectionT, DriverT]],
|
|
247
|
+
type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]], # pyright: ignore[reportInvalidTypeVarUse]
|
|
189
248
|
],
|
|
190
249
|
) -> Awaitable[ConnectionT]: ...
|
|
191
250
|
|
|
192
251
|
def get_connection(
|
|
193
252
|
self,
|
|
194
253
|
name: Union[
|
|
195
|
-
type[NoPoolSyncConfig[ConnectionT]],
|
|
196
|
-
type[NoPoolAsyncConfig[ConnectionT]],
|
|
197
|
-
type[SyncDatabaseConfig[ConnectionT, PoolT]],
|
|
198
|
-
type[AsyncDatabaseConfig[ConnectionT, PoolT]],
|
|
254
|
+
type[NoPoolSyncConfig[ConnectionT, DriverT]],
|
|
255
|
+
type[NoPoolAsyncConfig[ConnectionT, DriverT]],
|
|
256
|
+
type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
|
|
257
|
+
type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
|
|
199
258
|
],
|
|
200
259
|
) -> Union[ConnectionT, Awaitable[ConnectionT]]:
|
|
201
|
-
"""Create and return a connection from the specified configuration.
|
|
260
|
+
"""Create and return a connection from the specified configuration.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
name: The configuration type to use for creating the connection.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Either a connection instance or an awaitable that resolves to a connection,
|
|
267
|
+
depending on whether the configuration is sync or async.
|
|
268
|
+
"""
|
|
202
269
|
config = self.get_config(name)
|
|
203
270
|
return config.create_connection()
|
|
204
271
|
|
|
205
272
|
@overload
|
|
206
|
-
def get_pool(
|
|
273
|
+
def get_pool(
|
|
274
|
+
self, name: type[Union[NoPoolSyncConfig[ConnectionT, DriverT], NoPoolAsyncConfig[ConnectionT, DriverT]]]
|
|
275
|
+
) -> None: ... # pyright: ignore[reportInvalidTypeVarUse]
|
|
207
276
|
|
|
208
277
|
@overload
|
|
209
|
-
def get_pool(self, name: type[SyncDatabaseConfig[ConnectionT, PoolT]]) -> type[PoolT]: ... # pyright: ignore[reportInvalidTypeVarUse]
|
|
278
|
+
def get_pool(self, name: type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]) -> type[PoolT]: ... # pyright: ignore[reportInvalidTypeVarUse]
|
|
210
279
|
|
|
211
280
|
@overload
|
|
212
|
-
def get_pool(self, name: type[AsyncDatabaseConfig[ConnectionT, PoolT]]) -> Awaitable[type[PoolT]]: ... # pyright: ignore[reportInvalidTypeVarUse]
|
|
281
|
+
def get_pool(self, name: type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]) -> Awaitable[type[PoolT]]: ... # pyright: ignore[reportInvalidTypeVarUse]
|
|
213
282
|
|
|
214
283
|
def get_pool(
|
|
215
284
|
self,
|
|
216
285
|
name: Union[
|
|
217
|
-
type[NoPoolSyncConfig[ConnectionT]],
|
|
218
|
-
type[NoPoolAsyncConfig[ConnectionT]],
|
|
219
|
-
type[SyncDatabaseConfig[ConnectionT, PoolT]],
|
|
220
|
-
type[AsyncDatabaseConfig[ConnectionT, PoolT]],
|
|
286
|
+
type[NoPoolSyncConfig[ConnectionT, DriverT]],
|
|
287
|
+
type[NoPoolAsyncConfig[ConnectionT, DriverT]],
|
|
288
|
+
type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
|
|
289
|
+
type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
|
|
221
290
|
],
|
|
222
291
|
) -> Union[type[PoolT], Awaitable[type[PoolT]], None]:
|
|
223
|
-
"""Create and return a connection pool from the specified configuration.
|
|
292
|
+
"""Create and return a connection pool from the specified configuration.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
name: The configuration type to use for creating the pool.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Either a pool instance, an awaitable that resolves to a pool instance, or None
|
|
299
|
+
if the configuration does not support connection pooling.
|
|
300
|
+
"""
|
|
224
301
|
config = self.get_config(name)
|
|
225
|
-
if
|
|
226
|
-
return
|
|
227
|
-
return
|
|
302
|
+
if config.support_connection_pooling:
|
|
303
|
+
return cast("Union[type[PoolT], Awaitable[type[PoolT]]]", config.create_pool())
|
|
304
|
+
return None
|
|
305
|
+
|
|
306
|
+
def close_pool(
|
|
307
|
+
self,
|
|
308
|
+
name: Union[
|
|
309
|
+
type[NoPoolSyncConfig[ConnectionT, DriverT]],
|
|
310
|
+
type[NoPoolAsyncConfig[ConnectionT, DriverT]],
|
|
311
|
+
type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
|
|
312
|
+
type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
|
|
313
|
+
],
|
|
314
|
+
) -> Optional[Awaitable[None]]:
|
|
315
|
+
"""Close the connection pool for the specified configuration.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
name: The configuration type whose pool to close.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
An awaitable if the configuration is async, otherwise None.
|
|
322
|
+
"""
|
|
323
|
+
config = self.get_config(name)
|
|
324
|
+
if config.support_connection_pooling:
|
|
325
|
+
return config.close_pool()
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class CommonDriverAttributes(Generic[ConnectionT]):
|
|
330
|
+
"""Common attributes and methods for driver adapters."""
|
|
331
|
+
|
|
332
|
+
param_style: str = "?"
|
|
333
|
+
"""The parameter style placeholder supported by the underlying database driver (e.g., '?', '%s')."""
|
|
334
|
+
connection: ConnectionT
|
|
335
|
+
"""The connection to the underlying database."""
|
|
336
|
+
|
|
337
|
+
def _connection(self, connection: "Optional[ConnectionT]" = None) -> "ConnectionT":
|
|
338
|
+
return connection if connection is not None else self.connection
|
|
339
|
+
|
|
340
|
+
@staticmethod
|
|
341
|
+
def check_not_found(item_or_none: Optional[T] = None) -> T:
|
|
342
|
+
"""Raise :exc:`sqlspec.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
item_or_none: Item to be tested for existence.
|
|
346
|
+
|
|
347
|
+
Raises:
|
|
348
|
+
NotFoundError: If ``item_or_none`` is ``None``
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
The item, if it exists.
|
|
352
|
+
"""
|
|
353
|
+
if item_or_none is None:
|
|
354
|
+
msg = "No result found when one was expected"
|
|
355
|
+
raise NotFoundError(msg)
|
|
356
|
+
return item_or_none
|
|
357
|
+
|
|
358
|
+
def _process_sql_statement(self, sql: str) -> str:
|
|
359
|
+
"""Perform any preprocessing of the SQL query string if needed.
|
|
360
|
+
Default implementation returns the SQL unchanged.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
sql: The SQL query string.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
The processed SQL query string.
|
|
367
|
+
"""
|
|
368
|
+
return sql
|
|
369
|
+
|
|
370
|
+
def _process_sql_params(
|
|
371
|
+
self, sql: str, parameters: "Optional[StatementParameterType]" = None
|
|
372
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
373
|
+
"""Process SQL query and parameters for DB-API execution.
|
|
374
|
+
|
|
375
|
+
Converts named parameters (:name) to positional parameters specified by `self.param_style`
|
|
376
|
+
if the input parameters are a dictionary.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
sql: The SQL query string.
|
|
380
|
+
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
A tuple containing the processed SQL string and the processed parameters
|
|
384
|
+
(always a tuple or None if the input was a dictionary, otherwise the original type).
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
ValueError: If a named parameter in the SQL is not found in the dictionary
|
|
388
|
+
or if a parameter in the dictionary is not used in the SQL.
|
|
389
|
+
"""
|
|
390
|
+
if not isinstance(parameters, dict) or not parameters:
|
|
391
|
+
# If parameters are not a dict, or empty dict, assume positional/no params
|
|
392
|
+
# Let the underlying driver handle tuples/lists directly
|
|
393
|
+
return self._process_sql_statement(sql), parameters
|
|
394
|
+
|
|
395
|
+
processed_sql = ""
|
|
396
|
+
processed_params_list: list[Any] = []
|
|
397
|
+
last_end = 0
|
|
398
|
+
found_params: set[str] = set()
|
|
399
|
+
|
|
400
|
+
for match in PARAM_REGEX.finditer(sql):
|
|
401
|
+
if match.group("dquote") is not None or match.group("squote") is not None:
|
|
402
|
+
# Skip placeholders within quotes
|
|
403
|
+
continue
|
|
404
|
+
|
|
405
|
+
var_name = match.group("var_name")
|
|
406
|
+
if var_name is None: # Should not happen with the regex, but safeguard
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
if var_name not in parameters:
|
|
410
|
+
msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary."
|
|
411
|
+
raise ValueError(msg)
|
|
412
|
+
|
|
413
|
+
# Append segment before the placeholder + the leading character + the driver's positional placeholder
|
|
414
|
+
# The match.start("var_name") -1 includes the character before the ':'
|
|
415
|
+
processed_sql += sql[last_end : match.start("var_name")] + self.param_style
|
|
416
|
+
processed_params_list.append(parameters[var_name])
|
|
417
|
+
found_params.add(var_name)
|
|
418
|
+
last_end = match.end("var_name")
|
|
419
|
+
|
|
420
|
+
# Append the rest of the SQL string
|
|
421
|
+
processed_sql += sql[last_end:]
|
|
422
|
+
|
|
423
|
+
# Check if all provided parameters were used
|
|
424
|
+
unused_params = set(parameters.keys()) - found_params
|
|
425
|
+
if unused_params:
|
|
426
|
+
msg = f"Parameters provided but not found in SQL: {unused_params}"
|
|
427
|
+
# Depending on desired strictness, this could be a warning or an error
|
|
428
|
+
# For now, let's raise an error for clarity
|
|
429
|
+
raise ValueError(msg)
|
|
430
|
+
|
|
431
|
+
processed_params = tuple(processed_params_list)
|
|
432
|
+
# Pass the processed SQL through the driver-specific processor if needed
|
|
433
|
+
final_sql = self._process_sql_statement(processed_sql)
|
|
434
|
+
return final_sql, processed_params
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]):
|
|
438
|
+
connection: ConnectionT
|
|
439
|
+
|
|
440
|
+
def __init__(self, connection: ConnectionT) -> None:
|
|
441
|
+
self.connection = connection
|
|
442
|
+
|
|
443
|
+
@abstractmethod
|
|
444
|
+
def select(
|
|
445
|
+
self,
|
|
446
|
+
sql: str,
|
|
447
|
+
parameters: Optional[StatementParameterType] = None,
|
|
448
|
+
/,
|
|
449
|
+
connection: Optional[ConnectionT] = None,
|
|
450
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
451
|
+
) -> "list[Union[ModelDTOT, dict[str, Any]]]": ...
|
|
452
|
+
|
|
453
|
+
@abstractmethod
|
|
454
|
+
def select_one(
|
|
455
|
+
self,
|
|
456
|
+
sql: str,
|
|
457
|
+
parameters: Optional[StatementParameterType] = None,
|
|
458
|
+
/,
|
|
459
|
+
connection: Optional[ConnectionT] = None,
|
|
460
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
461
|
+
) -> "Union[ModelDTOT, dict[str, Any]]": ...
|
|
462
|
+
|
|
463
|
+
@abstractmethod
|
|
464
|
+
def select_one_or_none(
|
|
465
|
+
self,
|
|
466
|
+
sql: str,
|
|
467
|
+
parameters: Optional[StatementParameterType] = None,
|
|
468
|
+
/,
|
|
469
|
+
connection: Optional[ConnectionT] = None,
|
|
470
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
471
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ...
|
|
472
|
+
|
|
473
|
+
@abstractmethod
|
|
474
|
+
def select_value(
|
|
475
|
+
self,
|
|
476
|
+
sql: str,
|
|
477
|
+
parameters: Optional[StatementParameterType] = None,
|
|
478
|
+
/,
|
|
479
|
+
connection: Optional[ConnectionT] = None,
|
|
480
|
+
schema_type: Optional[type[T]] = None,
|
|
481
|
+
) -> "Union[Any, T]": ...
|
|
482
|
+
|
|
483
|
+
@abstractmethod
|
|
484
|
+
def select_value_or_none(
|
|
485
|
+
self,
|
|
486
|
+
sql: str,
|
|
487
|
+
parameters: Optional[StatementParameterType] = None,
|
|
488
|
+
/,
|
|
489
|
+
connection: Optional[ConnectionT] = None,
|
|
490
|
+
schema_type: Optional[type[T]] = None,
|
|
491
|
+
) -> "Optional[Union[Any, T]]": ...
|
|
492
|
+
|
|
493
|
+
@abstractmethod
|
|
494
|
+
def insert_update_delete(
|
|
495
|
+
self,
|
|
496
|
+
sql: str,
|
|
497
|
+
parameters: Optional[StatementParameterType] = None,
|
|
498
|
+
/,
|
|
499
|
+
connection: Optional[ConnectionT] = None,
|
|
500
|
+
) -> int: ...
|
|
501
|
+
|
|
502
|
+
@abstractmethod
|
|
503
|
+
def insert_update_delete_returning(
|
|
504
|
+
self,
|
|
505
|
+
sql: str,
|
|
506
|
+
parameters: Optional[StatementParameterType] = None,
|
|
507
|
+
/,
|
|
508
|
+
connection: Optional[ConnectionT] = None,
|
|
509
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
510
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ...
|
|
511
|
+
|
|
512
|
+
@abstractmethod
|
|
513
|
+
def execute_script(
|
|
514
|
+
self,
|
|
515
|
+
sql: str,
|
|
516
|
+
parameters: Optional[StatementParameterType] = None,
|
|
517
|
+
/,
|
|
518
|
+
connection: Optional[ConnectionT] = None,
|
|
519
|
+
) -> str: ...
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class AsyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]):
|
|
523
|
+
connection: ConnectionT
|
|
524
|
+
|
|
525
|
+
def __init__(self, connection: ConnectionT) -> None:
|
|
526
|
+
self.connection = connection
|
|
527
|
+
|
|
528
|
+
@abstractmethod
|
|
529
|
+
async def select(
|
|
530
|
+
self,
|
|
531
|
+
sql: str,
|
|
532
|
+
parameters: Optional[StatementParameterType] = None,
|
|
533
|
+
/,
|
|
534
|
+
connection: Optional[ConnectionT] = None,
|
|
535
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
536
|
+
) -> "list[Union[ModelDTOT, dict[str, Any]]]": ...
|
|
537
|
+
|
|
538
|
+
@abstractmethod
|
|
539
|
+
async def select_one(
|
|
540
|
+
self,
|
|
541
|
+
sql: str,
|
|
542
|
+
parameters: Optional[StatementParameterType] = None,
|
|
543
|
+
/,
|
|
544
|
+
connection: Optional[ConnectionT] = None,
|
|
545
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
546
|
+
) -> "Union[ModelDTOT, dict[str, Any]]": ...
|
|
547
|
+
|
|
548
|
+
@abstractmethod
|
|
549
|
+
async def select_one_or_none(
|
|
550
|
+
self,
|
|
551
|
+
sql: str,
|
|
552
|
+
parameters: Optional[StatementParameterType] = None,
|
|
553
|
+
/,
|
|
554
|
+
connection: Optional[ConnectionT] = None,
|
|
555
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
556
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ...
|
|
557
|
+
|
|
558
|
+
@abstractmethod
|
|
559
|
+
async def select_value(
|
|
560
|
+
self,
|
|
561
|
+
sql: str,
|
|
562
|
+
parameters: Optional[StatementParameterType] = None,
|
|
563
|
+
/,
|
|
564
|
+
connection: Optional[ConnectionT] = None,
|
|
565
|
+
schema_type: Optional[type[T]] = None,
|
|
566
|
+
) -> "Union[Any, T]": ...
|
|
567
|
+
|
|
568
|
+
@abstractmethod
|
|
569
|
+
async def select_value_or_none(
|
|
570
|
+
self,
|
|
571
|
+
sql: str,
|
|
572
|
+
parameters: Optional[StatementParameterType] = None,
|
|
573
|
+
/,
|
|
574
|
+
connection: Optional[ConnectionT] = None,
|
|
575
|
+
schema_type: Optional[type[T]] = None,
|
|
576
|
+
) -> "Optional[Union[Any, T]]": ...
|
|
577
|
+
|
|
578
|
+
@abstractmethod
|
|
579
|
+
async def insert_update_delete(
|
|
580
|
+
self,
|
|
581
|
+
sql: str,
|
|
582
|
+
parameters: Optional[StatementParameterType] = None,
|
|
583
|
+
/,
|
|
584
|
+
connection: Optional[ConnectionT] = None,
|
|
585
|
+
) -> int: ...
|
|
586
|
+
|
|
587
|
+
@abstractmethod
|
|
588
|
+
async def insert_update_delete_returning(
|
|
589
|
+
self,
|
|
590
|
+
sql: str,
|
|
591
|
+
parameters: Optional[StatementParameterType] = None,
|
|
592
|
+
/,
|
|
593
|
+
connection: Optional[ConnectionT] = None,
|
|
594
|
+
schema_type: Optional[type[ModelDTOT]] = None,
|
|
595
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ...
|
|
596
|
+
|
|
597
|
+
@abstractmethod
|
|
598
|
+
async def execute_script(
|
|
599
|
+
self,
|
|
600
|
+
sql: str,
|
|
601
|
+
parameters: Optional[StatementParameterType] = None,
|
|
602
|
+
/,
|
|
603
|
+
connection: Optional[ConnectionT] = None,
|
|
604
|
+
) -> str: ...
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
DriverAdapterProtocol = Union[SyncDriverAdapterProtocol[ConnectionT], AsyncDriverAdapterProtocol[ConnectionT]]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
2
|
+
from sqlspec.extensions.litestar.handlers import (
|
|
3
|
+
autocommit_handler_maker,
|
|
4
|
+
connection_provider_maker,
|
|
5
|
+
lifespan_handler_maker,
|
|
6
|
+
manual_handler_maker,
|
|
7
|
+
pool_provider_maker,
|
|
8
|
+
)
|
|
9
|
+
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"DatabaseConfig",
|
|
13
|
+
"SQLSpec",
|
|
14
|
+
"autocommit_handler_maker",
|
|
15
|
+
"connection_provider_maker",
|
|
16
|
+
"lifespan_handler_maker",
|
|
17
|
+
"manual_handler_maker",
|
|
18
|
+
"pool_provider_maker",
|
|
19
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from litestar.types import Scope
|
|
5
|
+
|
|
6
|
+
__all__ = (
|
|
7
|
+
"delete_sqlspec_scope_state",
|
|
8
|
+
"get_sqlspec_scope_state",
|
|
9
|
+
"set_sqlspec_scope_state",
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
_SCOPE_NAMESPACE = "_sqlspec"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
|
|
16
|
+
"""Get an internal value from connection scope state.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
|
|
20
|
+
namespace to the default value, and returning it.
|
|
21
|
+
|
|
22
|
+
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
|
|
23
|
+
exist.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
scope: The connection scope.
|
|
27
|
+
key: Key to get from internal namespace in scope state.
|
|
28
|
+
default: Default value to return.
|
|
29
|
+
pop: Boolean flag dictating whether the value should be deleted from the state.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Value mapped to ``key`` in internal connection scope namespace.
|
|
33
|
+
"""
|
|
34
|
+
namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
|
|
35
|
+
return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def set_sqlspec_scope_state(scope: "Scope", key: str, value: Any) -> None:
|
|
39
|
+
"""Set an internal value in connection scope state.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
scope: The connection scope.
|
|
43
|
+
key: Key to set under internal namespace in scope state.
|
|
44
|
+
value: Value for key.
|
|
45
|
+
"""
|
|
46
|
+
scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def delete_sqlspec_scope_state(scope: "Scope", key: str) -> None:
|
|
50
|
+
"""Remove an internal value from connection scope state.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
scope: The connection scope.
|
|
54
|
+
key: Key to set under internal namespace in scope state.
|
|
55
|
+
"""
|
|
56
|
+
del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]
|