sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -2,12 +2,14 @@ import contextlib
|
|
|
2
2
|
import logging
|
|
3
3
|
from collections.abc import Iterator
|
|
4
4
|
from contextlib import contextmanager
|
|
5
|
+
from dataclasses import replace
|
|
5
6
|
from decimal import Decimal
|
|
6
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Optional,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
|
|
7
8
|
|
|
8
9
|
from adbc_driver_manager.dbapi import Connection, Cursor
|
|
9
10
|
|
|
10
11
|
from sqlspec.driver import SyncDriverAdapterProtocol
|
|
12
|
+
from sqlspec.driver.connection import managed_transaction_sync
|
|
11
13
|
from sqlspec.driver.mixins import (
|
|
12
14
|
SQLTranslatorMixin,
|
|
13
15
|
SyncPipelinedExecutionMixin,
|
|
@@ -15,11 +17,12 @@ from sqlspec.driver.mixins import (
|
|
|
15
17
|
ToSchemaMixin,
|
|
16
18
|
TypeCoercionMixin,
|
|
17
19
|
)
|
|
20
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
18
21
|
from sqlspec.exceptions import wrap_exceptions
|
|
19
22
|
from sqlspec.statement.parameters import ParameterStyle
|
|
20
|
-
from sqlspec.statement.result import ArrowResult,
|
|
23
|
+
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
21
24
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
22
|
-
from sqlspec.typing import DictRow,
|
|
25
|
+
from sqlspec.typing import DictRow, RowT
|
|
23
26
|
from sqlspec.utils.serializers import to_json
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
@@ -65,8 +68,15 @@ class AdbcDriver(
|
|
|
65
68
|
config: "Optional[SQLConfig]" = None,
|
|
66
69
|
default_row_type: "type[DictRow]" = DictRow,
|
|
67
70
|
) -> None:
|
|
71
|
+
dialect = self._get_dialect(connection)
|
|
72
|
+
if config and not config.dialect:
|
|
73
|
+
config = replace(config, dialect=dialect)
|
|
74
|
+
elif not config:
|
|
75
|
+
# Create config with dialect
|
|
76
|
+
config = SQLConfig(dialect=dialect)
|
|
77
|
+
|
|
68
78
|
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
69
|
-
self.dialect: DialectType =
|
|
79
|
+
self.dialect: DialectType = dialect
|
|
70
80
|
self.default_parameter_style = self._get_parameter_style_for_dialect(self.dialect)
|
|
71
81
|
# Override supported parameter styles based on actual dialect capabilities
|
|
72
82
|
self.supported_parameter_styles = self._get_supported_parameter_styles_for_dialect(self.dialect)
|
|
@@ -169,13 +179,13 @@ class AdbcDriver(
|
|
|
169
179
|
|
|
170
180
|
def _execute_statement(
|
|
171
181
|
self, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
172
|
-
) ->
|
|
182
|
+
) -> SQLResult[RowT]:
|
|
173
183
|
if statement.is_script:
|
|
174
184
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
175
185
|
return self._execute_script(sql, connection=connection, **kwargs)
|
|
176
186
|
|
|
177
|
-
# Determine if we need to convert parameter style
|
|
178
187
|
detected_styles = {p.style for p in statement.parameter_info}
|
|
188
|
+
|
|
179
189
|
target_style = self.default_parameter_style
|
|
180
190
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
181
191
|
|
|
@@ -196,69 +206,107 @@ class AdbcDriver(
|
|
|
196
206
|
|
|
197
207
|
def _execute(
|
|
198
208
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
199
|
-
) ->
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
209
|
+
) -> SQLResult[RowT]:
|
|
210
|
+
# Use provided connection or driver's default connection
|
|
211
|
+
conn = connection if connection is not None else self._connection(None)
|
|
212
|
+
|
|
213
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
214
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
215
|
+
if normalized_params is not None and not isinstance(normalized_params, (list, tuple)):
|
|
216
|
+
cursor_params = [normalized_params]
|
|
205
217
|
else:
|
|
206
|
-
cursor_params =
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
218
|
+
cursor_params = normalized_params
|
|
219
|
+
|
|
220
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
221
|
+
try:
|
|
222
|
+
cursor.execute(sql, cursor_params or [])
|
|
223
|
+
except Exception as e:
|
|
224
|
+
# Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
|
|
225
|
+
if self.dialect == "postgres":
|
|
226
|
+
with contextlib.suppress(Exception):
|
|
227
|
+
cursor.execute("ROLLBACK")
|
|
228
|
+
raise e from e
|
|
229
|
+
|
|
230
|
+
if self.returns_rows(statement.expression):
|
|
231
|
+
fetched_data = cursor.fetchall()
|
|
232
|
+
column_names = [col[0] for col in cursor.description or []]
|
|
233
|
+
|
|
234
|
+
if fetched_data and isinstance(fetched_data[0], tuple):
|
|
235
|
+
dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
|
|
236
|
+
else:
|
|
237
|
+
dict_data = fetched_data # type: ignore[assignment]
|
|
238
|
+
|
|
239
|
+
return SQLResult(
|
|
240
|
+
statement=statement,
|
|
241
|
+
data=cast("list[RowT]", dict_data),
|
|
242
|
+
column_names=column_names,
|
|
243
|
+
rows_affected=len(dict_data),
|
|
244
|
+
operation_type="SELECT",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
operation_type = self._determine_operation_type(statement)
|
|
248
|
+
return SQLResult(
|
|
249
|
+
statement=statement,
|
|
250
|
+
data=cast("list[RowT]", []),
|
|
251
|
+
rows_affected=cursor.rowcount,
|
|
252
|
+
operation_type=operation_type,
|
|
253
|
+
metadata={"status_message": "OK"},
|
|
254
|
+
)
|
|
229
255
|
|
|
230
256
|
def _execute_many(
|
|
231
257
|
self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
232
|
-
) ->
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
258
|
+
) -> SQLResult[RowT]:
|
|
259
|
+
# Use provided connection or driver's default connection
|
|
260
|
+
conn = connection if connection is not None else self._connection(None)
|
|
261
|
+
|
|
262
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
263
|
+
# Normalize parameter list using consolidated utility
|
|
264
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
265
|
+
|
|
266
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
267
|
+
try:
|
|
268
|
+
cursor.executemany(sql, normalized_param_list or [])
|
|
269
|
+
except Exception as e:
|
|
270
|
+
if self.dialect == "postgres":
|
|
271
|
+
with contextlib.suppress(Exception):
|
|
272
|
+
cursor.execute("ROLLBACK")
|
|
273
|
+
# Always re-raise the original exception
|
|
274
|
+
raise e from e
|
|
275
|
+
|
|
276
|
+
return SQLResult(
|
|
277
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
278
|
+
data=[],
|
|
279
|
+
rows_affected=cursor.rowcount,
|
|
280
|
+
operation_type="EXECUTE",
|
|
281
|
+
metadata={"status_message": "OK"},
|
|
282
|
+
)
|
|
246
283
|
|
|
247
284
|
def _execute_script(
|
|
248
285
|
self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
|
|
249
|
-
) ->
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
286
|
+
) -> SQLResult[RowT]:
|
|
287
|
+
# Use provided connection or driver's default connection
|
|
288
|
+
conn = connection if connection is not None else self._connection(None)
|
|
289
|
+
|
|
290
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
291
|
+
# ADBC drivers don't support multiple statements in a single execute
|
|
292
|
+
statements = self._split_script_statements(script)
|
|
293
|
+
|
|
294
|
+
executed_count = 0
|
|
295
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
296
|
+
for statement in statements:
|
|
297
|
+
if statement.strip():
|
|
298
|
+
self._execute_single_script_statement(cursor, statement)
|
|
299
|
+
executed_count += 1
|
|
300
|
+
|
|
301
|
+
return SQLResult(
|
|
302
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
303
|
+
data=[],
|
|
304
|
+
rows_affected=0,
|
|
305
|
+
operation_type="SCRIPT",
|
|
306
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
307
|
+
total_statements=executed_count,
|
|
308
|
+
successful_statements=executed_count,
|
|
309
|
+
)
|
|
262
310
|
|
|
263
311
|
def _execute_single_script_statement(self, cursor: "Cursor", statement: str) -> int:
|
|
264
312
|
"""Execute a single statement from a script and handle errors.
|
|
@@ -273,7 +321,7 @@ class AdbcDriver(
|
|
|
273
321
|
try:
|
|
274
322
|
cursor.execute(statement)
|
|
275
323
|
except Exception as e:
|
|
276
|
-
# Rollback transaction on error for PostgreSQL
|
|
324
|
+
# Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
|
|
277
325
|
if self.dialect == "postgres":
|
|
278
326
|
with contextlib.suppress(Exception):
|
|
279
327
|
cursor.execute("ROLLBACK")
|
|
@@ -281,59 +329,6 @@ class AdbcDriver(
|
|
|
281
329
|
else:
|
|
282
330
|
return 1
|
|
283
331
|
|
|
284
|
-
def _wrap_select_result(
|
|
285
|
-
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
286
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
287
|
-
# result must be a dict with keys: data, column_names, rows_affected
|
|
288
|
-
|
|
289
|
-
rows_as_dicts = [dict(zip(result["column_names"], row)) for row in result["data"]]
|
|
290
|
-
|
|
291
|
-
if schema_type:
|
|
292
|
-
return SQLResult[ModelDTOT](
|
|
293
|
-
statement=statement,
|
|
294
|
-
data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
|
|
295
|
-
column_names=result["column_names"],
|
|
296
|
-
rows_affected=result["rows_affected"],
|
|
297
|
-
operation_type="SELECT",
|
|
298
|
-
)
|
|
299
|
-
return SQLResult[RowT](
|
|
300
|
-
statement=statement,
|
|
301
|
-
data=rows_as_dicts,
|
|
302
|
-
column_names=result["column_names"],
|
|
303
|
-
rows_affected=result["rows_affected"],
|
|
304
|
-
operation_type="SELECT",
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
def _wrap_execute_result(
|
|
308
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
309
|
-
) -> SQLResult[RowT]:
|
|
310
|
-
operation_type = (
|
|
311
|
-
str(statement.expression.key).upper()
|
|
312
|
-
if statement.expression and hasattr(statement.expression, "key")
|
|
313
|
-
else "UNKNOWN"
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
# Handle TypedDict results
|
|
317
|
-
if is_dict_with_field(result, "statements_executed"):
|
|
318
|
-
return SQLResult[RowT](
|
|
319
|
-
statement=statement,
|
|
320
|
-
data=[],
|
|
321
|
-
rows_affected=0,
|
|
322
|
-
total_statements=result["statements_executed"],
|
|
323
|
-
operation_type="SCRIPT", # Scripts always have operation_type SCRIPT
|
|
324
|
-
metadata={"status_message": result["status_message"]},
|
|
325
|
-
)
|
|
326
|
-
if is_dict_with_field(result, "rows_affected"):
|
|
327
|
-
return SQLResult[RowT](
|
|
328
|
-
statement=statement,
|
|
329
|
-
data=[],
|
|
330
|
-
rows_affected=result["rows_affected"],
|
|
331
|
-
operation_type=operation_type,
|
|
332
|
-
metadata={"status_message": result["status_message"]},
|
|
333
|
-
)
|
|
334
|
-
msg = f"Unexpected result type: {type(result)}"
|
|
335
|
-
raise ValueError(msg)
|
|
336
|
-
|
|
337
332
|
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
338
333
|
"""ADBC native Arrow table fetching.
|
|
339
334
|
|
|
@@ -379,10 +374,17 @@ class AdbcDriver(
|
|
|
379
374
|
|
|
380
375
|
conn = self._connection(None)
|
|
381
376
|
with self._get_cursor(conn) as cursor:
|
|
382
|
-
# Handle different modes
|
|
383
377
|
if mode == "replace":
|
|
384
|
-
cursor.execute(
|
|
378
|
+
cursor.execute(
|
|
379
|
+
SQL(f"TRUNCATE TABLE {table_name}", _dialect=self.dialect).to_sql(
|
|
380
|
+
placeholder_style=ParameterStyle.STATIC
|
|
381
|
+
)
|
|
382
|
+
)
|
|
385
383
|
elif mode == "create":
|
|
386
384
|
msg = "'create' mode is not supported for ADBC ingestion"
|
|
387
385
|
raise NotImplementedError(msg)
|
|
388
386
|
return cursor.adbc_ingest(table_name, table, mode=mode, **options) # type: ignore[arg-type]
|
|
387
|
+
|
|
388
|
+
def _connection(self, connection: Optional["AdbcConnection"] = None) -> "AdbcConnection":
|
|
389
|
+
"""Get the connection to use for the operation."""
|
|
390
|
+
return connection or self.connection
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
|
8
7
|
|
|
9
8
|
import aiosqlite
|
|
@@ -172,15 +171,16 @@ class AiosqliteConfig(AsyncDatabaseConfig[AiosqliteConnection, None, AiosqliteDr
|
|
|
172
171
|
An AiosqliteDriver instance.
|
|
173
172
|
"""
|
|
174
173
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
175
|
-
# Create statement config with parameter style info if not already set
|
|
176
174
|
statement_config = self.statement_config
|
|
175
|
+
# Inject parameter style info if not already set
|
|
177
176
|
if statement_config.allowed_parameter_styles is None:
|
|
177
|
+
from dataclasses import replace
|
|
178
|
+
|
|
178
179
|
statement_config = replace(
|
|
179
180
|
statement_config,
|
|
180
181
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
181
182
|
target_parameter_style=self.preferred_parameter_style,
|
|
182
183
|
)
|
|
183
|
-
|
|
184
184
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
185
185
|
|
|
186
186
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> None:
|
|
@@ -3,11 +3,12 @@ import logging
|
|
|
3
3
|
from collections.abc import AsyncGenerator, Sequence
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
7
7
|
|
|
8
8
|
import aiosqlite
|
|
9
9
|
|
|
10
10
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
11
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
11
12
|
from sqlspec.driver.mixins import (
|
|
12
13
|
AsyncPipelinedExecutionMixin,
|
|
13
14
|
AsyncStorageMixin,
|
|
@@ -15,10 +16,11 @@ from sqlspec.driver.mixins import (
|
|
|
15
16
|
ToSchemaMixin,
|
|
16
17
|
TypeCoercionMixin,
|
|
17
18
|
)
|
|
18
|
-
from sqlspec.
|
|
19
|
-
from sqlspec.statement.
|
|
19
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
20
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
21
|
+
from sqlspec.statement.result import SQLResult
|
|
20
22
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
21
|
-
from sqlspec.typing import DictRow,
|
|
23
|
+
from sqlspec.typing import DictRow, RowT
|
|
22
24
|
from sqlspec.utils.serializers import to_json
|
|
23
25
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
@@ -97,22 +99,24 @@ class AiosqliteDriver(
|
|
|
97
99
|
|
|
98
100
|
async def _execute_statement(
|
|
99
101
|
self, statement: SQL, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
100
|
-
) ->
|
|
102
|
+
) -> SQLResult[RowT]:
|
|
101
103
|
if statement.is_script:
|
|
102
104
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
103
105
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
104
106
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
+
detected_styles = set()
|
|
108
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
109
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
110
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
111
|
+
if param_infos:
|
|
112
|
+
detected_styles = {p.style for p in param_infos}
|
|
113
|
+
|
|
107
114
|
target_style = self.default_parameter_style
|
|
108
115
|
|
|
109
|
-
# Check if any detected style is not supported
|
|
110
116
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
111
117
|
if unsupported_styles:
|
|
112
|
-
# Convert to default style if we have unsupported styles
|
|
113
118
|
target_style = self.default_parameter_style
|
|
114
119
|
elif detected_styles:
|
|
115
|
-
# Use the first detected style if all are supported
|
|
116
120
|
# Prefer the first supported style found
|
|
117
121
|
for style in detected_styles:
|
|
118
122
|
if style in self.supported_parameter_styles:
|
|
@@ -122,85 +126,111 @@ class AiosqliteDriver(
|
|
|
122
126
|
if statement.is_many:
|
|
123
127
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
124
128
|
|
|
125
|
-
# Process parameter list through type coercion
|
|
126
129
|
params = self._process_parameters(params)
|
|
127
130
|
|
|
128
131
|
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
129
132
|
|
|
130
133
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
131
134
|
|
|
132
|
-
# Process parameters through type coercion
|
|
133
135
|
params = self._process_parameters(params)
|
|
134
136
|
|
|
135
137
|
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
136
138
|
|
|
137
139
|
async def _execute(
|
|
138
140
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
139
|
-
) ->
|
|
141
|
+
) -> SQLResult[RowT]:
|
|
140
142
|
conn = self._connection(connection)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
143
|
+
|
|
144
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
145
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
146
|
+
|
|
147
|
+
# Extract the actual parameters from the normalized list
|
|
148
|
+
if normalized_params and len(normalized_params) == 1:
|
|
149
|
+
actual_params = normalized_params[0]
|
|
150
|
+
else:
|
|
151
|
+
actual_params = normalized_params
|
|
152
|
+
|
|
153
|
+
# AIOSQLite expects tuple or dict - handle parameter conversion
|
|
154
|
+
if ":param_" in sql or (isinstance(actual_params, dict)):
|
|
155
|
+
# SQL has named placeholders, ensure params are dict
|
|
156
|
+
converted_params = self._convert_parameters_to_driver_format(
|
|
157
|
+
sql, actual_params, target_style=ParameterStyle.NAMED_COLON
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
# SQL has positional placeholders, ensure params are list/tuple
|
|
161
|
+
converted_params = self._convert_parameters_to_driver_format(
|
|
162
|
+
sql, actual_params, target_style=ParameterStyle.QMARK
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
166
|
+
# Aiosqlite handles both dict and tuple parameters
|
|
167
|
+
await cursor.execute(sql, converted_params or ())
|
|
168
|
+
if self.returns_rows(statement.expression):
|
|
169
|
+
fetched_data = await cursor.fetchall()
|
|
170
|
+
column_names = [desc[0] for desc in cursor.description or []]
|
|
171
|
+
data_list: list[Any] = list(fetched_data) if fetched_data else []
|
|
172
|
+
return SQLResult(
|
|
173
|
+
statement=statement,
|
|
174
|
+
data=data_list,
|
|
175
|
+
column_names=column_names,
|
|
176
|
+
rows_affected=len(data_list),
|
|
177
|
+
operation_type="SELECT",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return SQLResult(
|
|
181
|
+
statement=statement,
|
|
182
|
+
data=[],
|
|
183
|
+
rows_affected=cursor.rowcount,
|
|
184
|
+
operation_type=self._determine_operation_type(statement),
|
|
185
|
+
metadata={"status_message": "OK"},
|
|
186
|
+
)
|
|
169
187
|
|
|
170
188
|
async def _execute_many(
|
|
171
189
|
self, sql: str, param_list: Any, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
172
|
-
) ->
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
param_set
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
190
|
+
) -> SQLResult[RowT]:
|
|
191
|
+
# Use provided connection or driver's default connection
|
|
192
|
+
conn = connection if connection is not None else self._connection(None)
|
|
193
|
+
|
|
194
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
195
|
+
# Normalize parameter list using consolidated utility
|
|
196
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
197
|
+
|
|
198
|
+
params_list: list[tuple[Any, ...]] = []
|
|
199
|
+
if normalized_param_list and isinstance(normalized_param_list, Sequence):
|
|
200
|
+
for param_set in normalized_param_list:
|
|
201
|
+
if isinstance(param_set, (list, tuple)):
|
|
202
|
+
params_list.append(tuple(param_set))
|
|
203
|
+
elif param_set is None:
|
|
204
|
+
params_list.append(())
|
|
205
|
+
|
|
206
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
207
|
+
await cursor.executemany(sql, params_list)
|
|
208
|
+
return SQLResult(
|
|
209
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
210
|
+
data=[],
|
|
211
|
+
rows_affected=cursor.rowcount,
|
|
212
|
+
operation_type="EXECUTE",
|
|
213
|
+
metadata={"status_message": "OK"},
|
|
214
|
+
)
|
|
192
215
|
|
|
193
216
|
async def _execute_script(
|
|
194
217
|
self, script: str, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
195
|
-
) ->
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
218
|
+
) -> SQLResult[RowT]:
|
|
219
|
+
# Use provided connection or driver's default connection
|
|
220
|
+
conn = connection if connection is not None else self._connection(None)
|
|
221
|
+
|
|
222
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
223
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
224
|
+
await cursor.executescript(script)
|
|
225
|
+
return SQLResult(
|
|
226
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
227
|
+
data=[],
|
|
228
|
+
rows_affected=0,
|
|
229
|
+
operation_type="SCRIPT",
|
|
230
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
231
|
+
total_statements=-1, # AIOSQLite doesn't provide this info
|
|
232
|
+
successful_statements=-1,
|
|
233
|
+
)
|
|
204
234
|
|
|
205
235
|
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
206
236
|
"""Database-specific bulk load implementation using storage backend."""
|
|
@@ -234,66 +264,6 @@ class AiosqliteDriver(
|
|
|
234
264
|
finally:
|
|
235
265
|
await conn.close()
|
|
236
266
|
|
|
237
|
-
async def _wrap_select_result(
|
|
238
|
-
self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
|
|
239
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
240
|
-
fetched_data = result["data"]
|
|
241
|
-
column_names = result["column_names"]
|
|
242
|
-
rows_affected = result["rows_affected"]
|
|
243
|
-
|
|
244
|
-
rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in fetched_data]
|
|
245
|
-
|
|
246
|
-
if self.returns_rows(statement.expression):
|
|
247
|
-
converted_data_seq = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
248
|
-
return SQLResult[ModelDTOT](
|
|
249
|
-
statement=statement,
|
|
250
|
-
data=list(converted_data_seq),
|
|
251
|
-
column_names=column_names,
|
|
252
|
-
rows_affected=rows_affected,
|
|
253
|
-
operation_type="SELECT",
|
|
254
|
-
)
|
|
255
|
-
return SQLResult[RowT](
|
|
256
|
-
statement=statement,
|
|
257
|
-
data=rows_as_dicts,
|
|
258
|
-
column_names=column_names,
|
|
259
|
-
rows_affected=rows_affected,
|
|
260
|
-
operation_type="SELECT",
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
async def _wrap_execute_result(
|
|
264
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
265
|
-
) -> SQLResult[RowT]:
|
|
266
|
-
operation_type = "UNKNOWN"
|
|
267
|
-
if statement.expression:
|
|
268
|
-
operation_type = str(statement.expression.key).upper()
|
|
269
|
-
|
|
270
|
-
if "statements_executed" in result:
|
|
271
|
-
script_result = cast("ScriptResultDict", result)
|
|
272
|
-
return SQLResult[RowT](
|
|
273
|
-
statement=statement,
|
|
274
|
-
data=[],
|
|
275
|
-
rows_affected=0,
|
|
276
|
-
operation_type="SCRIPT",
|
|
277
|
-
total_statements=script_result.get("statements_executed", -1),
|
|
278
|
-
metadata={"status_message": script_result.get("status_message", "")},
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
if "rows_affected" in result:
|
|
282
|
-
dml_result = cast("DMLResultDict", result)
|
|
283
|
-
rows_affected = dml_result["rows_affected"]
|
|
284
|
-
status_message = dml_result["status_message"]
|
|
285
|
-
return SQLResult[RowT](
|
|
286
|
-
statement=statement,
|
|
287
|
-
data=[],
|
|
288
|
-
rows_affected=rows_affected,
|
|
289
|
-
operation_type=operation_type,
|
|
290
|
-
metadata={"status_message": status_message},
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
# This shouldn't happen with TypedDict approach
|
|
294
|
-
msg = f"Unexpected result type: {type(result)}"
|
|
295
|
-
raise ValueError(msg)
|
|
296
|
-
|
|
297
267
|
def _connection(self, connection: Optional[AiosqliteConnection] = None) -> AiosqliteConnection:
|
|
298
268
|
"""Get the connection to use for the operation."""
|
|
299
269
|
return connection or self.connection
|