sqlspec 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +1 -1
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +160 -17
- sqlspec/adapters/adbc/driver.py +333 -0
- sqlspec/adapters/aiosqlite/__init__.py +6 -2
- sqlspec/adapters/aiosqlite/config.py +25 -7
- sqlspec/adapters/aiosqlite/driver.py +275 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +75 -14
- sqlspec/adapters/asyncmy/driver.py +255 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +99 -20
- sqlspec/adapters/asyncpg/driver.py +288 -0
- sqlspec/adapters/duckdb/__init__.py +6 -2
- sqlspec/adapters/duckdb/config.py +195 -13
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +11 -8
- sqlspec/adapters/oracledb/config/__init__.py +6 -6
- sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +99 -14
- sqlspec/adapters/oracledb/driver.py +498 -0
- sqlspec/adapters/psycopg/__init__.py +11 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +105 -13
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +105 -13
- sqlspec/adapters/psycopg/driver.py +616 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +25 -7
- sqlspec/adapters/sqlite/driver.py +303 -0
- sqlspec/base.py +416 -36
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +81 -0
- sqlspec/extensions/litestar/handlers.py +188 -0
- sqlspec/extensions/litestar/plugin.py +100 -11
- sqlspec/typing.py +72 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
- sqlspec-0.8.0.dist-info/RECORD +57 -0
- sqlspec-0.7.1.dist-info/RECORD +0 -46
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
from collections.abc import AsyncGenerator
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
5
|
+
|
|
6
|
+
from sqlspec.base import AsyncDriverAdapterProtocol, T
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from asyncmy import Connection
|
|
10
|
+
from asyncmy.cursors import Cursor
|
|
11
|
+
|
|
12
|
+
from sqlspec.typing import ModelDTOT, StatementParameterType
|
|
13
|
+
|
|
14
|
+
__all__ = ("AsyncmyDriver",)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AsyncmyDriver(AsyncDriverAdapterProtocol["Connection"]):
|
|
18
|
+
"""Asyncmy MySQL/MariaDB Driver Adapter."""
|
|
19
|
+
|
|
20
|
+
connection: "Connection"
|
|
21
|
+
|
|
22
|
+
def __init__(self, connection: "Connection") -> None:
|
|
23
|
+
self.connection = connection
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
async def _cursor(connection: "Connection") -> "Cursor":
|
|
27
|
+
return await connection.cursor()
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
@asynccontextmanager
|
|
31
|
+
async def _with_cursor(connection: "Connection") -> AsyncGenerator["Cursor", None]:
|
|
32
|
+
cursor = connection.cursor()
|
|
33
|
+
try:
|
|
34
|
+
yield cursor
|
|
35
|
+
finally:
|
|
36
|
+
await cursor.close()
|
|
37
|
+
|
|
38
|
+
async def select(
|
|
39
|
+
self,
|
|
40
|
+
sql: str,
|
|
41
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
42
|
+
/,
|
|
43
|
+
connection: Optional["Connection"] = None,
|
|
44
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
45
|
+
) -> "list[Union[ModelDTOT, dict[str, Any]]]":
|
|
46
|
+
"""Fetch data from the database.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
List of row data as either model instances or dictionaries.
|
|
50
|
+
"""
|
|
51
|
+
connection = self._connection(connection)
|
|
52
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
53
|
+
async with self._with_cursor(connection) as cursor:
|
|
54
|
+
await cursor.execute(sql, parameters)
|
|
55
|
+
results = await cursor.fetchall()
|
|
56
|
+
if not results:
|
|
57
|
+
return []
|
|
58
|
+
column_names = [c[0] for c in cursor.description or []]
|
|
59
|
+
if schema_type is None:
|
|
60
|
+
return [dict(zip(column_names, row)) for row in results]
|
|
61
|
+
return [schema_type(**dict(zip(column_names, row))) for row in results]
|
|
62
|
+
|
|
63
|
+
async def select_one(
|
|
64
|
+
self,
|
|
65
|
+
sql: str,
|
|
66
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
67
|
+
/,
|
|
68
|
+
connection: Optional["Connection"] = None,
|
|
69
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
70
|
+
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
71
|
+
"""Fetch one row from the database.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The first row of the query results.
|
|
75
|
+
"""
|
|
76
|
+
connection = self._connection(connection)
|
|
77
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
78
|
+
async with self._with_cursor(connection) as cursor:
|
|
79
|
+
await cursor.execute(sql, parameters)
|
|
80
|
+
result = await cursor.fetchone()
|
|
81
|
+
result = self.check_not_found(result)
|
|
82
|
+
column_names = [c[0] for c in cursor.description or []]
|
|
83
|
+
if schema_type is None:
|
|
84
|
+
return dict(zip(column_names, result))
|
|
85
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result))))
|
|
86
|
+
|
|
87
|
+
async def select_one_or_none(
|
|
88
|
+
self,
|
|
89
|
+
sql: str,
|
|
90
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
91
|
+
/,
|
|
92
|
+
connection: Optional["Connection"] = None,
|
|
93
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
94
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
95
|
+
"""Fetch one row from the database.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The first row of the query results.
|
|
99
|
+
"""
|
|
100
|
+
connection = self._connection(connection)
|
|
101
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
102
|
+
async with self._with_cursor(connection) as cursor:
|
|
103
|
+
await cursor.execute(sql, parameters)
|
|
104
|
+
result = await cursor.fetchone()
|
|
105
|
+
if result is None:
|
|
106
|
+
return None
|
|
107
|
+
column_names = [c[0] for c in cursor.description or []]
|
|
108
|
+
if schema_type is None:
|
|
109
|
+
return dict(zip(column_names, result))
|
|
110
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result))))
|
|
111
|
+
|
|
112
|
+
async def select_value(
|
|
113
|
+
self,
|
|
114
|
+
sql: str,
|
|
115
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
116
|
+
/,
|
|
117
|
+
connection: "Optional[Connection]" = None,
|
|
118
|
+
schema_type: "Optional[type[T]]" = None,
|
|
119
|
+
) -> "Union[T, Any]":
|
|
120
|
+
"""Fetch a single value from the database.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The first value from the first row of results, or None if no results.
|
|
124
|
+
"""
|
|
125
|
+
connection = self._connection(connection)
|
|
126
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
127
|
+
|
|
128
|
+
async with self._with_cursor(connection) as cursor:
|
|
129
|
+
await cursor.execute(sql, parameters)
|
|
130
|
+
result = await cursor.fetchone()
|
|
131
|
+
result = self.check_not_found(result)
|
|
132
|
+
|
|
133
|
+
value = result[0]
|
|
134
|
+
if schema_type is not None:
|
|
135
|
+
return schema_type(value) # type: ignore[call-arg]
|
|
136
|
+
return value
|
|
137
|
+
|
|
138
|
+
async def select_value_or_none(
|
|
139
|
+
self,
|
|
140
|
+
sql: str,
|
|
141
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
142
|
+
/,
|
|
143
|
+
connection: "Optional[Connection]" = None,
|
|
144
|
+
schema_type: "Optional[type[T]]" = None,
|
|
145
|
+
) -> "Optional[Union[T, Any]]":
|
|
146
|
+
"""Fetch a single value from the database.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
The first value from the first row of results, or None if no results.
|
|
150
|
+
"""
|
|
151
|
+
connection = self._connection(connection)
|
|
152
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
153
|
+
|
|
154
|
+
async with self._with_cursor(connection) as cursor:
|
|
155
|
+
await cursor.execute(sql, parameters)
|
|
156
|
+
result = await cursor.fetchone()
|
|
157
|
+
|
|
158
|
+
if result is None:
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
value = result[0]
|
|
162
|
+
if schema_type is not None:
|
|
163
|
+
return schema_type(value) # type: ignore[call-arg]
|
|
164
|
+
return value
|
|
165
|
+
|
|
166
|
+
async def insert_update_delete(
|
|
167
|
+
self,
|
|
168
|
+
sql: str,
|
|
169
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
170
|
+
/,
|
|
171
|
+
connection: Optional["Connection"] = None,
|
|
172
|
+
) -> int:
|
|
173
|
+
"""Insert, update, or delete data from the database.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Row count affected by the operation.
|
|
177
|
+
"""
|
|
178
|
+
connection = self._connection(connection)
|
|
179
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
180
|
+
|
|
181
|
+
async with self._with_cursor(connection) as cursor:
|
|
182
|
+
await cursor.execute(sql, parameters)
|
|
183
|
+
return cursor.rowcount
|
|
184
|
+
|
|
185
|
+
async def insert_update_delete_returning(
|
|
186
|
+
self,
|
|
187
|
+
sql: str,
|
|
188
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
189
|
+
/,
|
|
190
|
+
connection: Optional["Connection"] = None,
|
|
191
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
192
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
193
|
+
"""Insert, update, or delete data from the database and return result.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
The first row of results.
|
|
197
|
+
"""
|
|
198
|
+
connection = self._connection(connection)
|
|
199
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
200
|
+
column_names: list[str] = []
|
|
201
|
+
|
|
202
|
+
async with self._with_cursor(connection) as cursor:
|
|
203
|
+
await cursor.execute(sql, parameters)
|
|
204
|
+
result = await cursor.fetchone()
|
|
205
|
+
if result is None:
|
|
206
|
+
return None
|
|
207
|
+
column_names = [c[0] for c in cursor.description or []]
|
|
208
|
+
if schema_type is not None:
|
|
209
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result))))
|
|
210
|
+
return dict(zip(column_names, result))
|
|
211
|
+
|
|
212
|
+
async def execute_script(
|
|
213
|
+
self,
|
|
214
|
+
sql: str,
|
|
215
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
216
|
+
/,
|
|
217
|
+
connection: Optional["Connection"] = None,
|
|
218
|
+
) -> str:
|
|
219
|
+
"""Execute a script.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Status message for the operation.
|
|
223
|
+
"""
|
|
224
|
+
connection = self._connection(connection)
|
|
225
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
226
|
+
|
|
227
|
+
async with self._with_cursor(connection) as cursor:
|
|
228
|
+
await cursor.execute(sql, parameters)
|
|
229
|
+
return "DONE"
|
|
230
|
+
|
|
231
|
+
async def execute_script_returning(
|
|
232
|
+
self,
|
|
233
|
+
sql: str,
|
|
234
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
235
|
+
/,
|
|
236
|
+
connection: Optional["Connection"] = None,
|
|
237
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
238
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
239
|
+
"""Execute a script and return result.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
The first row of results.
|
|
243
|
+
"""
|
|
244
|
+
connection = self._connection(connection)
|
|
245
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
246
|
+
|
|
247
|
+
async with self._with_cursor(connection) as cursor:
|
|
248
|
+
await cursor.execute(sql, parameters)
|
|
249
|
+
result = await cursor.fetchone()
|
|
250
|
+
if result is None:
|
|
251
|
+
return None
|
|
252
|
+
column_names = [c[0] for c in cursor.description or []]
|
|
253
|
+
if schema_type is not None:
|
|
254
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result))))
|
|
255
|
+
return dict(zip(column_names, result))
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from contextlib import asynccontextmanager
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
4
4
|
|
|
5
5
|
from asyncpg import Record
|
|
6
6
|
from asyncpg import create_pool as asyncpg_create_pool
|
|
7
|
-
from asyncpg.pool import
|
|
7
|
+
from asyncpg.pool import PoolConnectionProxy
|
|
8
8
|
from typing_extensions import TypeAlias
|
|
9
9
|
|
|
10
10
|
from sqlspec._serialization import decode_json, encode_json
|
|
11
|
+
from sqlspec.adapters.asyncpg.driver import AsyncpgDriver
|
|
11
12
|
from sqlspec.base import AsyncDatabaseConfig, GenericPoolConfig
|
|
12
13
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
13
14
|
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
|
|
@@ -17,21 +18,22 @@ if TYPE_CHECKING:
|
|
|
17
18
|
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
|
|
18
19
|
|
|
19
20
|
from asyncpg.connection import Connection
|
|
21
|
+
from asyncpg.pool import Pool
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
__all__ = (
|
|
23
|
-
"
|
|
24
|
-
"
|
|
25
|
+
"Asyncpg",
|
|
26
|
+
"AsyncpgPool",
|
|
25
27
|
)
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
T = TypeVar("T")
|
|
29
31
|
|
|
30
|
-
PgConnection: TypeAlias = "Union[Connection, PoolConnectionProxy]"
|
|
32
|
+
PgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]"
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
@dataclass
|
|
34
|
-
class
|
|
36
|
+
class AsyncpgPool(GenericPoolConfig):
|
|
35
37
|
"""Configuration for Asyncpg's :class:`Pool <asyncpg.pool.Pool>`.
|
|
36
38
|
|
|
37
39
|
For details see: https://magicstack.github.io/asyncpg/current/api/index.html#connection-pools
|
|
@@ -52,7 +54,7 @@ class AsyncPgPoolConfig(GenericPoolConfig):
|
|
|
52
54
|
min_size: "Union[int, EmptyType]" = Empty
|
|
53
55
|
"""The number of connections to keep open inside the connection pool."""
|
|
54
56
|
max_size: "Union[int, EmptyType]" = Empty
|
|
55
|
-
"""The number of connections to allow in connection pool
|
|
57
|
+
"""The number of connections to allow in connection pool "overflow", that is connections that can be opened above
|
|
56
58
|
and beyond the pool_size setting, which defaults to 10."""
|
|
57
59
|
|
|
58
60
|
max_queries: "Union[int, EmptyType]" = Empty
|
|
@@ -71,10 +73,10 @@ class AsyncPgPoolConfig(GenericPoolConfig):
|
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
@dataclass
|
|
74
|
-
class
|
|
76
|
+
class Asyncpg(AsyncDatabaseConfig["PgConnection", "Pool", "AsyncpgDriver"]): # pyright: ignore[reportMissingTypeArgument]
|
|
75
77
|
"""Asyncpg Configuration."""
|
|
76
78
|
|
|
77
|
-
pool_config: "Optional[
|
|
79
|
+
pool_config: "Optional[AsyncpgPool]" = None
|
|
78
80
|
"""Asyncpg Pool configuration"""
|
|
79
81
|
json_deserializer: "Callable[[str], Any]" = decode_json
|
|
80
82
|
"""For dialects that support the :class:`JSON <sqlalchemy.types.JSON>` datatype, this is a Python callable that will
|
|
@@ -83,11 +85,41 @@ class AsyncPgConfig(AsyncDatabaseConfig[PgConnection, Pool]): # pyright: ignore
|
|
|
83
85
|
json_serializer: "Callable[[Any], str]" = encode_json
|
|
84
86
|
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
|
|
85
87
|
By default, SQLSpec's :attr:`encode_json() <sqlspec._serialization.encode_json>` is used."""
|
|
88
|
+
connection_type: "type[PgConnection]" = field(init=False, default_factory=lambda: PoolConnectionProxy)
|
|
89
|
+
"""Type of the connection object"""
|
|
90
|
+
driver_type: "type[AsyncpgDriver]" = field(init=False, default_factory=lambda: AsyncpgDriver) # type: ignore[type-abstract,unused-ignore]
|
|
91
|
+
"""Type of the driver object"""
|
|
86
92
|
pool_instance: "Optional[Pool[Any]]" = None
|
|
87
|
-
"""
|
|
93
|
+
"""The connection pool instance. If set, this will be used instead of creating a new pool."""
|
|
88
94
|
|
|
89
|
-
|
|
90
|
-
""
|
|
95
|
+
@property
|
|
96
|
+
def connection_config_dict(self) -> "dict[str, Any]":
|
|
97
|
+
"""Return the connection configuration as a dict.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A string keyed dict of config kwargs for the asyncpg.connect function.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
ImproperConfigurationError: If the connection configuration is not provided.
|
|
104
|
+
"""
|
|
105
|
+
if self.pool_config:
|
|
106
|
+
connect_dict: dict[str, Any] = {}
|
|
107
|
+
|
|
108
|
+
# Add dsn if available
|
|
109
|
+
if hasattr(self.pool_config, "dsn"):
|
|
110
|
+
connect_dict["dsn"] = self.pool_config.dsn
|
|
111
|
+
|
|
112
|
+
# Add any connect_kwargs if available
|
|
113
|
+
if (
|
|
114
|
+
hasattr(self.pool_config, "connect_kwargs")
|
|
115
|
+
and self.pool_config.connect_kwargs is not Empty
|
|
116
|
+
and isinstance(self.pool_config.connect_kwargs, dict)
|
|
117
|
+
):
|
|
118
|
+
connect_dict.update(dict(self.pool_config.connect_kwargs.items()))
|
|
119
|
+
|
|
120
|
+
return connect_dict
|
|
121
|
+
msg = "You must provide a 'pool_config' for this adapter."
|
|
122
|
+
raise ImproperConfigurationError(msg)
|
|
91
123
|
|
|
92
124
|
@property
|
|
93
125
|
def pool_config_dict(self) -> "dict[str, Any]":
|
|
@@ -96,9 +128,17 @@ class AsyncPgConfig(AsyncDatabaseConfig[PgConnection, Pool]): # pyright: ignore
|
|
|
96
128
|
Returns:
|
|
97
129
|
A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <asyncpg.pool.create_pool>`
|
|
98
130
|
function.
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
ImproperConfigurationError: If no pool_config is provided but a pool_instance is set.
|
|
99
134
|
"""
|
|
100
135
|
if self.pool_config:
|
|
101
|
-
return dataclass_to_dict(
|
|
136
|
+
return dataclass_to_dict(
|
|
137
|
+
self.pool_config,
|
|
138
|
+
exclude_empty=True,
|
|
139
|
+
exclude={"pool_instance", "driver_type", "connection_type"},
|
|
140
|
+
convert_nested=False,
|
|
141
|
+
)
|
|
102
142
|
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
|
|
103
143
|
raise ImproperConfigurationError(msg)
|
|
104
144
|
|
|
@@ -107,6 +147,10 @@ class AsyncPgConfig(AsyncDatabaseConfig[PgConnection, Pool]): # pyright: ignore
|
|
|
107
147
|
|
|
108
148
|
Returns:
|
|
109
149
|
Getter that returns the pool instance used by the plugin.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ImproperConfigurationError: If neither pool_config nor pool_instance are provided,
|
|
153
|
+
or if the pool could not be configured.
|
|
110
154
|
"""
|
|
111
155
|
if self.pool_instance is not None:
|
|
112
156
|
return self.pool_instance
|
|
@@ -117,11 +161,9 @@ class AsyncPgConfig(AsyncDatabaseConfig[PgConnection, Pool]): # pyright: ignore
|
|
|
117
161
|
|
|
118
162
|
pool_config = self.pool_config_dict
|
|
119
163
|
self.pool_instance = await asyncpg_create_pool(**pool_config)
|
|
120
|
-
if self.pool_instance is None:
|
|
121
|
-
msg = "Could not configure the 'pool_instance'. Please check your configuration."
|
|
122
|
-
raise ImproperConfigurationError(
|
|
123
|
-
msg,
|
|
124
|
-
)
|
|
164
|
+
if self.pool_instance is None: # pyright: ignore[reportUnnecessaryComparison]
|
|
165
|
+
msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
|
|
166
|
+
raise ImproperConfigurationError(msg)
|
|
125
167
|
return self.pool_instance
|
|
126
168
|
|
|
127
169
|
def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[Pool]": # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
|
|
@@ -132,13 +174,50 @@ class AsyncPgConfig(AsyncDatabaseConfig[PgConnection, Pool]): # pyright: ignore
|
|
|
132
174
|
"""
|
|
133
175
|
return self.create_pool() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
134
176
|
|
|
177
|
+
async def create_connection(self) -> "PgConnection":
|
|
178
|
+
"""Create and return a new asyncpg connection.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
A Connection instance.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ImproperConfigurationError: If the connection could not be created.
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
import asyncpg
|
|
188
|
+
|
|
189
|
+
return await asyncpg.connect(**self.connection_config_dict) # type: ignore[no-any-return]
|
|
190
|
+
except Exception as e:
|
|
191
|
+
msg = f"Could not configure the asyncpg connection. Error: {e!s}"
|
|
192
|
+
raise ImproperConfigurationError(msg) from e
|
|
193
|
+
|
|
135
194
|
@asynccontextmanager
|
|
136
|
-
async def provide_connection(
|
|
195
|
+
async def provide_connection(
|
|
196
|
+
self, *args: "Any", **kwargs: "Any"
|
|
197
|
+
) -> "AsyncGenerator[PoolConnectionProxy[Any], None]": # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
|
|
137
198
|
"""Create a connection instance.
|
|
138
199
|
|
|
139
|
-
|
|
200
|
+
Yields:
|
|
140
201
|
A connection instance.
|
|
141
202
|
"""
|
|
142
203
|
db_pool = await self.provide_pool(*args, **kwargs) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
143
204
|
async with db_pool.acquire() as connection: # pyright: ignore[reportUnknownVariableType]
|
|
144
205
|
yield connection
|
|
206
|
+
|
|
207
|
+
async def close_pool(self) -> None:
|
|
208
|
+
"""Close the pool."""
|
|
209
|
+
if self.pool_instance is not None:
|
|
210
|
+
await self.pool_instance.close()
|
|
211
|
+
self.pool_instance = None
|
|
212
|
+
|
|
213
|
+
@asynccontextmanager
|
|
214
|
+
async def provide_session(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncpgDriver, None]":
|
|
215
|
+
"""Create and provide a database session.
|
|
216
|
+
|
|
217
|
+
Yields:
|
|
218
|
+
A Aiosqlite driver instance.
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
"""
|
|
222
|
+
async with self.provide_connection(*args, **kwargs) as connection:
|
|
223
|
+
yield self.driver_type(connection)
|