sqlspec 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +1 -1
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +160 -17
- sqlspec/adapters/adbc/driver.py +333 -0
- sqlspec/adapters/aiosqlite/__init__.py +6 -2
- sqlspec/adapters/aiosqlite/config.py +25 -7
- sqlspec/adapters/aiosqlite/driver.py +275 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +75 -14
- sqlspec/adapters/asyncmy/driver.py +255 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +99 -20
- sqlspec/adapters/asyncpg/driver.py +288 -0
- sqlspec/adapters/duckdb/__init__.py +6 -2
- sqlspec/adapters/duckdb/config.py +197 -15
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +11 -8
- sqlspec/adapters/oracledb/config/__init__.py +6 -6
- sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +99 -14
- sqlspec/adapters/oracledb/driver.py +498 -0
- sqlspec/adapters/psycopg/__init__.py +11 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +105 -13
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +105 -13
- sqlspec/adapters/psycopg/driver.py +616 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +25 -7
- sqlspec/adapters/sqlite/driver.py +303 -0
- sqlspec/base.py +416 -36
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +81 -0
- sqlspec/extensions/litestar/handlers.py +188 -0
- sqlspec/extensions/litestar/plugin.py +103 -11
- sqlspec/typing.py +72 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/deprecation.py +1 -1
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
- sqlspec-0.8.0.dist-info/RECORD +57 -0
- sqlspec-0.7.0.dist-info/RECORD +0 -46
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -2,7 +2,10 @@ from contextlib import asynccontextmanager
|
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from aiosqlite import Connection
|
|
6
|
+
|
|
7
|
+
from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver
|
|
8
|
+
from sqlspec.base import NoPoolAsyncConfig
|
|
6
9
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
7
10
|
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
|
|
8
11
|
|
|
@@ -11,13 +14,12 @@ if TYPE_CHECKING:
|
|
|
11
14
|
from sqlite3 import Connection as SQLite3Connection
|
|
12
15
|
from typing import Literal
|
|
13
16
|
|
|
14
|
-
from aiosqlite import Connection
|
|
15
17
|
|
|
16
|
-
__all__ = ("
|
|
18
|
+
__all__ = ("Aiosqlite",)
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
@dataclass
|
|
20
|
-
class
|
|
22
|
+
class Aiosqlite(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]):
|
|
21
23
|
"""Configuration for Aiosqlite database connections.
|
|
22
24
|
|
|
23
25
|
This class provides configuration options for Aiosqlite database connections, wrapping all parameters
|
|
@@ -42,6 +44,10 @@ class AiosqliteConfig(NoPoolSyncConfig["Connection"]):
|
|
|
42
44
|
"""The number of statements that SQLite will cache for this connection. The default is 128."""
|
|
43
45
|
uri: "Union[bool, EmptyType]" = field(default=Empty)
|
|
44
46
|
"""If set to True, database is interpreted as a URI with supported options."""
|
|
47
|
+
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection)
|
|
48
|
+
"""Type of the connection object"""
|
|
49
|
+
driver_type: "type[AiosqliteDriver]" = field(init=False, default_factory=lambda: AiosqliteDriver) # type: ignore[type-abstract,unused-ignore]
|
|
50
|
+
"""Type of the driver object"""
|
|
45
51
|
|
|
46
52
|
@property
|
|
47
53
|
def connection_config_dict(self) -> "dict[str, Any]":
|
|
@@ -50,7 +56,9 @@ class AiosqliteConfig(NoPoolSyncConfig["Connection"]):
|
|
|
50
56
|
Returns:
|
|
51
57
|
A string keyed dict of config kwargs for the aiosqlite.connect() function.
|
|
52
58
|
"""
|
|
53
|
-
return dataclass_to_dict(
|
|
59
|
+
return dataclass_to_dict(
|
|
60
|
+
self, exclude_empty=True, convert_nested=False, exclude={"pool_instance", "connection_type", "driver_type"}
|
|
61
|
+
)
|
|
54
62
|
|
|
55
63
|
async def create_connection(self) -> "Connection":
|
|
56
64
|
"""Create and return a new database connection.
|
|
@@ -76,11 +84,21 @@ class AiosqliteConfig(NoPoolSyncConfig["Connection"]):
|
|
|
76
84
|
Yields:
|
|
77
85
|
An Aiosqlite connection instance.
|
|
78
86
|
|
|
79
|
-
Raises:
|
|
80
|
-
ImproperConfigurationError: If the connection could not be established.
|
|
81
87
|
"""
|
|
82
88
|
connection = await self.create_connection()
|
|
83
89
|
try:
|
|
84
90
|
yield connection
|
|
85
91
|
finally:
|
|
86
92
|
await connection.close()
|
|
93
|
+
|
|
94
|
+
@asynccontextmanager
|
|
95
|
+
async def provide_session(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AiosqliteDriver, None]":
|
|
96
|
+
"""Create and provide a database connection.
|
|
97
|
+
|
|
98
|
+
Yields:
|
|
99
|
+
A Aiosqlite driver instance.
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
async with self.provide_connection(*args, **kwargs) as connection:
|
|
104
|
+
yield self.driver_type(connection)
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
3
|
+
|
|
4
|
+
from sqlspec.base import AsyncDriverAdapterProtocol, T
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from collections.abc import AsyncGenerator
|
|
8
|
+
|
|
9
|
+
from aiosqlite import Connection, Cursor
|
|
10
|
+
|
|
11
|
+
from sqlspec.typing import ModelDTOT, StatementParameterType
|
|
12
|
+
|
|
13
|
+
__all__ = ("AiosqliteDriver",)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AiosqliteDriver(AsyncDriverAdapterProtocol["Connection"]):
|
|
17
|
+
"""SQLite Async Driver Adapter."""
|
|
18
|
+
|
|
19
|
+
connection: "Connection"
|
|
20
|
+
|
|
21
|
+
def __init__(self, connection: "Connection") -> None:
|
|
22
|
+
self.connection = connection
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
async def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor":
|
|
26
|
+
return await connection.cursor(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
@asynccontextmanager
|
|
29
|
+
async def _with_cursor(self, connection: "Connection") -> "AsyncGenerator[Cursor, None]":
|
|
30
|
+
cursor = await self._cursor(connection)
|
|
31
|
+
try:
|
|
32
|
+
yield cursor
|
|
33
|
+
finally:
|
|
34
|
+
await cursor.close()
|
|
35
|
+
|
|
36
|
+
def _process_sql_params(
|
|
37
|
+
self, sql: str, parameters: "Optional[StatementParameterType]" = None
|
|
38
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
39
|
+
"""Process SQL query and parameters for DB-API execution.
|
|
40
|
+
|
|
41
|
+
Converts named parameters (:name) to positional parameters (?) for SQLite.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
sql: The SQL query string.
|
|
45
|
+
parameters: The parameters for the query (dict, tuple, list, or None).
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A tuple containing the processed SQL string and the processed parameters.
|
|
49
|
+
"""
|
|
50
|
+
if not isinstance(parameters, dict) or not parameters:
|
|
51
|
+
# If parameters are not a dict, or empty dict, assume positional/no params
|
|
52
|
+
# Let the underlying driver handle tuples/lists directly
|
|
53
|
+
return sql, parameters
|
|
54
|
+
|
|
55
|
+
# Convert named parameters to positional parameters
|
|
56
|
+
processed_sql = sql
|
|
57
|
+
processed_params: list[Any] = []
|
|
58
|
+
for key, value in parameters.items():
|
|
59
|
+
# Replace :key with ? in the SQL
|
|
60
|
+
processed_sql = processed_sql.replace(f":{key}", "?")
|
|
61
|
+
processed_params.append(value)
|
|
62
|
+
|
|
63
|
+
return processed_sql, tuple(processed_params)
|
|
64
|
+
|
|
65
|
+
async def select(
|
|
66
|
+
self,
|
|
67
|
+
sql: str,
|
|
68
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
69
|
+
/,
|
|
70
|
+
connection: Optional["Connection"] = None,
|
|
71
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
72
|
+
) -> "list[Union[ModelDTOT, dict[str, Any]]]":
|
|
73
|
+
"""Fetch data from the database.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
List of row data as either model instances or dictionaries.
|
|
77
|
+
"""
|
|
78
|
+
connection = self._connection(connection)
|
|
79
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
80
|
+
async with self._with_cursor(connection) as cursor:
|
|
81
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
82
|
+
results = await cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
83
|
+
if not results:
|
|
84
|
+
return []
|
|
85
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
86
|
+
if schema_type is None:
|
|
87
|
+
return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType]
|
|
88
|
+
return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType]
|
|
89
|
+
|
|
90
|
+
async def select_one(
|
|
91
|
+
self,
|
|
92
|
+
sql: str,
|
|
93
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
94
|
+
/,
|
|
95
|
+
connection: Optional["Connection"] = None,
|
|
96
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
97
|
+
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
98
|
+
"""Fetch one row from the database.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
The first row of the query results.
|
|
102
|
+
"""
|
|
103
|
+
connection = self._connection(connection)
|
|
104
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
105
|
+
async with self._with_cursor(connection) as cursor:
|
|
106
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
107
|
+
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
108
|
+
result = self.check_not_found(result)
|
|
109
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
110
|
+
if schema_type is None:
|
|
111
|
+
return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
112
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType]
|
|
113
|
+
|
|
114
|
+
async def select_one_or_none(
|
|
115
|
+
self,
|
|
116
|
+
sql: str,
|
|
117
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
118
|
+
/,
|
|
119
|
+
connection: Optional["Connection"] = None,
|
|
120
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
121
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
122
|
+
"""Fetch one row from the database.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
The first row of the query results.
|
|
126
|
+
"""
|
|
127
|
+
connection = self._connection(connection)
|
|
128
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
129
|
+
async with self._with_cursor(connection) as cursor:
|
|
130
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
131
|
+
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
132
|
+
if result is None:
|
|
133
|
+
return None
|
|
134
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
135
|
+
if schema_type is None:
|
|
136
|
+
return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
137
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType]
|
|
138
|
+
|
|
139
|
+
async def select_value(
|
|
140
|
+
self,
|
|
141
|
+
sql: str,
|
|
142
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
143
|
+
/,
|
|
144
|
+
connection: "Optional[Connection]" = None,
|
|
145
|
+
schema_type: "Optional[type[T]]" = None,
|
|
146
|
+
) -> "Union[T, Any]":
|
|
147
|
+
"""Fetch a single value from the database.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
The first value from the first row of results, or None if no results.
|
|
151
|
+
"""
|
|
152
|
+
connection = self._connection(connection)
|
|
153
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
154
|
+
async with self._with_cursor(connection) as cursor:
|
|
155
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
156
|
+
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType]
|
|
157
|
+
result = self.check_not_found(result)
|
|
158
|
+
if schema_type is None:
|
|
159
|
+
return result[0]
|
|
160
|
+
return schema_type(result[0]) # type: ignore[call-arg]
|
|
161
|
+
|
|
162
|
+
async def select_value_or_none(
|
|
163
|
+
self,
|
|
164
|
+
sql: str,
|
|
165
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
166
|
+
/,
|
|
167
|
+
connection: "Optional[Connection]" = None,
|
|
168
|
+
schema_type: "Optional[type[T]]" = None,
|
|
169
|
+
) -> "Optional[Union[T, Any]]":
|
|
170
|
+
"""Fetch a single value from the database.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
The first value from the first row of results, or None if no results.
|
|
174
|
+
"""
|
|
175
|
+
connection = self._connection(connection)
|
|
176
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
177
|
+
|
|
178
|
+
async with self._with_cursor(connection) as cursor:
|
|
179
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
180
|
+
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType]
|
|
181
|
+
if result is None:
|
|
182
|
+
return None
|
|
183
|
+
if schema_type is None:
|
|
184
|
+
return result[0]
|
|
185
|
+
return schema_type(result[0]) # type: ignore[call-arg]
|
|
186
|
+
|
|
187
|
+
async def insert_update_delete(
|
|
188
|
+
self,
|
|
189
|
+
sql: str,
|
|
190
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
191
|
+
/,
|
|
192
|
+
connection: Optional["Connection"] = None,
|
|
193
|
+
) -> int:
|
|
194
|
+
"""Insert, update, or delete data from the database.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Row count affected by the operation.
|
|
198
|
+
"""
|
|
199
|
+
connection = self._connection(connection)
|
|
200
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
201
|
+
|
|
202
|
+
async with self._with_cursor(connection) as cursor:
|
|
203
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
204
|
+
return cursor.rowcount if hasattr(cursor, "rowcount") else -1 # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues]
|
|
205
|
+
|
|
206
|
+
async def insert_update_delete_returning(
|
|
207
|
+
self,
|
|
208
|
+
sql: str,
|
|
209
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
210
|
+
/,
|
|
211
|
+
connection: Optional["Connection"] = None,
|
|
212
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
213
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
214
|
+
"""Insert, update, or delete data from the database and return result.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
The first row of results.
|
|
218
|
+
"""
|
|
219
|
+
connection = self._connection(connection)
|
|
220
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
221
|
+
|
|
222
|
+
async with self._with_cursor(connection) as cursor:
|
|
223
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
224
|
+
results = list(await cursor.fetchall()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
225
|
+
if not results: # Check if empty
|
|
226
|
+
return None
|
|
227
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
228
|
+
if schema_type is not None:
|
|
229
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, results[0])))) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
230
|
+
return dict(zip(column_names, results[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
231
|
+
|
|
232
|
+
async def execute_script(
|
|
233
|
+
self,
|
|
234
|
+
sql: str,
|
|
235
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
236
|
+
/,
|
|
237
|
+
connection: Optional["Connection"] = None,
|
|
238
|
+
) -> str:
|
|
239
|
+
"""Execute a script.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Status message for the operation.
|
|
243
|
+
"""
|
|
244
|
+
connection = self._connection(connection)
|
|
245
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
246
|
+
|
|
247
|
+
async with self._with_cursor(connection) as cursor:
|
|
248
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
249
|
+
return "DONE"
|
|
250
|
+
|
|
251
|
+
async def execute_script_returning(
|
|
252
|
+
self,
|
|
253
|
+
sql: str,
|
|
254
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
255
|
+
/,
|
|
256
|
+
connection: Optional["Connection"] = None,
|
|
257
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
258
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
259
|
+
"""Execute a script and return result.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
The first row of results.
|
|
263
|
+
"""
|
|
264
|
+
connection = self._connection(connection)
|
|
265
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
266
|
+
|
|
267
|
+
async with self._with_cursor(connection) as cursor:
|
|
268
|
+
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
269
|
+
results = list(await cursor.fetchall()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
270
|
+
if not results: # Check if empty
|
|
271
|
+
return None
|
|
272
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
273
|
+
if schema_type is not None:
|
|
274
|
+
return cast("ModelDTOT", schema_type(**dict(zip(column_names, results[0])))) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
275
|
+
return dict(zip(column_names, results[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
-
from sqlspec.adapters.asyncmy.config import
|
|
1
|
+
from sqlspec.adapters.asyncmy.config import Asyncmy, AsyncmyPool
|
|
2
|
+
from sqlspec.adapters.asyncmy.driver import AsyncmyDriver # type: ignore[attr-defined]
|
|
2
3
|
|
|
3
|
-
__all__ = (
|
|
4
|
+
__all__ = (
|
|
5
|
+
"Asyncmy",
|
|
6
|
+
"AsyncmyDriver",
|
|
7
|
+
"AsyncmyPool",
|
|
8
|
+
)
|
|
@@ -1,23 +1,23 @@
|
|
|
1
1
|
from contextlib import asynccontextmanager
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import TYPE_CHECKING, Optional, TypeVar, Union
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
4
4
|
|
|
5
5
|
from asyncmy.connection import Connection # pyright: ignore[reportUnknownVariableType]
|
|
6
|
-
from asyncmy.pool import Pool # pyright: ignore[reportUnknownVariableType]
|
|
7
6
|
|
|
7
|
+
from sqlspec.adapters.asyncmy.driver import AsyncmyDriver # type: ignore[attr-defined]
|
|
8
8
|
from sqlspec.base import AsyncDatabaseConfig, GenericPoolConfig
|
|
9
9
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
10
10
|
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from collections.abc import AsyncGenerator
|
|
14
|
-
from typing import Any
|
|
15
14
|
|
|
16
15
|
from asyncmy.cursors import Cursor, DictCursor # pyright: ignore[reportUnknownVariableType]
|
|
16
|
+
from asyncmy.pool import Pool # pyright: ignore[reportUnknownVariableType]
|
|
17
17
|
|
|
18
18
|
__all__ = (
|
|
19
|
-
"
|
|
20
|
-
"
|
|
19
|
+
"Asyncmy",
|
|
20
|
+
"AsyncmyPool",
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
|
|
@@ -25,7 +25,7 @@ T = TypeVar("T")
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@dataclass
|
|
28
|
-
class
|
|
28
|
+
class AsyncmyPool(GenericPoolConfig):
|
|
29
29
|
"""Configuration for Asyncmy's connection pool.
|
|
30
30
|
|
|
31
31
|
This class provides configuration options for Asyncmy database connection pools.
|
|
@@ -104,20 +104,42 @@ class AsyncmyPoolConfig(GenericPoolConfig):
|
|
|
104
104
|
|
|
105
105
|
|
|
106
106
|
@dataclass
|
|
107
|
-
class
|
|
107
|
+
class Asyncmy(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]):
|
|
108
108
|
"""Asyncmy Configuration."""
|
|
109
109
|
|
|
110
110
|
__is_async__ = True
|
|
111
111
|
__supports_connection_pooling__ = True
|
|
112
112
|
|
|
113
|
-
pool_config: "Optional[
|
|
113
|
+
pool_config: "Optional[AsyncmyPool]" = None
|
|
114
114
|
"""Asyncmy Pool configuration"""
|
|
115
|
-
|
|
115
|
+
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # pyright: ignore
|
|
116
|
+
"""Type of the connection object"""
|
|
117
|
+
driver_type: "type[AsyncmyDriver]" = field(init=False, default_factory=lambda: AsyncmyDriver)
|
|
118
|
+
"""Type of the driver object"""
|
|
116
119
|
pool_instance: "Optional[Pool]" = None # pyright: ignore[reportUnknownVariableType]
|
|
117
|
-
"""
|
|
120
|
+
"""Instance of the pool"""
|
|
118
121
|
|
|
119
|
-
|
|
120
|
-
""
|
|
122
|
+
@property
|
|
123
|
+
def connection_config_dict(self) -> "dict[str, Any]":
|
|
124
|
+
"""Return the connection configuration as a dict.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A string keyed dict of config kwargs for the Asyncmy connect function.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ImproperConfigurationError: If the connection configuration is not provided.
|
|
131
|
+
"""
|
|
132
|
+
if self.pool_config:
|
|
133
|
+
# Filter out pool-specific parameters
|
|
134
|
+
pool_only_params = {"minsize", "maxsize", "echo", "pool_recycle"}
|
|
135
|
+
return dataclass_to_dict(
|
|
136
|
+
self.pool_config,
|
|
137
|
+
exclude_empty=True,
|
|
138
|
+
convert_nested=False,
|
|
139
|
+
exclude=pool_only_params.union({"pool_instance", "driver_type", "connection_type"}),
|
|
140
|
+
)
|
|
141
|
+
msg = "You must provide a 'pool_config' for this adapter."
|
|
142
|
+
raise ImproperConfigurationError(msg)
|
|
121
143
|
|
|
122
144
|
@property
|
|
123
145
|
def pool_config_dict(self) -> "dict[str, Any]":
|
|
@@ -130,10 +152,32 @@ class AsyncMyConfig(AsyncDatabaseConfig[Connection, Pool]):
|
|
|
130
152
|
ImproperConfigurationError: If the pool configuration is not provided.
|
|
131
153
|
"""
|
|
132
154
|
if self.pool_config:
|
|
133
|
-
return dataclass_to_dict(
|
|
155
|
+
return dataclass_to_dict(
|
|
156
|
+
self.pool_config,
|
|
157
|
+
exclude_empty=True,
|
|
158
|
+
convert_nested=False,
|
|
159
|
+
exclude={"pool_instance", "driver_type", "connection_type"},
|
|
160
|
+
)
|
|
134
161
|
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
|
|
135
162
|
raise ImproperConfigurationError(msg)
|
|
136
163
|
|
|
164
|
+
async def create_connection(self) -> "Connection": # pyright: ignore[reportUnknownParameterType]
|
|
165
|
+
"""Create and return a new asyncmy connection.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A Connection instance.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
ImproperConfigurationError: If the connection could not be created.
|
|
172
|
+
"""
|
|
173
|
+
try:
|
|
174
|
+
import asyncmy # pyright: ignore[reportMissingTypeStubs]
|
|
175
|
+
|
|
176
|
+
return await asyncmy.connect(**self.connection_config_dict) # pyright: ignore
|
|
177
|
+
except Exception as e:
|
|
178
|
+
msg = f"Could not configure the Asyncmy connection. Error: {e!s}"
|
|
179
|
+
raise ImproperConfigurationError(msg) from e
|
|
180
|
+
|
|
137
181
|
async def create_pool(self) -> "Pool": # pyright: ignore[reportUnknownParameterType]
|
|
138
182
|
"""Return a pool. If none exists yet, create one.
|
|
139
183
|
|
|
@@ -179,3 +223,20 @@ class AsyncMyConfig(AsyncDatabaseConfig[Connection, Pool]):
|
|
|
179
223
|
pool = await self.provide_pool(*args, **kwargs) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
|
180
224
|
async with pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
181
225
|
yield connection # pyright: ignore[reportUnknownMemberType]
|
|
226
|
+
|
|
227
|
+
@asynccontextmanager
|
|
228
|
+
async def provide_session(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncmyDriver, None]":
|
|
229
|
+
"""Create and provide a database session.
|
|
230
|
+
|
|
231
|
+
Yields:
|
|
232
|
+
An Asyncmy driver instance.
|
|
233
|
+
|
|
234
|
+
"""
|
|
235
|
+
async with self.provide_connection(*args, **kwargs) as connection: # pyright: ignore[reportUnknownVariableType]
|
|
236
|
+
yield self.driver_type(connection) # pyright: ignore[reportUnknownArgumentType]
|
|
237
|
+
|
|
238
|
+
async def close_pool(self) -> None:
|
|
239
|
+
"""Close the connection pool."""
|
|
240
|
+
if self.pool_instance is not None: # pyright: ignore[reportUnknownMemberType]
|
|
241
|
+
await self.pool_instance.close() # pyright: ignore[reportUnknownMemberType]
|
|
242
|
+
self.pool_instance = None
|