sqlspec 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +1 -1
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +160 -17
- sqlspec/adapters/adbc/driver.py +333 -0
- sqlspec/adapters/aiosqlite/__init__.py +6 -2
- sqlspec/adapters/aiosqlite/config.py +25 -7
- sqlspec/adapters/aiosqlite/driver.py +275 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +75 -14
- sqlspec/adapters/asyncmy/driver.py +255 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +99 -20
- sqlspec/adapters/asyncpg/driver.py +288 -0
- sqlspec/adapters/duckdb/__init__.py +6 -2
- sqlspec/adapters/duckdb/config.py +197 -15
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +11 -8
- sqlspec/adapters/oracledb/config/__init__.py +6 -6
- sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +99 -14
- sqlspec/adapters/oracledb/driver.py +498 -0
- sqlspec/adapters/psycopg/__init__.py +11 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +105 -13
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +105 -13
- sqlspec/adapters/psycopg/driver.py +616 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +25 -7
- sqlspec/adapters/sqlite/driver.py +303 -0
- sqlspec/base.py +416 -36
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +81 -0
- sqlspec/extensions/litestar/handlers.py +188 -0
- sqlspec/extensions/litestar/plugin.py +103 -11
- sqlspec/typing.py +72 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/deprecation.py +1 -1
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
- sqlspec-0.8.0.dist-info/RECORD +57 -0
- sqlspec-0.7.0.dist-info/RECORD +0 -46
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
2
|
+
|
|
3
|
+
from asyncpg import Connection
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
from sqlspec.base import AsyncDriverAdapterProtocol, T
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from asyncpg.connection import Connection
|
|
10
|
+
from asyncpg.pool import PoolConnectionProxy
|
|
11
|
+
|
|
12
|
+
from sqlspec.typing import ModelDTOT, StatementParameterType
|
|
13
|
+
|
|
14
|
+
__all__ = ("AsyncpgDriver",)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
PgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncpgDriver(AsyncDriverAdapterProtocol["PgConnection"]):
|
|
21
|
+
"""AsyncPG Postgres Driver Adapter."""
|
|
22
|
+
|
|
23
|
+
connection: "PgConnection"
|
|
24
|
+
|
|
25
|
+
def __init__(self, connection: "PgConnection") -> None:
|
|
26
|
+
self.connection = connection
|
|
27
|
+
|
|
28
|
+
def _process_sql_params(
|
|
29
|
+
self, sql: str, parameters: "Optional[StatementParameterType]" = None
|
|
30
|
+
) -> "tuple[str, Union[tuple[Any, ...], list[Any], dict[str, Any]]]":
|
|
31
|
+
sql, parameters = super()._process_sql_params(sql, parameters)
|
|
32
|
+
return sql, parameters if parameters is not None else ()
|
|
33
|
+
|
|
34
|
+
async def select(
|
|
35
|
+
self,
|
|
36
|
+
sql: str,
|
|
37
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
38
|
+
/,
|
|
39
|
+
connection: Optional["PgConnection"] = None,
|
|
40
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
41
|
+
) -> "list[Union[ModelDTOT, dict[str, Any]]]":
|
|
42
|
+
"""Fetch data from the database.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
sql: SQL statement.
|
|
46
|
+
parameters: Query parameters.
|
|
47
|
+
connection: Optional connection to use.
|
|
48
|
+
schema_type: Optional schema class for the result.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List of row data as either model instances or dictionaries.
|
|
52
|
+
"""
|
|
53
|
+
connection = self._connection(connection)
|
|
54
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
55
|
+
|
|
56
|
+
results = await connection.fetch(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
57
|
+
if not results:
|
|
58
|
+
return []
|
|
59
|
+
if schema_type is None:
|
|
60
|
+
return [dict(row.items()) for row in results] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
61
|
+
return [cast("ModelDTOT", schema_type(**dict(row.items()))) for row in results] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
62
|
+
|
|
63
|
+
async def select_one(
|
|
64
|
+
self,
|
|
65
|
+
sql: str,
|
|
66
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
67
|
+
/,
|
|
68
|
+
connection: Optional["PgConnection"] = None,
|
|
69
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
70
|
+
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
71
|
+
"""Fetch one row from the database.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
sql: SQL statement.
|
|
75
|
+
parameters: Query parameters.
|
|
76
|
+
connection: Optional connection to use.
|
|
77
|
+
schema_type: Optional schema class for the result.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The first row of the query results.
|
|
81
|
+
"""
|
|
82
|
+
connection = self._connection(connection)
|
|
83
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
84
|
+
# Use empty tuple if params is None
|
|
85
|
+
params = params if params is not None else ()
|
|
86
|
+
|
|
87
|
+
result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
88
|
+
result = self.check_not_found(result)
|
|
89
|
+
|
|
90
|
+
if schema_type is None:
|
|
91
|
+
# Always return as dictionary
|
|
92
|
+
return dict(result.items()) # type: ignore[attr-defined]
|
|
93
|
+
return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined]
|
|
94
|
+
|
|
95
|
+
async def select_one_or_none(
|
|
96
|
+
self,
|
|
97
|
+
sql: str,
|
|
98
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
99
|
+
/,
|
|
100
|
+
connection: Optional["PgConnection"] = None,
|
|
101
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
102
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
103
|
+
"""Fetch one row from the database.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
sql: SQL statement.
|
|
107
|
+
parameters: Query parameters.
|
|
108
|
+
connection: Optional connection to use.
|
|
109
|
+
schema_type: Optional schema class for the result.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The first row of the query results.
|
|
113
|
+
"""
|
|
114
|
+
connection = self._connection(connection)
|
|
115
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
116
|
+
|
|
117
|
+
result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
118
|
+
result = self.check_not_found(result)
|
|
119
|
+
if schema_type is None:
|
|
120
|
+
# Always return as dictionary
|
|
121
|
+
return dict(result.items()) # type: ignore[attr-defined]
|
|
122
|
+
return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined]
|
|
123
|
+
|
|
124
|
+
async def select_value(
|
|
125
|
+
self,
|
|
126
|
+
sql: str,
|
|
127
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
128
|
+
/,
|
|
129
|
+
connection: "Optional[PgConnection]" = None,
|
|
130
|
+
schema_type: "Optional[type[T]]" = None,
|
|
131
|
+
) -> "Union[T, Any]":
|
|
132
|
+
"""Fetch a single value from the database.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
The first value from the first row of results, or None if no results.
|
|
136
|
+
"""
|
|
137
|
+
connection = self._connection(connection)
|
|
138
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
139
|
+
# Use empty tuple if params is None
|
|
140
|
+
params = params if params is not None else ()
|
|
141
|
+
|
|
142
|
+
result = await connection.fetchval(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
143
|
+
result = self.check_not_found(result)
|
|
144
|
+
if schema_type is None:
|
|
145
|
+
return result[0]
|
|
146
|
+
return schema_type(result[0]) # type: ignore[call-arg]
|
|
147
|
+
|
|
148
|
+
async def select_value_or_none(
|
|
149
|
+
self,
|
|
150
|
+
sql: str,
|
|
151
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
152
|
+
/,
|
|
153
|
+
connection: "Optional[PgConnection]" = None,
|
|
154
|
+
schema_type: "Optional[type[T]]" = None,
|
|
155
|
+
) -> "Optional[Union[T, Any]]":
|
|
156
|
+
"""Fetch a single value from the database.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
The first value from the first row of results, or None if no results.
|
|
160
|
+
"""
|
|
161
|
+
connection = self._connection(connection)
|
|
162
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
163
|
+
# Use empty tuple if params is None
|
|
164
|
+
params = params if params is not None else ()
|
|
165
|
+
|
|
166
|
+
result = await connection.fetchval(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
167
|
+
if result is None:
|
|
168
|
+
return None
|
|
169
|
+
if schema_type is None:
|
|
170
|
+
return result[0]
|
|
171
|
+
return schema_type(result[0]) # type: ignore[call-arg]
|
|
172
|
+
|
|
173
|
+
async def insert_update_delete(
|
|
174
|
+
self,
|
|
175
|
+
sql: str,
|
|
176
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
177
|
+
/,
|
|
178
|
+
connection: Optional["PgConnection"] = None,
|
|
179
|
+
) -> int:
|
|
180
|
+
"""Insert, update, or delete data from the database.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
sql: SQL statement.
|
|
184
|
+
parameters: Query parameters.
|
|
185
|
+
connection: Optional connection to use.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Row count affected by the operation.
|
|
189
|
+
"""
|
|
190
|
+
connection = self._connection(connection)
|
|
191
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
192
|
+
# Use empty tuple if params is None
|
|
193
|
+
params = params if params is not None else ()
|
|
194
|
+
|
|
195
|
+
status = await connection.execute(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
196
|
+
# AsyncPG returns a string like "INSERT 0 1" where the last number is the affected rows
|
|
197
|
+
try:
|
|
198
|
+
return int(status.split()[-1]) # pyright: ignore[reportUnknownMemberType]
|
|
199
|
+
except (ValueError, IndexError, AttributeError):
|
|
200
|
+
return -1 # Fallback if we can't parse the status
|
|
201
|
+
|
|
202
|
+
async def insert_update_delete_returning(
|
|
203
|
+
self,
|
|
204
|
+
sql: str,
|
|
205
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
206
|
+
/,
|
|
207
|
+
connection: Optional["PgConnection"] = None,
|
|
208
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
209
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
210
|
+
"""Insert, update, or delete data from the database and return result.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
sql: SQL statement.
|
|
214
|
+
parameters: Query parameters.
|
|
215
|
+
connection: Optional connection to use.
|
|
216
|
+
schema_type: Optional schema class for the result.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
The first row of results.
|
|
220
|
+
"""
|
|
221
|
+
connection = self._connection(connection)
|
|
222
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
223
|
+
# Use empty tuple if params is None
|
|
224
|
+
params = params if params is not None else ()
|
|
225
|
+
|
|
226
|
+
result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
227
|
+
if result is None:
|
|
228
|
+
return None
|
|
229
|
+
if schema_type is None:
|
|
230
|
+
# Always return as dictionary
|
|
231
|
+
return dict(result.items()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
232
|
+
return cast("ModelDTOT", schema_type(**dict(result.items()))) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType]
|
|
233
|
+
|
|
234
|
+
async def execute_script(
|
|
235
|
+
self,
|
|
236
|
+
sql: str,
|
|
237
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
238
|
+
/,
|
|
239
|
+
connection: Optional["PgConnection"] = None,
|
|
240
|
+
) -> str:
|
|
241
|
+
"""Execute a script.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
sql: SQL statement.
|
|
245
|
+
parameters: Query parameters.
|
|
246
|
+
connection: Optional connection to use.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Status message for the operation.
|
|
250
|
+
"""
|
|
251
|
+
connection = self._connection(connection)
|
|
252
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
253
|
+
# Use empty tuple if params is None
|
|
254
|
+
params = params if params is not None else ()
|
|
255
|
+
|
|
256
|
+
return await connection.execute(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
257
|
+
|
|
258
|
+
async def execute_script_returning(
|
|
259
|
+
self,
|
|
260
|
+
sql: str,
|
|
261
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
262
|
+
/,
|
|
263
|
+
connection: Optional["PgConnection"] = None,
|
|
264
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
265
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
266
|
+
"""Execute a script and return result.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
sql: SQL statement.
|
|
270
|
+
parameters: Query parameters.
|
|
271
|
+
connection: Optional connection to use.
|
|
272
|
+
schema_type: Optional schema class for the result.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
The first row of results.
|
|
276
|
+
"""
|
|
277
|
+
connection = self._connection(connection)
|
|
278
|
+
sql, params = self._process_sql_params(sql, parameters)
|
|
279
|
+
# Use empty tuple if params is None
|
|
280
|
+
params = params if params is not None else ()
|
|
281
|
+
|
|
282
|
+
result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
283
|
+
if result is None:
|
|
284
|
+
return None
|
|
285
|
+
if schema_type is None:
|
|
286
|
+
# Always return as dictionary
|
|
287
|
+
return dict(result.items()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
288
|
+
return cast("ModelDTOT", schema_type(**dict(result.items()))) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType]
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from contextlib import contextmanager
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Union, cast
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
from duckdb import DuckDBPyConnection
|
|
6
|
-
from typing_extensions import NotRequired, TypedDict
|
|
6
|
+
from typing_extensions import Literal, NotRequired, TypedDict
|
|
7
7
|
|
|
8
|
+
from sqlspec.adapters.duckdb.driver import DuckDBDriver
|
|
8
9
|
from sqlspec.base import NoPoolSyncConfig
|
|
9
10
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
10
11
|
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
|
|
@@ -13,7 +14,7 @@ if TYPE_CHECKING:
|
|
|
13
14
|
from collections.abc import Generator, Sequence
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
__all__ = ("
|
|
17
|
+
__all__ = ("DuckDB", "ExtensionConfig")
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class ExtensionConfig(TypedDict):
|
|
@@ -29,6 +30,8 @@ class ExtensionConfig(TypedDict):
|
|
|
29
30
|
"""The name of the extension to install"""
|
|
30
31
|
config: "NotRequired[dict[str, Any]]"
|
|
31
32
|
"""Optional configuration settings to apply after installation"""
|
|
33
|
+
install_if_missing: "NotRequired[bool]"
|
|
34
|
+
"""Whether to install if missing"""
|
|
32
35
|
force_install: "NotRequired[bool]"
|
|
33
36
|
"""Whether to force reinstall if already present"""
|
|
34
37
|
repository: "NotRequired[str]"
|
|
@@ -39,8 +42,34 @@ class ExtensionConfig(TypedDict):
|
|
|
39
42
|
"""Optional version of the extension to install"""
|
|
40
43
|
|
|
41
44
|
|
|
45
|
+
class SecretConfig(TypedDict):
|
|
46
|
+
"""Configuration for a secret to store in a connection.
|
|
47
|
+
|
|
48
|
+
This class provides configuration options for storing a secret in a connection for later retrieval.
|
|
49
|
+
|
|
50
|
+
For details see: https://duckdb.org/docs/stable/configuration/secrets_manager
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
secret_type: Union[
|
|
54
|
+
Literal[
|
|
55
|
+
"azure", "gcs", "s3", "r2", "huggingface", "http", "mysql", "postgres", "bigquery", "openai", "open_prompt" # noqa: PYI051
|
|
56
|
+
],
|
|
57
|
+
str,
|
|
58
|
+
]
|
|
59
|
+
provider: NotRequired[str]
|
|
60
|
+
"""The provider of the secret"""
|
|
61
|
+
name: str
|
|
62
|
+
"""The name of the secret to store"""
|
|
63
|
+
value: dict[str, Any]
|
|
64
|
+
"""The secret value to store"""
|
|
65
|
+
persist: NotRequired[bool]
|
|
66
|
+
"""Whether to persist the secret"""
|
|
67
|
+
replace_if_exists: NotRequired[bool]
|
|
68
|
+
"""Whether to replace the secret if it already exists"""
|
|
69
|
+
|
|
70
|
+
|
|
42
71
|
@dataclass
|
|
43
|
-
class
|
|
72
|
+
class DuckDB(NoPoolSyncConfig["DuckDBPyConnection", "DuckDBDriver"]):
|
|
44
73
|
"""Configuration for DuckDB database connections.
|
|
45
74
|
|
|
46
75
|
This class provides configuration options for DuckDB database connections, wrapping all parameters
|
|
@@ -49,7 +78,7 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
49
78
|
For details see: https://duckdb.org/docs/api/python/overview#connection-options
|
|
50
79
|
"""
|
|
51
80
|
|
|
52
|
-
database: "Union[str, EmptyType]" =
|
|
81
|
+
database: "Union[str, EmptyType]" = field(default=":memory:")
|
|
53
82
|
"""The path to the database file to be opened. Pass ":memory:" to open a connection to a database that resides in RAM instead of on disk. If not specified, an in-memory database will be created."""
|
|
54
83
|
|
|
55
84
|
read_only: "Union[bool, EmptyType]" = Empty
|
|
@@ -63,6 +92,18 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
63
92
|
|
|
64
93
|
extensions: "Union[Sequence[ExtensionConfig], ExtensionConfig, EmptyType]" = Empty
|
|
65
94
|
"""A sequence of extension configurations to install and configure upon connection creation."""
|
|
95
|
+
secrets: "Union[Sequence[SecretConfig], SecretConfig , EmptyType]" = Empty
|
|
96
|
+
"""A dictionary of secrets to store in the connection for later retrieval."""
|
|
97
|
+
auto_update_extensions: "bool" = False
|
|
98
|
+
"""Whether to automatically update on connection creation"""
|
|
99
|
+
on_connection_create: "Optional[Callable[[DuckDBPyConnection], Optional[DuckDBPyConnection]]]" = None
|
|
100
|
+
"""A callable to be called after the connection is created."""
|
|
101
|
+
connection_type: "type[DuckDBPyConnection]" = field(init=False, default_factory=lambda: DuckDBPyConnection)
|
|
102
|
+
"""The type of connection to create. Defaults to DuckDBPyConnection."""
|
|
103
|
+
driver_type: "type[DuckDBDriver]" = field(init=False, default_factory=lambda: DuckDBDriver) # type: ignore[type-abstract,unused-ignore]
|
|
104
|
+
"""The type of driver to use. Defaults to DuckDBDriver."""
|
|
105
|
+
pool_instance: "None" = field(init=False, default=None)
|
|
106
|
+
"""The pool instance to use. Defaults to None."""
|
|
66
107
|
|
|
67
108
|
def __post_init__(self) -> None:
|
|
68
109
|
"""Post-initialization validation and processing.
|
|
@@ -73,9 +114,10 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
73
114
|
"""
|
|
74
115
|
if self.config is Empty:
|
|
75
116
|
self.config = {}
|
|
76
|
-
|
|
77
117
|
if self.extensions is Empty:
|
|
78
118
|
self.extensions = []
|
|
119
|
+
if self.secrets is Empty:
|
|
120
|
+
self.secrets = []
|
|
79
121
|
if isinstance(self.extensions, dict):
|
|
80
122
|
self.extensions = [self.extensions]
|
|
81
123
|
# this is purely for mypy
|
|
@@ -103,8 +145,8 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
103
145
|
Args:
|
|
104
146
|
connection: The DuckDB connection to configure.
|
|
105
147
|
"""
|
|
106
|
-
for
|
|
107
|
-
connection.execute(
|
|
148
|
+
for key, value in cast("dict[str,Any]", self.config).items():
|
|
149
|
+
connection.execute(f"SET {key}='{value}'")
|
|
108
150
|
|
|
109
151
|
def _configure_extensions(self, connection: "DuckDBPyConnection") -> None:
|
|
110
152
|
"""Configure extensions for the connection.
|
|
@@ -119,9 +161,104 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
119
161
|
|
|
120
162
|
for extension in cast("list[ExtensionConfig]", self.extensions):
|
|
121
163
|
self._configure_extension(connection, extension)
|
|
164
|
+
if self.auto_update_extensions:
|
|
165
|
+
connection.execute("update extensions")
|
|
122
166
|
|
|
123
167
|
@staticmethod
|
|
124
|
-
def
|
|
168
|
+
def _secret_exists(connection: "DuckDBPyConnection", name: "str") -> bool:
|
|
169
|
+
"""Check if a secret exists in the connection.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
connection: The DuckDB connection to check for the secret.
|
|
173
|
+
name: The name of the secret to check for.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
bool: True if the secret exists, False otherwise.
|
|
177
|
+
"""
|
|
178
|
+
results = connection.execute("select 1 from duckdb_secrets() where name=?", [name]).fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
179
|
+
return results is not None
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def _is_community_extension(cls, connection: "DuckDBPyConnection", name: "str") -> bool:
|
|
183
|
+
"""Check if an extension is a community extension.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
connection: The DuckDB connection to check for the extension.
|
|
187
|
+
name: The name of the extension to check.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
bool: True if the extension is a community extension, False otherwise.
|
|
191
|
+
"""
|
|
192
|
+
results = connection.execute( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
193
|
+
"select 1 from duckdb_extensions() where extension_name=?", [name]
|
|
194
|
+
).fetchone()
|
|
195
|
+
return results is None
|
|
196
|
+
|
|
197
|
+
@classmethod
|
|
198
|
+
def _extension_installed(cls, connection: "DuckDBPyConnection", name: "str") -> bool:
|
|
199
|
+
"""Check if a extension exists in the connection.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
connection: The DuckDB connection to check for the secret.
|
|
203
|
+
name: The name of the secret to check for.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
bool: True if the extension is installed, False otherwise.
|
|
207
|
+
"""
|
|
208
|
+
results = connection.execute( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
209
|
+
"select 1 from duckdb_extensions() where extension_name=? and installed=true", [name]
|
|
210
|
+
).fetchone()
|
|
211
|
+
return results is not None
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def _extension_loaded(cls, connection: "DuckDBPyConnection", name: "str") -> bool:
|
|
215
|
+
"""Check if a extension is loaded in the connection.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
connection: The DuckDB connection to check for the extension.
|
|
219
|
+
name: The name of the extension to check for.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
bool: True if the extension is loaded, False otherwise.
|
|
223
|
+
"""
|
|
224
|
+
results = connection.execute( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
225
|
+
"select 1 from duckdb_extensions() where extension_name=? and loaded=true", [name]
|
|
226
|
+
).fetchone()
|
|
227
|
+
return results is not None
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def _configure_secrets(
|
|
231
|
+
cls,
|
|
232
|
+
connection: "DuckDBPyConnection",
|
|
233
|
+
secrets: "Sequence[SecretConfig]",
|
|
234
|
+
) -> None:
|
|
235
|
+
"""Configure persistent secrets for the connection.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
connection: The DuckDB connection to configure secrets for.
|
|
239
|
+
secrets: The list of secrets to store in the connection.
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
ImproperConfigurationError: If a secret could not be stored in the connection.
|
|
243
|
+
"""
|
|
244
|
+
try:
|
|
245
|
+
for secret in secrets:
|
|
246
|
+
secret_exists = cls._secret_exists(connection, secret["name"])
|
|
247
|
+
if not secret_exists or secret.get("replace_if_exists", False):
|
|
248
|
+
provider_type = "" if not secret.get("provider") else f"provider {secret.get('provider')},"
|
|
249
|
+
connection.execute(
|
|
250
|
+
f"""create or replace {"persistent" if secret.get("persist", False) else ""} secret {secret["name"]} (
|
|
251
|
+
type {secret["secret_type"]},
|
|
252
|
+
{provider_type}
|
|
253
|
+
{" ,".join([f"{k} '{v}'" for k, v in secret["value"].items()])}
|
|
254
|
+
) """
|
|
255
|
+
)
|
|
256
|
+
except Exception as e:
|
|
257
|
+
msg = f"Failed to store secret. Error: {e!s}"
|
|
258
|
+
raise ImproperConfigurationError(msg) from e
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def _configure_extension(cls, connection: "DuckDBPyConnection", extension: "ExtensionConfig") -> None:
|
|
125
262
|
"""Configure a single extension for the connection.
|
|
126
263
|
|
|
127
264
|
Args:
|
|
@@ -132,16 +269,32 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
132
269
|
ImproperConfigurationError: If extension installation or configuration fails.
|
|
133
270
|
"""
|
|
134
271
|
try:
|
|
135
|
-
if
|
|
272
|
+
# Install extension if needed
|
|
273
|
+
if (
|
|
274
|
+
not cls._extension_installed(connection, extension["name"])
|
|
275
|
+
and extension.get("install_if_missing", True)
|
|
276
|
+
) or extension.get("force_install", False):
|
|
277
|
+
repository = extension.get("repository", None)
|
|
278
|
+
repository_url = (
|
|
279
|
+
"https://community-extensions.duckdb.org"
|
|
280
|
+
if repository is None
|
|
281
|
+
and cls._is_community_extension(connection, extension["name"])
|
|
282
|
+
and extension.get("repository_url") is None
|
|
283
|
+
else extension.get("repository_url", None)
|
|
284
|
+
)
|
|
136
285
|
connection.install_extension(
|
|
137
286
|
extension=extension["name"],
|
|
138
287
|
force_install=extension.get("force_install", False),
|
|
139
|
-
repository=
|
|
140
|
-
repository_url=
|
|
288
|
+
repository=repository,
|
|
289
|
+
repository_url=repository_url,
|
|
141
290
|
version=extension.get("version"),
|
|
142
291
|
)
|
|
143
|
-
connection.load_extension(extension["name"])
|
|
144
292
|
|
|
293
|
+
# Load extension if not already loaded
|
|
294
|
+
if not cls._extension_loaded(connection, extension["name"]):
|
|
295
|
+
connection.load_extension(extension["name"])
|
|
296
|
+
|
|
297
|
+
# Apply any configuration settings
|
|
145
298
|
if extension.get("config"):
|
|
146
299
|
for key, value in extension.get("config", {}).items():
|
|
147
300
|
connection.execute(f"SET {key}={value}")
|
|
@@ -156,7 +309,20 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
156
309
|
Returns:
|
|
157
310
|
A string keyed dict of config kwargs for the duckdb.connect() function.
|
|
158
311
|
"""
|
|
159
|
-
config = dataclass_to_dict(
|
|
312
|
+
config = dataclass_to_dict(
|
|
313
|
+
self,
|
|
314
|
+
exclude_empty=True,
|
|
315
|
+
exclude={
|
|
316
|
+
"extensions",
|
|
317
|
+
"pool_instance",
|
|
318
|
+
"secrets",
|
|
319
|
+
"on_connection_create",
|
|
320
|
+
"auto_update_extensions",
|
|
321
|
+
"driver_type",
|
|
322
|
+
"connection_type",
|
|
323
|
+
},
|
|
324
|
+
convert_nested=False,
|
|
325
|
+
)
|
|
160
326
|
if not config.get("database"):
|
|
161
327
|
config["database"] = ":memory:"
|
|
162
328
|
return config
|
|
@@ -175,7 +341,11 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
175
341
|
try:
|
|
176
342
|
connection = duckdb.connect(**self.connection_config_dict) # pyright: ignore[reportUnknownMemberType]
|
|
177
343
|
self._configure_extensions(connection)
|
|
344
|
+
self._configure_secrets(connection, cast("list[SecretConfig]", self.secrets))
|
|
178
345
|
self._configure_connection(connection)
|
|
346
|
+
if self.on_connection_create:
|
|
347
|
+
self.on_connection_create(connection)
|
|
348
|
+
|
|
179
349
|
except Exception as e:
|
|
180
350
|
msg = f"Could not configure the DuckDB connection. Error: {e!s}"
|
|
181
351
|
raise ImproperConfigurationError(msg) from e
|
|
@@ -196,3 +366,15 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
|
|
|
196
366
|
yield connection
|
|
197
367
|
finally:
|
|
198
368
|
connection.close()
|
|
369
|
+
|
|
370
|
+
@contextmanager
|
|
371
|
+
def provide_session(self, *args: Any, **kwargs: Any) -> "Generator[DuckDBDriver, None, None]":
|
|
372
|
+
"""Create and provide a database connection.
|
|
373
|
+
|
|
374
|
+
Yields:
|
|
375
|
+
A DuckDB connection instance.
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
"""
|
|
379
|
+
with self.provide_connection(*args, **kwargs) as connection:
|
|
380
|
+
yield self.driver_type(connection, use_cursor=True)
|