sqlspec 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (49) hide show
  1. sqlspec/__init__.py +15 -0
  2. sqlspec/_serialization.py +16 -2
  3. sqlspec/_typing.py +1 -1
  4. sqlspec/adapters/adbc/__init__.py +7 -0
  5. sqlspec/adapters/adbc/config.py +160 -17
  6. sqlspec/adapters/adbc/driver.py +333 -0
  7. sqlspec/adapters/aiosqlite/__init__.py +6 -2
  8. sqlspec/adapters/aiosqlite/config.py +25 -7
  9. sqlspec/adapters/aiosqlite/driver.py +275 -0
  10. sqlspec/adapters/asyncmy/__init__.py +7 -2
  11. sqlspec/adapters/asyncmy/config.py +75 -14
  12. sqlspec/adapters/asyncmy/driver.py +255 -0
  13. sqlspec/adapters/asyncpg/__init__.py +9 -0
  14. sqlspec/adapters/asyncpg/config.py +99 -20
  15. sqlspec/adapters/asyncpg/driver.py +288 -0
  16. sqlspec/adapters/duckdb/__init__.py +6 -2
  17. sqlspec/adapters/duckdb/config.py +195 -13
  18. sqlspec/adapters/duckdb/driver.py +225 -0
  19. sqlspec/adapters/oracledb/__init__.py +11 -8
  20. sqlspec/adapters/oracledb/config/__init__.py +6 -6
  21. sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
  22. sqlspec/adapters/oracledb/config/_common.py +1 -1
  23. sqlspec/adapters/oracledb/config/_sync.py +99 -14
  24. sqlspec/adapters/oracledb/driver.py +498 -0
  25. sqlspec/adapters/psycopg/__init__.py +11 -0
  26. sqlspec/adapters/psycopg/config/__init__.py +6 -6
  27. sqlspec/adapters/psycopg/config/_async.py +105 -13
  28. sqlspec/adapters/psycopg/config/_common.py +2 -2
  29. sqlspec/adapters/psycopg/config/_sync.py +105 -13
  30. sqlspec/adapters/psycopg/driver.py +616 -0
  31. sqlspec/adapters/sqlite/__init__.py +7 -0
  32. sqlspec/adapters/sqlite/config.py +25 -7
  33. sqlspec/adapters/sqlite/driver.py +303 -0
  34. sqlspec/base.py +416 -36
  35. sqlspec/extensions/litestar/__init__.py +19 -0
  36. sqlspec/extensions/litestar/_utils.py +56 -0
  37. sqlspec/extensions/litestar/config.py +81 -0
  38. sqlspec/extensions/litestar/handlers.py +188 -0
  39. sqlspec/extensions/litestar/plugin.py +100 -11
  40. sqlspec/typing.py +72 -17
  41. sqlspec/utils/__init__.py +3 -0
  42. sqlspec/utils/fixtures.py +4 -5
  43. sqlspec/utils/sync_tools.py +335 -0
  44. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
  45. sqlspec-0.8.0.dist-info/RECORD +57 -0
  46. sqlspec-0.7.1.dist-info/RECORD +0 -46
  47. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
  48. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
  49. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,288 @@
1
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast
2
+
3
+ from asyncpg import Connection
4
+ from typing_extensions import TypeAlias
5
+
6
+ from sqlspec.base import AsyncDriverAdapterProtocol, T
7
+
8
+ if TYPE_CHECKING:
9
+ from asyncpg.connection import Connection
10
+ from asyncpg.pool import PoolConnectionProxy
11
+
12
+ from sqlspec.typing import ModelDTOT, StatementParameterType
13
+
14
+ __all__ = ("AsyncpgDriver",)
15
+
16
+
17
+ PgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument]
18
+
19
+
20
+ class AsyncpgDriver(AsyncDriverAdapterProtocol["PgConnection"]):
21
+ """AsyncPG Postgres Driver Adapter."""
22
+
23
+ connection: "PgConnection"
24
+
25
+ def __init__(self, connection: "PgConnection") -> None:
26
+ self.connection = connection
27
+
28
+ def _process_sql_params(
29
+ self, sql: str, parameters: "Optional[StatementParameterType]" = None
30
+ ) -> "tuple[str, Union[tuple[Any, ...], list[Any], dict[str, Any]]]":
31
+ sql, parameters = super()._process_sql_params(sql, parameters)
32
+ return sql, parameters if parameters is not None else ()
33
+
34
+ async def select(
35
+ self,
36
+ sql: str,
37
+ parameters: Optional["StatementParameterType"] = None,
38
+ /,
39
+ connection: Optional["PgConnection"] = None,
40
+ schema_type: "Optional[type[ModelDTOT]]" = None,
41
+ ) -> "list[Union[ModelDTOT, dict[str, Any]]]":
42
+ """Fetch data from the database.
43
+
44
+ Args:
45
+ sql: SQL statement.
46
+ parameters: Query parameters.
47
+ connection: Optional connection to use.
48
+ schema_type: Optional schema class for the result.
49
+
50
+ Returns:
51
+ List of row data as either model instances or dictionaries.
52
+ """
53
+ connection = self._connection(connection)
54
+ sql, parameters = self._process_sql_params(sql, parameters)
55
+
56
+ results = await connection.fetch(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
57
+ if not results:
58
+ return []
59
+ if schema_type is None:
60
+ return [dict(row.items()) for row in results] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
61
+ return [cast("ModelDTOT", schema_type(**dict(row.items()))) for row in results] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
62
+
63
+ async def select_one(
64
+ self,
65
+ sql: str,
66
+ parameters: Optional["StatementParameterType"] = None,
67
+ /,
68
+ connection: Optional["PgConnection"] = None,
69
+ schema_type: "Optional[type[ModelDTOT]]" = None,
70
+ ) -> "Union[ModelDTOT, dict[str, Any]]":
71
+ """Fetch one row from the database.
72
+
73
+ Args:
74
+ sql: SQL statement.
75
+ parameters: Query parameters.
76
+ connection: Optional connection to use.
77
+ schema_type: Optional schema class for the result.
78
+
79
+ Returns:
80
+ The first row of the query results.
81
+ """
82
+ connection = self._connection(connection)
83
+ sql, params = self._process_sql_params(sql, parameters)
84
+ # Use empty tuple if params is None
85
+ params = params if params is not None else ()
86
+
87
+ result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
88
+ result = self.check_not_found(result)
89
+
90
+ if schema_type is None:
91
+ # Always return as dictionary
92
+ return dict(result.items()) # type: ignore[attr-defined]
93
+ return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined]
94
+
95
+ async def select_one_or_none(
96
+ self,
97
+ sql: str,
98
+ parameters: Optional["StatementParameterType"] = None,
99
+ /,
100
+ connection: Optional["PgConnection"] = None,
101
+ schema_type: "Optional[type[ModelDTOT]]" = None,
102
+ ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
103
+ """Fetch one row from the database.
104
+
105
+ Args:
106
+ sql: SQL statement.
107
+ parameters: Query parameters.
108
+ connection: Optional connection to use.
109
+ schema_type: Optional schema class for the result.
110
+
111
+ Returns:
112
+ The first row of the query results.
113
+ """
114
+ connection = self._connection(connection)
115
+ sql, parameters = self._process_sql_params(sql, parameters)
116
+
117
+ result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
118
+ result = self.check_not_found(result)
119
+ if schema_type is None:
120
+ # Always return as dictionary
121
+ return dict(result.items()) # type: ignore[attr-defined]
122
+ return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined]
123
+
124
+ async def select_value(
125
+ self,
126
+ sql: str,
127
+ parameters: "Optional[StatementParameterType]" = None,
128
+ /,
129
+ connection: "Optional[PgConnection]" = None,
130
+ schema_type: "Optional[type[T]]" = None,
131
+ ) -> "Union[T, Any]":
132
+ """Fetch a single value from the database.
133
+
134
+ Returns:
135
+ The first value from the first row of results, or None if no results.
136
+ """
137
+ connection = self._connection(connection)
138
+ sql, params = self._process_sql_params(sql, parameters)
139
+ # Use empty tuple if params is None
140
+ params = params if params is not None else ()
141
+
142
+ result = await connection.fetchval(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
143
+ result = self.check_not_found(result)
144
+ if schema_type is None:
145
+ return result[0]
146
+ return schema_type(result[0]) # type: ignore[call-arg]
147
+
148
+ async def select_value_or_none(
149
+ self,
150
+ sql: str,
151
+ parameters: "Optional[StatementParameterType]" = None,
152
+ /,
153
+ connection: "Optional[PgConnection]" = None,
154
+ schema_type: "Optional[type[T]]" = None,
155
+ ) -> "Optional[Union[T, Any]]":
156
+ """Fetch a single value from the database.
157
+
158
+ Returns:
159
+ The first value from the first row of results, or None if no results.
160
+ """
161
+ connection = self._connection(connection)
162
+ sql, params = self._process_sql_params(sql, parameters)
163
+ # Use empty tuple if params is None
164
+ params = params if params is not None else ()
165
+
166
+ result = await connection.fetchval(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
167
+ if result is None:
168
+ return None
169
+ if schema_type is None:
170
+ return result[0]
171
+ return schema_type(result[0]) # type: ignore[call-arg]
172
+
173
+ async def insert_update_delete(
174
+ self,
175
+ sql: str,
176
+ parameters: Optional["StatementParameterType"] = None,
177
+ /,
178
+ connection: Optional["PgConnection"] = None,
179
+ ) -> int:
180
+ """Insert, update, or delete data from the database.
181
+
182
+ Args:
183
+ sql: SQL statement.
184
+ parameters: Query parameters.
185
+ connection: Optional connection to use.
186
+
187
+ Returns:
188
+ Row count affected by the operation.
189
+ """
190
+ connection = self._connection(connection)
191
+ sql, params = self._process_sql_params(sql, parameters)
192
+ # Use empty tuple if params is None
193
+ params = params if params is not None else ()
194
+
195
+ status = await connection.execute(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
196
+ # AsyncPG returns a string like "INSERT 0 1" where the last number is the affected rows
197
+ try:
198
+ return int(status.split()[-1]) # pyright: ignore[reportUnknownMemberType]
199
+ except (ValueError, IndexError, AttributeError):
200
+ return -1 # Fallback if we can't parse the status
201
+
202
+ async def insert_update_delete_returning(
203
+ self,
204
+ sql: str,
205
+ parameters: Optional["StatementParameterType"] = None,
206
+ /,
207
+ connection: Optional["PgConnection"] = None,
208
+ schema_type: "Optional[type[ModelDTOT]]" = None,
209
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
210
+ """Insert, update, or delete data from the database and return result.
211
+
212
+ Args:
213
+ sql: SQL statement.
214
+ parameters: Query parameters.
215
+ connection: Optional connection to use.
216
+ schema_type: Optional schema class for the result.
217
+
218
+ Returns:
219
+ The first row of results.
220
+ """
221
+ connection = self._connection(connection)
222
+ sql, params = self._process_sql_params(sql, parameters)
223
+ # Use empty tuple if params is None
224
+ params = params if params is not None else ()
225
+
226
+ result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
227
+ if result is None:
228
+ return None
229
+ if schema_type is None:
230
+ # Always return as dictionary
231
+ return dict(result.items()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
232
+ return cast("ModelDTOT", schema_type(**dict(result.items()))) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType]
233
+
234
+ async def execute_script(
235
+ self,
236
+ sql: str,
237
+ parameters: Optional["StatementParameterType"] = None,
238
+ /,
239
+ connection: Optional["PgConnection"] = None,
240
+ ) -> str:
241
+ """Execute a script.
242
+
243
+ Args:
244
+ sql: SQL statement.
245
+ parameters: Query parameters.
246
+ connection: Optional connection to use.
247
+
248
+ Returns:
249
+ Status message for the operation.
250
+ """
251
+ connection = self._connection(connection)
252
+ sql, params = self._process_sql_params(sql, parameters)
253
+ # Use empty tuple if params is None
254
+ params = params if params is not None else ()
255
+
256
+ return await connection.execute(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
257
+
258
+ async def execute_script_returning(
259
+ self,
260
+ sql: str,
261
+ parameters: Optional["StatementParameterType"] = None,
262
+ /,
263
+ connection: Optional["PgConnection"] = None,
264
+ schema_type: "Optional[type[ModelDTOT]]" = None,
265
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
266
+ """Execute a script and return result.
267
+
268
+ Args:
269
+ sql: SQL statement.
270
+ parameters: Query parameters.
271
+ connection: Optional connection to use.
272
+ schema_type: Optional schema class for the result.
273
+
274
+ Returns:
275
+ The first row of results.
276
+ """
277
+ connection = self._connection(connection)
278
+ sql, params = self._process_sql_params(sql, parameters)
279
+ # Use empty tuple if params is None
280
+ params = params if params is not None else ()
281
+
282
+ result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
283
+ if result is None:
284
+ return None
285
+ if schema_type is None:
286
+ # Always return as dictionary
287
+ return dict(result.items()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
288
+ return cast("ModelDTOT", schema_type(**dict(result.items()))) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType]
@@ -1,3 +1,7 @@
1
- from sqlspec.adapters.duckdb.config import DuckDBConfig
1
+ from sqlspec.adapters.duckdb.config import DuckDB
2
+ from sqlspec.adapters.duckdb.driver import DuckDBDriver
2
3
 
3
- __all__ = ("DuckDBConfig",)
4
+ __all__ = (
5
+ "DuckDB",
6
+ "DuckDBDriver",
7
+ )
@@ -1,10 +1,11 @@
1
1
  from contextlib import contextmanager
2
- from dataclasses import dataclass
3
- from typing import TYPE_CHECKING, Any, Union, cast
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
4
4
 
5
5
  from duckdb import DuckDBPyConnection
6
- from typing_extensions import NotRequired, TypedDict
6
+ from typing_extensions import Literal, NotRequired, TypedDict
7
7
 
8
+ from sqlspec.adapters.duckdb.driver import DuckDBDriver
8
9
  from sqlspec.base import NoPoolSyncConfig
9
10
  from sqlspec.exceptions import ImproperConfigurationError
10
11
  from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
@@ -13,7 +14,7 @@ if TYPE_CHECKING:
13
14
  from collections.abc import Generator, Sequence
14
15
 
15
16
 
16
- __all__ = ("DuckDBConfig", "ExtensionConfig")
17
+ __all__ = ("DuckDB", "ExtensionConfig")
17
18
 
18
19
 
19
20
  class ExtensionConfig(TypedDict):
@@ -29,6 +30,8 @@ class ExtensionConfig(TypedDict):
29
30
  """The name of the extension to install"""
30
31
  config: "NotRequired[dict[str, Any]]"
31
32
  """Optional configuration settings to apply after installation"""
33
+ install_if_missing: "NotRequired[bool]"
34
+ """Whether to install if missing"""
32
35
  force_install: "NotRequired[bool]"
33
36
  """Whether to force reinstall if already present"""
34
37
  repository: "NotRequired[str]"
@@ -39,8 +42,34 @@ class ExtensionConfig(TypedDict):
39
42
  """Optional version of the extension to install"""
40
43
 
41
44
 
45
+ class SecretConfig(TypedDict):
46
+ """Configuration for a secret to store in a connection.
47
+
48
+ This class provides configuration options for storing a secret in a connection for later retrieval.
49
+
50
+ For details see: https://duckdb.org/docs/stable/configuration/secrets_manager
51
+ """
52
+
53
+ secret_type: Union[
54
+ Literal[
55
+ "azure", "gcs", "s3", "r2", "huggingface", "http", "mysql", "postgres", "bigquery", "openai", "open_prompt" # noqa: PYI051
56
+ ],
57
+ str,
58
+ ]
59
+ provider: NotRequired[str]
60
+ """The provider of the secret"""
61
+ name: str
62
+ """The name of the secret to store"""
63
+ value: dict[str, Any]
64
+ """The secret value to store"""
65
+ persist: NotRequired[bool]
66
+ """Whether to persist the secret"""
67
+ replace_if_exists: NotRequired[bool]
68
+ """Whether to replace the secret if it already exists"""
69
+
70
+
42
71
  @dataclass
43
- class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
72
+ class DuckDB(NoPoolSyncConfig["DuckDBPyConnection", "DuckDBDriver"]):
44
73
  """Configuration for DuckDB database connections.
45
74
 
46
75
  This class provides configuration options for DuckDB database connections, wrapping all parameters
@@ -49,7 +78,7 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
49
78
  For details see: https://duckdb.org/docs/api/python/overview#connection-options
50
79
  """
51
80
 
52
- database: "Union[str, EmptyType]" = Empty
81
+ database: "Union[str, EmptyType]" = field(default=":memory:")
53
82
  """The path to the database file to be opened. Pass ":memory:" to open a connection to a database that resides in RAM instead of on disk. If not specified, an in-memory database will be created."""
54
83
 
55
84
  read_only: "Union[bool, EmptyType]" = Empty
@@ -63,6 +92,18 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
63
92
 
64
93
  extensions: "Union[Sequence[ExtensionConfig], ExtensionConfig, EmptyType]" = Empty
65
94
  """A sequence of extension configurations to install and configure upon connection creation."""
95
+ secrets: "Union[Sequence[SecretConfig], SecretConfig , EmptyType]" = Empty
96
+ """A dictionary of secrets to store in the connection for later retrieval."""
97
+ auto_update_extensions: "bool" = False
98
+ """Whether to automatically update on connection creation"""
99
+ on_connection_create: "Optional[Callable[[DuckDBPyConnection], Optional[DuckDBPyConnection]]]" = None
100
+ """A callable to be called after the connection is created."""
101
+ connection_type: "type[DuckDBPyConnection]" = field(init=False, default_factory=lambda: DuckDBPyConnection)
102
+ """The type of connection to create. Defaults to DuckDBPyConnection."""
103
+ driver_type: "type[DuckDBDriver]" = field(init=False, default_factory=lambda: DuckDBDriver) # type: ignore[type-abstract,unused-ignore]
104
+ """The type of driver to use. Defaults to DuckDBDriver."""
105
+ pool_instance: "None" = field(init=False, default=None)
106
+ """The pool instance to use. Defaults to None."""
66
107
 
67
108
  def __post_init__(self) -> None:
68
109
  """Post-initialization validation and processing.
@@ -73,9 +114,10 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
73
114
  """
74
115
  if self.config is Empty:
75
116
  self.config = {}
76
-
77
117
  if self.extensions is Empty:
78
118
  self.extensions = []
119
+ if self.secrets is Empty:
120
+ self.secrets = []
79
121
  if isinstance(self.extensions, dict):
80
122
  self.extensions = [self.extensions]
81
123
  # this is purely for mypy
@@ -119,9 +161,104 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
119
161
 
120
162
  for extension in cast("list[ExtensionConfig]", self.extensions):
121
163
  self._configure_extension(connection, extension)
164
+ if self.auto_update_extensions:
165
+ connection.execute("update extensions")
122
166
 
123
167
  @staticmethod
124
- def _configure_extension(connection: "DuckDBPyConnection", extension: ExtensionConfig) -> None:
168
+ def _secret_exists(connection: "DuckDBPyConnection", name: "str") -> bool:
169
+ """Check if a secret exists in the connection.
170
+
171
+ Args:
172
+ connection: The DuckDB connection to check for the secret.
173
+ name: The name of the secret to check for.
174
+
175
+ Returns:
176
+ bool: True if the secret exists, False otherwise.
177
+ """
178
+ results = connection.execute("select 1 from duckdb_secrets() where name=?", [name]).fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
179
+ return results is not None
180
+
181
+ @classmethod
182
+ def _is_community_extension(cls, connection: "DuckDBPyConnection", name: "str") -> bool:
183
+ """Check if an extension is a community extension.
184
+
185
+ Args:
186
+ connection: The DuckDB connection to check for the extension.
187
+ name: The name of the extension to check.
188
+
189
+ Returns:
190
+ bool: True if the extension is a community extension, False otherwise.
191
+ """
192
+ results = connection.execute( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
193
+ "select 1 from duckdb_extensions() where extension_name=?", [name]
194
+ ).fetchone()
195
+ return results is None
196
+
197
+ @classmethod
198
+ def _extension_installed(cls, connection: "DuckDBPyConnection", name: "str") -> bool:
199
+ """Check if a extension exists in the connection.
200
+
201
+ Args:
202
+ connection: The DuckDB connection to check for the secret.
203
+ name: The name of the secret to check for.
204
+
205
+ Returns:
206
+ bool: True if the extension is installed, False otherwise.
207
+ """
208
+ results = connection.execute( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
209
+ "select 1 from duckdb_extensions() where extension_name=? and installed=true", [name]
210
+ ).fetchone()
211
+ return results is not None
212
+
213
+ @classmethod
214
+ def _extension_loaded(cls, connection: "DuckDBPyConnection", name: "str") -> bool:
215
+ """Check if a extension is loaded in the connection.
216
+
217
+ Args:
218
+ connection: The DuckDB connection to check for the extension.
219
+ name: The name of the extension to check for.
220
+
221
+ Returns:
222
+ bool: True if the extension is loaded, False otherwise.
223
+ """
224
+ results = connection.execute( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
225
+ "select 1 from duckdb_extensions() where extension_name=? and loaded=true", [name]
226
+ ).fetchone()
227
+ return results is not None
228
+
229
+ @classmethod
230
+ def _configure_secrets(
231
+ cls,
232
+ connection: "DuckDBPyConnection",
233
+ secrets: "Sequence[SecretConfig]",
234
+ ) -> None:
235
+ """Configure persistent secrets for the connection.
236
+
237
+ Args:
238
+ connection: The DuckDB connection to configure secrets for.
239
+ secrets: The list of secrets to store in the connection.
240
+
241
+ Raises:
242
+ ImproperConfigurationError: If a secret could not be stored in the connection.
243
+ """
244
+ try:
245
+ for secret in secrets:
246
+ secret_exists = cls._secret_exists(connection, secret["name"])
247
+ if not secret_exists or secret.get("replace_if_exists", False):
248
+ provider_type = "" if not secret.get("provider") else f"provider {secret.get('provider')},"
249
+ connection.execute(
250
+ f"""create or replace {"persistent" if secret.get("persist", False) else ""} secret {secret["name"]} (
251
+ type {secret["secret_type"]},
252
+ {provider_type}
253
+ {" ,".join([f"{k} '{v}'" for k, v in secret["value"].items()])}
254
+ ) """
255
+ )
256
+ except Exception as e:
257
+ msg = f"Failed to store secret. Error: {e!s}"
258
+ raise ImproperConfigurationError(msg) from e
259
+
260
+ @classmethod
261
+ def _configure_extension(cls, connection: "DuckDBPyConnection", extension: "ExtensionConfig") -> None:
125
262
  """Configure a single extension for the connection.
126
263
 
127
264
  Args:
@@ -132,16 +269,32 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
132
269
  ImproperConfigurationError: If extension installation or configuration fails.
133
270
  """
134
271
  try:
135
- if extension.get("force_install"):
272
+ # Install extension if needed
273
+ if (
274
+ not cls._extension_installed(connection, extension["name"])
275
+ and extension.get("install_if_missing", True)
276
+ ) or extension.get("force_install", False):
277
+ repository = extension.get("repository", None)
278
+ repository_url = (
279
+ "https://community-extensions.duckdb.org"
280
+ if repository is None
281
+ and cls._is_community_extension(connection, extension["name"])
282
+ and extension.get("repository_url") is None
283
+ else extension.get("repository_url", None)
284
+ )
136
285
  connection.install_extension(
137
286
  extension=extension["name"],
138
287
  force_install=extension.get("force_install", False),
139
- repository=extension.get("repository"),
140
- repository_url=extension.get("repository_url"),
288
+ repository=repository,
289
+ repository_url=repository_url,
141
290
  version=extension.get("version"),
142
291
  )
143
- connection.load_extension(extension["name"])
144
292
 
293
+ # Load extension if not already loaded
294
+ if not cls._extension_loaded(connection, extension["name"]):
295
+ connection.load_extension(extension["name"])
296
+
297
+ # Apply any configuration settings
145
298
  if extension.get("config"):
146
299
  for key, value in extension.get("config", {}).items():
147
300
  connection.execute(f"SET {key}={value}")
@@ -156,7 +309,20 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
156
309
  Returns:
157
310
  A string keyed dict of config kwargs for the duckdb.connect() function.
158
311
  """
159
- config = dataclass_to_dict(self, exclude_empty=True, exclude={"extensions"}, convert_nested=False)
312
+ config = dataclass_to_dict(
313
+ self,
314
+ exclude_empty=True,
315
+ exclude={
316
+ "extensions",
317
+ "pool_instance",
318
+ "secrets",
319
+ "on_connection_create",
320
+ "auto_update_extensions",
321
+ "driver_type",
322
+ "connection_type",
323
+ },
324
+ convert_nested=False,
325
+ )
160
326
  if not config.get("database"):
161
327
  config["database"] = ":memory:"
162
328
  return config
@@ -175,7 +341,11 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
175
341
  try:
176
342
  connection = duckdb.connect(**self.connection_config_dict) # pyright: ignore[reportUnknownMemberType]
177
343
  self._configure_extensions(connection)
344
+ self._configure_secrets(connection, cast("list[SecretConfig]", self.secrets))
178
345
  self._configure_connection(connection)
346
+ if self.on_connection_create:
347
+ self.on_connection_create(connection)
348
+
179
349
  except Exception as e:
180
350
  msg = f"Could not configure the DuckDB connection. Error: {e!s}"
181
351
  raise ImproperConfigurationError(msg) from e
@@ -196,3 +366,15 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
196
366
  yield connection
197
367
  finally:
198
368
  connection.close()
369
+
370
+ @contextmanager
371
+ def provide_session(self, *args: Any, **kwargs: Any) -> "Generator[DuckDBDriver, None, None]":
372
+ """Create and provide a database connection.
373
+
374
+ Yields:
375
+ A DuckDB connection instance.
376
+
377
+
378
+ """
379
+ with self.provide_connection(*args, **kwargs) as connection:
380
+ yield self.driver_type(connection, use_cursor=True)