sqlspec 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +40 -7
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +183 -17
- sqlspec/adapters/adbc/driver.py +392 -0
- sqlspec/adapters/aiosqlite/__init__.py +5 -1
- sqlspec/adapters/aiosqlite/config.py +24 -6
- sqlspec/adapters/aiosqlite/driver.py +264 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +71 -11
- sqlspec/adapters/asyncmy/driver.py +246 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +102 -25
- sqlspec/adapters/asyncpg/driver.py +444 -0
- sqlspec/adapters/duckdb/__init__.py +5 -1
- sqlspec/adapters/duckdb/config.py +194 -12
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +7 -4
- sqlspec/adapters/oracledb/config/__init__.py +4 -4
- sqlspec/adapters/oracledb/config/_asyncio.py +96 -12
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +96 -12
- sqlspec/adapters/oracledb/driver.py +571 -0
- sqlspec/adapters/psqlpy/__init__.py +0 -0
- sqlspec/adapters/psqlpy/config.py +258 -0
- sqlspec/adapters/psqlpy/driver.py +335 -0
- sqlspec/adapters/psycopg/__init__.py +16 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +107 -15
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +107 -15
- sqlspec/adapters/psycopg/driver.py +578 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +24 -6
- sqlspec/adapters/sqlite/driver.py +305 -0
- sqlspec/base.py +565 -63
- sqlspec/exceptions.py +30 -0
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +87 -0
- sqlspec/extensions/litestar/handlers.py +213 -0
- sqlspec/extensions/litestar/plugin.py +105 -11
- sqlspec/statement.py +373 -0
- sqlspec/typing.py +81 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/METADATA +4 -1
- sqlspec-0.9.0.dist-info/RECORD +61 -0
- sqlspec-0.7.1.dist-info/RECORD +0 -46
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
|
|
7
|
+
|
|
8
|
+
from adbc_driver_manager.dbapi import Connection, Cursor
|
|
9
|
+
|
|
10
|
+
from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T
|
|
11
|
+
from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
|
|
12
|
+
from sqlspec.statement import SQLStatement
|
|
13
|
+
from sqlspec.typing import ArrowTable, StatementParameterType
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType
|
|
17
|
+
|
|
18
|
+
__all__ = ("AdbcDriver",)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger("sqlspec")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
PARAM_REGEX = re.compile(
|
|
24
|
+
r"""(?<![:\w\$]) # Avoid matching ::, \:, etc. and other vendor prefixes
|
|
25
|
+
(?:
|
|
26
|
+
(?P<dquote>"(?:[^"]|"")*") | # Double-quoted strings
|
|
27
|
+
(?P<squote>'(?:[^']|'')*') | # Single-quoted strings
|
|
28
|
+
(?P<comment>--.*?\n|\/\*.*?\*\/) | # SQL comments
|
|
29
|
+
(?P<lead>[:\$])(?P<var_name>[a-zA-Z_][a-zA-Z0-9_]*) # :name or $name identifier
|
|
30
|
+
)
|
|
31
|
+
""",
|
|
32
|
+
re.VERBOSE | re.DOTALL,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]):
|
|
37
|
+
"""ADBC Sync Driver Adapter."""
|
|
38
|
+
|
|
39
|
+
connection: Connection
|
|
40
|
+
__supports_arrow__: ClassVar[bool] = True
|
|
41
|
+
|
|
42
|
+
def __init__(self, connection: "Connection") -> None:
|
|
43
|
+
"""Initialize the ADBC driver adapter."""
|
|
44
|
+
self.connection = connection
|
|
45
|
+
self.dialect = self._get_dialect(connection)
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911
|
|
49
|
+
"""Get the database dialect based on the driver name.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
connection: The ADBC connection object.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The database dialect.
|
|
56
|
+
"""
|
|
57
|
+
driver_name = connection.adbc_get_info()["vendor_name"].lower()
|
|
58
|
+
if "postgres" in driver_name:
|
|
59
|
+
return "postgres"
|
|
60
|
+
if "bigquery" in driver_name:
|
|
61
|
+
return "bigquery"
|
|
62
|
+
if "sqlite" in driver_name:
|
|
63
|
+
return "sqlite"
|
|
64
|
+
if "duckdb" in driver_name:
|
|
65
|
+
return "duckdb"
|
|
66
|
+
if "mysql" in driver_name:
|
|
67
|
+
return "mysql"
|
|
68
|
+
if "snowflake" in driver_name:
|
|
69
|
+
return "snowflake"
|
|
70
|
+
return "postgres" # default to postgresql dialect
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor":
|
|
74
|
+
return connection.cursor(*args, **kwargs)
|
|
75
|
+
|
|
76
|
+
@contextmanager
|
|
77
|
+
def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]:
|
|
78
|
+
cursor = self._cursor(connection)
|
|
79
|
+
try:
|
|
80
|
+
yield cursor
|
|
81
|
+
finally:
|
|
82
|
+
with contextlib.suppress(Exception):
|
|
83
|
+
cursor.close() # type: ignore[no-untyped-call]
|
|
84
|
+
|
|
85
|
+
def _process_sql_params(
|
|
86
|
+
self,
|
|
87
|
+
sql: str,
|
|
88
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
89
|
+
/,
|
|
90
|
+
**kwargs: Any,
|
|
91
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
92
|
+
# Determine effective parameter type *before* calling SQLStatement
|
|
93
|
+
merged_params_type = dict if kwargs else type(parameters)
|
|
94
|
+
|
|
95
|
+
# If ADBC + sqlite/duckdb + dictionary params, handle conversion manually
|
|
96
|
+
if self.dialect in {"sqlite", "duckdb"} and merged_params_type is dict:
|
|
97
|
+
logger.debug(
|
|
98
|
+
"ADBC/%s with dict params; bypassing SQLStatement conversion, manually converting to '?' positional.",
|
|
99
|
+
self.dialect,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Combine parameters and kwargs into the actual dictionary to use
|
|
103
|
+
parameter_dict = {} # type: ignore[var-annotated]
|
|
104
|
+
if isinstance(parameters, dict):
|
|
105
|
+
parameter_dict.update(parameters)
|
|
106
|
+
if kwargs:
|
|
107
|
+
parameter_dict.update(kwargs)
|
|
108
|
+
|
|
109
|
+
# Define regex locally to find :name or $name
|
|
110
|
+
|
|
111
|
+
processed_sql_parts: list[str] = []
|
|
112
|
+
ordered_params = []
|
|
113
|
+
last_end = 0
|
|
114
|
+
found_params_regex: list[str] = []
|
|
115
|
+
|
|
116
|
+
for match in PARAM_REGEX.finditer(sql): # Use original sql
|
|
117
|
+
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
if match.group("var_name"):
|
|
121
|
+
var_name = match.group("var_name")
|
|
122
|
+
leading_char = match.group("lead") # : or $
|
|
123
|
+
found_params_regex.append(var_name)
|
|
124
|
+
# Use match span directly for replacement
|
|
125
|
+
start = match.start()
|
|
126
|
+
end = match.end()
|
|
127
|
+
|
|
128
|
+
if var_name not in parameter_dict:
|
|
129
|
+
msg = f"Named parameter '{leading_char}{var_name}' found in SQL but not provided. SQL: {sql}"
|
|
130
|
+
raise SQLParsingError(msg)
|
|
131
|
+
|
|
132
|
+
processed_sql_parts.extend((sql[last_end:start], "?")) # Force ? style
|
|
133
|
+
ordered_params.append(parameter_dict[var_name])
|
|
134
|
+
last_end = end
|
|
135
|
+
|
|
136
|
+
processed_sql_parts.append(sql[last_end:])
|
|
137
|
+
|
|
138
|
+
if not found_params_regex and parameter_dict:
|
|
139
|
+
msg = f"ADBC/{self.dialect}: Dict params provided, but no :name or $name placeholders found. SQL: {sql}"
|
|
140
|
+
raise ParameterStyleMismatchError(msg)
|
|
141
|
+
|
|
142
|
+
# Key validation
|
|
143
|
+
provided_keys = set(parameter_dict.keys())
|
|
144
|
+
missing_keys = set(found_params_regex) - provided_keys
|
|
145
|
+
if missing_keys:
|
|
146
|
+
msg = (
|
|
147
|
+
f"Named parameters found in SQL ({found_params_regex}) but not provided: {missing_keys}. SQL: {sql}"
|
|
148
|
+
)
|
|
149
|
+
raise SQLParsingError(msg)
|
|
150
|
+
extra_keys = provided_keys - set(found_params_regex)
|
|
151
|
+
if extra_keys:
|
|
152
|
+
logger.debug("Extra parameters provided for ADBC/%s: %s", self.dialect, extra_keys)
|
|
153
|
+
# Allow extra keys
|
|
154
|
+
|
|
155
|
+
final_sql = "".join(processed_sql_parts)
|
|
156
|
+
final_params = tuple(ordered_params)
|
|
157
|
+
return final_sql, final_params
|
|
158
|
+
# For all other cases (other dialects, or non-dict params for sqlite/duckdb),
|
|
159
|
+
# use the standard SQLStatement processing.
|
|
160
|
+
stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None)
|
|
161
|
+
return stmt.process()
|
|
162
|
+
|
|
163
|
+
def select(
|
|
164
|
+
self,
|
|
165
|
+
sql: str,
|
|
166
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
167
|
+
/,
|
|
168
|
+
*,
|
|
169
|
+
connection: Optional["Connection"] = None,
|
|
170
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> "list[Union[ModelDTOT, dict[str, Any]]]":
|
|
173
|
+
"""Fetch data from the database.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of row data as either model instances or dictionaries.
|
|
177
|
+
"""
|
|
178
|
+
connection = self._connection(connection)
|
|
179
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
180
|
+
with self._with_cursor(connection) as cursor:
|
|
181
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
182
|
+
results = cursor.fetchall() # pyright: ignore
|
|
183
|
+
if not results:
|
|
184
|
+
return []
|
|
185
|
+
|
|
186
|
+
column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
187
|
+
|
|
188
|
+
if schema_type is not None:
|
|
189
|
+
return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
|
|
190
|
+
return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
|
|
191
|
+
|
|
192
|
+
def select_one(
|
|
193
|
+
self,
|
|
194
|
+
sql: str,
|
|
195
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
196
|
+
/,
|
|
197
|
+
*,
|
|
198
|
+
connection: Optional["Connection"] = None,
|
|
199
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
200
|
+
**kwargs: Any,
|
|
201
|
+
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
202
|
+
"""Fetch one row from the database.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
The first row of the query results.
|
|
206
|
+
"""
|
|
207
|
+
connection = self._connection(connection)
|
|
208
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
209
|
+
with self._with_cursor(connection) as cursor:
|
|
210
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
211
|
+
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
212
|
+
result = self.check_not_found(result) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
|
213
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
214
|
+
if schema_type is None:
|
|
215
|
+
return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
216
|
+
return schema_type(**dict(zip(column_names, result))) # type: ignore[return-value]
|
|
217
|
+
|
|
218
|
+
def select_one_or_none(
|
|
219
|
+
self,
|
|
220
|
+
sql: str,
|
|
221
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
222
|
+
/,
|
|
223
|
+
*,
|
|
224
|
+
connection: Optional["Connection"] = None,
|
|
225
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
226
|
+
**kwargs: Any,
|
|
227
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
228
|
+
"""Fetch one row from the database.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
The first row of the query results.
|
|
232
|
+
"""
|
|
233
|
+
connection = self._connection(connection)
|
|
234
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
235
|
+
with self._with_cursor(connection) as cursor:
|
|
236
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
237
|
+
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
238
|
+
if result is None:
|
|
239
|
+
return None
|
|
240
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
241
|
+
if schema_type is None:
|
|
242
|
+
return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
|
|
243
|
+
return schema_type(**dict(zip(column_names, result))) # type: ignore[return-value]
|
|
244
|
+
|
|
245
|
+
def select_value(
|
|
246
|
+
self,
|
|
247
|
+
sql: str,
|
|
248
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
249
|
+
/,
|
|
250
|
+
*,
|
|
251
|
+
connection: Optional["Connection"] = None,
|
|
252
|
+
schema_type: "Optional[type[T]]" = None,
|
|
253
|
+
**kwargs: Any,
|
|
254
|
+
) -> "Union[T, Any]":
|
|
255
|
+
"""Fetch a single value from the database.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
The first value from the first row of results, or None if no results.
|
|
259
|
+
"""
|
|
260
|
+
connection = self._connection(connection)
|
|
261
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
262
|
+
with self._with_cursor(connection) as cursor:
|
|
263
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
264
|
+
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
265
|
+
result = self.check_not_found(result) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
|
266
|
+
if schema_type is None:
|
|
267
|
+
return result[0] # pyright: ignore[reportUnknownVariableType]
|
|
268
|
+
return schema_type(result[0]) # type: ignore[call-arg]
|
|
269
|
+
|
|
270
|
+
def select_value_or_none(
|
|
271
|
+
self,
|
|
272
|
+
sql: str,
|
|
273
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
274
|
+
/,
|
|
275
|
+
*,
|
|
276
|
+
connection: Optional["Connection"] = None,
|
|
277
|
+
schema_type: "Optional[type[T]]" = None,
|
|
278
|
+
**kwargs: Any,
|
|
279
|
+
) -> "Optional[Union[T, Any]]":
|
|
280
|
+
"""Fetch a single value from the database.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
The first value from the first row of results, or None if no results.
|
|
284
|
+
"""
|
|
285
|
+
connection = self._connection(connection)
|
|
286
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
287
|
+
with self._with_cursor(connection) as cursor:
|
|
288
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
289
|
+
result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
290
|
+
if result is None:
|
|
291
|
+
return None
|
|
292
|
+
if schema_type is None:
|
|
293
|
+
return result[0] # pyright: ignore[reportUnknownVariableType]
|
|
294
|
+
return schema_type(result[0]) # type: ignore[call-arg]
|
|
295
|
+
|
|
296
|
+
def insert_update_delete(
|
|
297
|
+
self,
|
|
298
|
+
sql: str,
|
|
299
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
300
|
+
/,
|
|
301
|
+
*,
|
|
302
|
+
connection: Optional["Connection"] = None,
|
|
303
|
+
**kwargs: Any,
|
|
304
|
+
) -> int:
|
|
305
|
+
"""Insert, update, or delete data from the database.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Row count affected by the operation.
|
|
309
|
+
"""
|
|
310
|
+
connection = self._connection(connection)
|
|
311
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
312
|
+
|
|
313
|
+
with self._with_cursor(connection) as cursor:
|
|
314
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
315
|
+
return cursor.rowcount if hasattr(cursor, "rowcount") else -1
|
|
316
|
+
|
|
317
|
+
def insert_update_delete_returning(
|
|
318
|
+
self,
|
|
319
|
+
sql: str,
|
|
320
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
321
|
+
/,
|
|
322
|
+
*,
|
|
323
|
+
connection: Optional["Connection"] = None,
|
|
324
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
325
|
+
**kwargs: Any,
|
|
326
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
327
|
+
"""Insert, update, or delete data from the database and return result.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
The first row of results.
|
|
331
|
+
"""
|
|
332
|
+
connection = self._connection(connection)
|
|
333
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
334
|
+
with self._with_cursor(connection) as cursor:
|
|
335
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
336
|
+
result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
337
|
+
if not result:
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
first_row = result[0]
|
|
341
|
+
|
|
342
|
+
column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
|
343
|
+
|
|
344
|
+
result_dict = dict(zip(column_names, first_row))
|
|
345
|
+
|
|
346
|
+
if schema_type is None:
|
|
347
|
+
return result_dict
|
|
348
|
+
return cast("ModelDTOT", schema_type(**result_dict))
|
|
349
|
+
|
|
350
|
+
def execute_script(
|
|
351
|
+
self,
|
|
352
|
+
sql: str,
|
|
353
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
354
|
+
/,
|
|
355
|
+
*,
|
|
356
|
+
connection: Optional["Connection"] = None,
|
|
357
|
+
**kwargs: Any,
|
|
358
|
+
) -> str:
|
|
359
|
+
"""Execute a script.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Status message for the operation.
|
|
363
|
+
"""
|
|
364
|
+
connection = self._connection(connection)
|
|
365
|
+
sql, parameters = self._process_sql_params(sql, parameters)
|
|
366
|
+
|
|
367
|
+
with self._with_cursor(connection) as cursor:
|
|
368
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
369
|
+
return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
|
|
370
|
+
|
|
371
|
+
# --- Arrow Bulk Operations ---
|
|
372
|
+
|
|
373
|
+
def select_arrow( # pyright: ignore[reportUnknownParameterType]
|
|
374
|
+
self,
|
|
375
|
+
sql: str,
|
|
376
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
377
|
+
/,
|
|
378
|
+
*,
|
|
379
|
+
connection: "Optional[Connection]" = None,
|
|
380
|
+
**kwargs: Any,
|
|
381
|
+
) -> "ArrowTable":
|
|
382
|
+
"""Execute a SQL query and return results as an Apache Arrow Table.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
The results of the query as an Apache Arrow Table.
|
|
386
|
+
"""
|
|
387
|
+
conn = self._connection(connection)
|
|
388
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
389
|
+
|
|
390
|
+
with self._with_cursor(conn) as cursor:
|
|
391
|
+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
|
|
392
|
+
return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType]
|
|
@@ -2,7 +2,10 @@ from contextlib import asynccontextmanager
|
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from aiosqlite import Connection
|
|
6
|
+
|
|
7
|
+
from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver
|
|
8
|
+
from sqlspec.base import NoPoolAsyncConfig
|
|
6
9
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
7
10
|
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
|
|
8
11
|
|
|
@@ -11,13 +14,12 @@ if TYPE_CHECKING:
|
|
|
11
14
|
from sqlite3 import Connection as SQLite3Connection
|
|
12
15
|
from typing import Literal
|
|
13
16
|
|
|
14
|
-
from aiosqlite import Connection
|
|
15
17
|
|
|
16
18
|
__all__ = ("AiosqliteConfig",)
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
@dataclass
|
|
20
|
-
class AiosqliteConfig(
|
|
22
|
+
class AiosqliteConfig(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]):
|
|
21
23
|
"""Configuration for Aiosqlite database connections.
|
|
22
24
|
|
|
23
25
|
This class provides configuration options for Aiosqlite database connections, wrapping all parameters
|
|
@@ -42,6 +44,10 @@ class AiosqliteConfig(NoPoolSyncConfig["Connection"]):
|
|
|
42
44
|
"""The number of statements that SQLite will cache for this connection. The default is 128."""
|
|
43
45
|
uri: "Union[bool, EmptyType]" = field(default=Empty)
|
|
44
46
|
"""If set to True, database is interpreted as a URI with supported options."""
|
|
47
|
+
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection)
|
|
48
|
+
"""Type of the connection object"""
|
|
49
|
+
driver_type: "type[AiosqliteDriver]" = field(init=False, default_factory=lambda: AiosqliteDriver) # type: ignore[type-abstract,unused-ignore]
|
|
50
|
+
"""Type of the driver object"""
|
|
45
51
|
|
|
46
52
|
@property
|
|
47
53
|
def connection_config_dict(self) -> "dict[str, Any]":
|
|
@@ -50,7 +56,9 @@ class AiosqliteConfig(NoPoolSyncConfig["Connection"]):
|
|
|
50
56
|
Returns:
|
|
51
57
|
A string keyed dict of config kwargs for the aiosqlite.connect() function.
|
|
52
58
|
"""
|
|
53
|
-
return dataclass_to_dict(
|
|
59
|
+
return dataclass_to_dict(
|
|
60
|
+
self, exclude_empty=True, convert_nested=False, exclude={"pool_instance", "connection_type", "driver_type"}
|
|
61
|
+
)
|
|
54
62
|
|
|
55
63
|
async def create_connection(self) -> "Connection":
|
|
56
64
|
"""Create and return a new database connection.
|
|
@@ -76,11 +84,21 @@ class AiosqliteConfig(NoPoolSyncConfig["Connection"]):
|
|
|
76
84
|
Yields:
|
|
77
85
|
An Aiosqlite connection instance.
|
|
78
86
|
|
|
79
|
-
Raises:
|
|
80
|
-
ImproperConfigurationError: If the connection could not be established.
|
|
81
87
|
"""
|
|
82
88
|
connection = await self.create_connection()
|
|
83
89
|
try:
|
|
84
90
|
yield connection
|
|
85
91
|
finally:
|
|
86
92
|
await connection.close()
|
|
93
|
+
|
|
94
|
+
@asynccontextmanager
|
|
95
|
+
async def provide_session(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AiosqliteDriver, None]":
|
|
96
|
+
"""Create and provide a database connection.
|
|
97
|
+
|
|
98
|
+
Yields:
|
|
99
|
+
A Aiosqlite driver instance.
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
async with self.provide_connection(*args, **kwargs) as connection:
|
|
104
|
+
yield self.driver_type(connection)
|