sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import sqlite3
|
|
5
5
|
from contextlib import contextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
8
7
|
|
|
9
8
|
from sqlspec.adapters.sqlite.driver import SqliteConnection, SqliteDriver
|
|
@@ -88,7 +87,6 @@ class SqliteConfig(NoPoolSyncConfig[SqliteConnection, SqliteDriver]):
|
|
|
88
87
|
uri: Whether to interpret database as URI
|
|
89
88
|
**kwargs: Additional parameters (stored in extras)
|
|
90
89
|
"""
|
|
91
|
-
# Validate required parameters
|
|
92
90
|
if database is None:
|
|
93
91
|
msg = "database parameter cannot be None"
|
|
94
92
|
raise TypeError(msg)
|
|
@@ -164,11 +162,13 @@ class SqliteConfig(NoPoolSyncConfig[SqliteConnection, SqliteDriver]):
|
|
|
164
162
|
"""
|
|
165
163
|
with self.provide_connection(*args, **kwargs) as connection:
|
|
166
164
|
statement_config = self.statement_config
|
|
165
|
+
# Inject parameter style info if not already set
|
|
167
166
|
if statement_config.allowed_parameter_styles is None:
|
|
167
|
+
from dataclasses import replace
|
|
168
|
+
|
|
168
169
|
statement_config = replace(
|
|
169
170
|
statement_config,
|
|
170
171
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
171
172
|
target_parameter_style=self.preferred_parameter_style,
|
|
172
173
|
)
|
|
173
|
-
|
|
174
174
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
@@ -4,11 +4,12 @@ import sqlite3
|
|
|
4
4
|
from collections.abc import Iterator
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Optional,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
8
8
|
|
|
9
9
|
from typing_extensions import TypeAlias
|
|
10
10
|
|
|
11
11
|
from sqlspec.driver import SyncDriverAdapterProtocol
|
|
12
|
+
from sqlspec.driver.connection import managed_transaction_sync
|
|
12
13
|
from sqlspec.driver.mixins import (
|
|
13
14
|
SQLTranslatorMixin,
|
|
14
15
|
SyncPipelinedExecutionMixin,
|
|
@@ -16,10 +17,11 @@ from sqlspec.driver.mixins import (
|
|
|
16
17
|
ToSchemaMixin,
|
|
17
18
|
TypeCoercionMixin,
|
|
18
19
|
)
|
|
19
|
-
from sqlspec.
|
|
20
|
-
from sqlspec.statement.
|
|
20
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
21
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
22
|
+
from sqlspec.statement.result import SQLResult
|
|
21
23
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
22
|
-
from sqlspec.typing import DictRow,
|
|
24
|
+
from sqlspec.typing import DictRow, RowT
|
|
23
25
|
from sqlspec.utils.logging import get_logger
|
|
24
26
|
from sqlspec.utils.serializers import to_json
|
|
25
27
|
|
|
@@ -102,19 +104,22 @@ class SqliteDriver(
|
|
|
102
104
|
|
|
103
105
|
def _execute_statement(
|
|
104
106
|
self, statement: SQL, connection: Optional[SqliteConnection] = None, **kwargs: Any
|
|
105
|
-
) ->
|
|
107
|
+
) -> SQLResult[RowT]:
|
|
106
108
|
if statement.is_script:
|
|
107
109
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
108
|
-
return self._execute_script(sql, connection=connection, **kwargs)
|
|
110
|
+
return self._execute_script(sql, connection=connection, statement=statement, **kwargs)
|
|
111
|
+
|
|
112
|
+
detected_styles = set()
|
|
113
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
114
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
115
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
116
|
+
if param_infos:
|
|
117
|
+
detected_styles = {p.style for p in param_infos}
|
|
109
118
|
|
|
110
|
-
# Determine if we need to convert parameter style
|
|
111
|
-
detected_styles = {p.style for p in statement.parameter_info}
|
|
112
119
|
target_style = self.default_parameter_style
|
|
113
120
|
|
|
114
|
-
# Check if any detected style is not supported
|
|
115
121
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
116
122
|
if unsupported_styles:
|
|
117
|
-
# Convert to default style if we have unsupported styles
|
|
118
123
|
target_style = self.default_parameter_style
|
|
119
124
|
elif len(detected_styles) > 1:
|
|
120
125
|
# Mixed styles detected - use default style for consistency
|
|
@@ -129,11 +134,10 @@ class SqliteDriver(
|
|
|
129
134
|
|
|
130
135
|
if statement.is_many:
|
|
131
136
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
132
|
-
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
137
|
+
return self._execute_many(sql, params, connection=connection, statement=statement, **kwargs)
|
|
133
138
|
|
|
134
139
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
135
140
|
|
|
136
|
-
# Process parameters through type coercion
|
|
137
141
|
params = self._process_parameters(params)
|
|
138
142
|
|
|
139
143
|
# SQLite expects tuples for positional parameters
|
|
@@ -144,58 +148,105 @@ class SqliteDriver(
|
|
|
144
148
|
|
|
145
149
|
def _execute(
|
|
146
150
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[SqliteConnection] = None, **kwargs: Any
|
|
147
|
-
) ->
|
|
151
|
+
) -> SQLResult[RowT]:
|
|
148
152
|
"""Execute a single statement with parameters."""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
153
|
+
# Use provided connection or driver's default connection
|
|
154
|
+
conn = connection if connection is not None else self._connection(None)
|
|
155
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
|
|
156
|
+
# Normalize parameters using consolidated utility
|
|
157
|
+
normalized_params_list = normalize_parameter_sequence(parameters)
|
|
158
|
+
params_for_execute: Any
|
|
159
|
+
if normalized_params_list and len(normalized_params_list) == 1:
|
|
160
|
+
# Single parameter should be tuple for SQLite
|
|
161
|
+
if not isinstance(normalized_params_list[0], (tuple, list, dict)):
|
|
162
|
+
params_for_execute = (normalized_params_list[0],)
|
|
163
|
+
else:
|
|
164
|
+
params_for_execute = normalized_params_list[0]
|
|
165
|
+
else:
|
|
166
|
+
# Multiple parameters
|
|
167
|
+
params_for_execute = tuple(normalized_params_list) if normalized_params_list else ()
|
|
168
|
+
|
|
169
|
+
cursor.execute(sql, params_for_execute)
|
|
156
170
|
if self.returns_rows(statement.expression):
|
|
157
171
|
fetched_data: list[sqlite3.Row] = cursor.fetchall()
|
|
158
|
-
return
|
|
159
|
-
|
|
160
|
-
"
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
172
|
+
return SQLResult(
|
|
173
|
+
statement=statement,
|
|
174
|
+
data=cast("list[RowT]", fetched_data),
|
|
175
|
+
column_names=[col[0] for col in cursor.description or []],
|
|
176
|
+
rows_affected=len(fetched_data),
|
|
177
|
+
operation_type="SELECT",
|
|
178
|
+
)
|
|
179
|
+
operation_type = self._determine_operation_type(statement)
|
|
180
|
+
|
|
181
|
+
return SQLResult(
|
|
182
|
+
statement=statement,
|
|
183
|
+
data=[],
|
|
184
|
+
rows_affected=cursor.rowcount,
|
|
185
|
+
operation_type=operation_type,
|
|
186
|
+
metadata={"status_message": "OK"},
|
|
187
|
+
)
|
|
164
188
|
|
|
165
189
|
def _execute_many(
|
|
166
|
-
self,
|
|
167
|
-
|
|
190
|
+
self,
|
|
191
|
+
sql: str,
|
|
192
|
+
param_list: Any,
|
|
193
|
+
connection: Optional[SqliteConnection] = None,
|
|
194
|
+
statement: Optional[SQL] = None,
|
|
195
|
+
**kwargs: Any,
|
|
196
|
+
) -> SQLResult[RowT]:
|
|
168
197
|
"""Execute a statement many times with a list of parameter tuples."""
|
|
169
|
-
|
|
170
|
-
if
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
198
|
+
# Use provided connection or driver's default connection
|
|
199
|
+
conn = connection if connection is not None else self._connection(None)
|
|
200
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
201
|
+
# Normalize parameter list using consolidated utility
|
|
202
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
203
|
+
formatted_params: list[tuple[Any, ...]] = []
|
|
204
|
+
if normalized_param_list:
|
|
205
|
+
for param_set in normalized_param_list:
|
|
206
|
+
if isinstance(param_set, (list, tuple)):
|
|
207
|
+
formatted_params.append(tuple(param_set))
|
|
208
|
+
elif param_set is None:
|
|
209
|
+
formatted_params.append(())
|
|
210
|
+
else:
|
|
211
|
+
formatted_params.append((param_set,))
|
|
212
|
+
|
|
213
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
214
|
+
cursor.executemany(sql, formatted_params)
|
|
215
|
+
|
|
216
|
+
if statement is None:
|
|
217
|
+
statement = SQL(sql, _dialect=self.dialect)
|
|
218
|
+
|
|
219
|
+
return SQLResult(
|
|
220
|
+
statement=statement,
|
|
221
|
+
data=[],
|
|
222
|
+
rows_affected=cursor.rowcount,
|
|
223
|
+
operation_type="EXECUTE",
|
|
224
|
+
metadata={"status_message": "OK"},
|
|
225
|
+
)
|
|
187
226
|
|
|
188
227
|
def _execute_script(
|
|
189
|
-
self, script: str, connection: Optional[SqliteConnection] = None, **kwargs: Any
|
|
190
|
-
) ->
|
|
228
|
+
self, script: str, connection: Optional[SqliteConnection] = None, statement: Optional[SQL] = None, **kwargs: Any
|
|
229
|
+
) -> SQLResult[RowT]:
|
|
191
230
|
"""Execute a script on the SQLite connection."""
|
|
192
|
-
|
|
231
|
+
# Use provided connection or driver's default connection
|
|
232
|
+
conn = connection if connection is not None else self._connection(None)
|
|
193
233
|
with self._get_cursor(conn) as cursor:
|
|
194
234
|
cursor.executescript(script)
|
|
195
|
-
# executescript doesn't auto-commit in some cases
|
|
235
|
+
# executescript doesn't auto-commit in some cases - force commit
|
|
196
236
|
conn.commit()
|
|
197
|
-
|
|
198
|
-
|
|
237
|
+
|
|
238
|
+
if statement is None:
|
|
239
|
+
statement = SQL(script, _dialect=self.dialect).as_script()
|
|
240
|
+
|
|
241
|
+
return SQLResult(
|
|
242
|
+
statement=statement,
|
|
243
|
+
data=[],
|
|
244
|
+
rows_affected=-1, # Unknown for scripts
|
|
245
|
+
operation_type="SCRIPT",
|
|
246
|
+
total_statements=-1, # SQLite doesn't provide this info
|
|
247
|
+
successful_statements=-1,
|
|
248
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
249
|
+
)
|
|
199
250
|
|
|
200
251
|
def _ingest_arrow_table(self, table: Any, table_name: str, mode: str = "create", **options: Any) -> int:
|
|
201
252
|
"""SQLite-specific Arrow table ingestion using CSV conversion.
|
|
@@ -208,12 +259,10 @@ class SqliteDriver(
|
|
|
208
259
|
|
|
209
260
|
import pyarrow.csv as pa_csv
|
|
210
261
|
|
|
211
|
-
# Convert Arrow table to CSV in memory
|
|
212
262
|
csv_buffer = io.BytesIO()
|
|
213
263
|
pa_csv.write_csv(table, csv_buffer)
|
|
214
264
|
csv_content = csv_buffer.getvalue()
|
|
215
265
|
|
|
216
|
-
# Create a temporary file path
|
|
217
266
|
temp_filename = f"sqlspec_temp_{table_name}_{id(self)}.csv"
|
|
218
267
|
temp_path = Path(tempfile.gettempdir()) / temp_filename
|
|
219
268
|
|
|
@@ -258,46 +307,3 @@ class SqliteDriver(
|
|
|
258
307
|
data_iter = list(reader) # Read all data into memory
|
|
259
308
|
cursor.executemany(sql, data_iter)
|
|
260
309
|
return cursor.rowcount
|
|
261
|
-
|
|
262
|
-
def _wrap_select_result(
|
|
263
|
-
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
264
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
265
|
-
rows_as_dicts = [dict(row) for row in result["data"]]
|
|
266
|
-
if schema_type:
|
|
267
|
-
return SQLResult[ModelDTOT](
|
|
268
|
-
statement=statement,
|
|
269
|
-
data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
|
|
270
|
-
column_names=result["column_names"],
|
|
271
|
-
rows_affected=result["rows_affected"],
|
|
272
|
-
operation_type="SELECT",
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
return SQLResult[RowT](
|
|
276
|
-
statement=statement,
|
|
277
|
-
data=rows_as_dicts,
|
|
278
|
-
column_names=result["column_names"],
|
|
279
|
-
rows_affected=result["rows_affected"],
|
|
280
|
-
operation_type="SELECT",
|
|
281
|
-
)
|
|
282
|
-
|
|
283
|
-
def _wrap_execute_result(
|
|
284
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
285
|
-
) -> SQLResult[RowT]:
|
|
286
|
-
if is_dict_with_field(result, "statements_executed"):
|
|
287
|
-
return SQLResult[RowT](
|
|
288
|
-
statement=statement,
|
|
289
|
-
data=[],
|
|
290
|
-
rows_affected=0,
|
|
291
|
-
operation_type="SCRIPT",
|
|
292
|
-
metadata={
|
|
293
|
-
"status_message": result.get("status_message", ""),
|
|
294
|
-
"statements_executed": result.get("statements_executed", -1),
|
|
295
|
-
},
|
|
296
|
-
)
|
|
297
|
-
return SQLResult[RowT](
|
|
298
|
-
statement=statement,
|
|
299
|
-
data=[],
|
|
300
|
-
rows_affected=cast("int", result.get("rows_affected", -1)),
|
|
301
|
-
operation_type=statement.expression.key.upper() if statement.expression else "UNKNOWN",
|
|
302
|
-
metadata={"status_message": result.get("status_message", "")},
|
|
303
|
-
)
|
sqlspec/config.py
CHANGED
|
@@ -97,7 +97,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
97
97
|
Returns:
|
|
98
98
|
The SQL dialect type.
|
|
99
99
|
"""
|
|
100
|
-
# Get dialect from driver_class (all drivers must have a dialect attribute)
|
|
101
100
|
return self.driver_type.dialect
|
|
102
101
|
|
|
103
102
|
@abstractmethod
|
|
@@ -154,17 +153,14 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
154
153
|
"""
|
|
155
154
|
namespace: dict[str, type[Any]] = {}
|
|
156
155
|
|
|
157
|
-
# Add the driver and config types
|
|
158
156
|
if hasattr(self, "driver_type") and self.driver_type:
|
|
159
157
|
namespace[self.driver_type.__name__] = self.driver_type
|
|
160
158
|
|
|
161
159
|
namespace[self.__class__.__name__] = self.__class__
|
|
162
160
|
|
|
163
|
-
# Add connection type(s)
|
|
164
161
|
if hasattr(self, "connection_type") and self.connection_type:
|
|
165
162
|
connection_type = self.connection_type
|
|
166
163
|
|
|
167
|
-
# Handle Union types (like AsyncPG's Union[Connection, PoolConnectionProxy])
|
|
168
164
|
if hasattr(connection_type, "__args__"):
|
|
169
165
|
# It's a generic type, extract the actual types
|
|
170
166
|
for arg_type in connection_type.__args__: # type: ignore[attr-defined]
|
sqlspec/driver/_async.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
"""Asynchronous driver protocol implementation."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from
|
|
4
|
+
from dataclasses import replace
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, overload
|
|
5
6
|
|
|
6
7
|
from sqlspec.driver._common import CommonDriverAttributesMixin
|
|
7
|
-
from sqlspec.
|
|
8
|
-
from sqlspec.statement.
|
|
8
|
+
from sqlspec.driver.parameters import process_execute_many_parameters
|
|
9
|
+
from sqlspec.statement.builder import Delete, Insert, QueryBuilder, Select, Update
|
|
9
10
|
from sqlspec.statement.result import SQLResult
|
|
10
11
|
from sqlspec.statement.sql import SQL, SQLConfig, Statement
|
|
11
12
|
from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
|
|
13
|
+
from sqlspec.utils.logging import get_logger
|
|
14
|
+
from sqlspec.utils.type_guards import can_convert_to_schema
|
|
12
15
|
|
|
13
16
|
if TYPE_CHECKING:
|
|
14
|
-
from sqlspec.statement.
|
|
17
|
+
from sqlspec.statement.filters import StatementFilter
|
|
18
|
+
|
|
19
|
+
logger = get_logger("sqlspec")
|
|
15
20
|
|
|
16
21
|
__all__ = ("AsyncDriverAdapterProtocol",)
|
|
17
22
|
|
|
@@ -49,42 +54,64 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
49
54
|
|
|
50
55
|
if isinstance(statement, QueryBuilder):
|
|
51
56
|
return statement.to_statement(config=_config)
|
|
52
|
-
# If statement is already a SQL object,
|
|
57
|
+
# If statement is already a SQL object, handle additional parameters
|
|
53
58
|
if isinstance(statement, SQL):
|
|
59
|
+
if parameters or kwargs:
|
|
60
|
+
new_config = _config
|
|
61
|
+
if self.dialect and not new_config.dialect:
|
|
62
|
+
new_config = replace(new_config, dialect=self.dialect)
|
|
63
|
+
# Use raw SQL if available to ensure proper parsing with dialect
|
|
64
|
+
sql_source = statement._raw_sql or statement._statement
|
|
65
|
+
# Preserve filters and state when creating new SQL object
|
|
66
|
+
existing_state = {
|
|
67
|
+
"is_many": statement._is_many,
|
|
68
|
+
"is_script": statement._is_script,
|
|
69
|
+
"original_parameters": statement._original_parameters,
|
|
70
|
+
"filters": statement._filters,
|
|
71
|
+
"positional_params": statement._positional_params,
|
|
72
|
+
"named_params": statement._named_params,
|
|
73
|
+
}
|
|
74
|
+
return SQL(sql_source, *parameters, config=new_config, _existing_state=existing_state, **kwargs)
|
|
75
|
+
# Even without additional parameters, ensure dialect is set
|
|
76
|
+
if self.dialect and (not statement._config.dialect or statement._config.dialect != self.dialect):
|
|
77
|
+
new_config = replace(statement._config, dialect=self.dialect)
|
|
78
|
+
# Use raw SQL if available to ensure proper parsing with dialect
|
|
79
|
+
sql_source = statement._raw_sql or statement._statement
|
|
80
|
+
# Preserve parameters and state when creating new SQL object
|
|
81
|
+
# Use the public parameters property which always has the right value
|
|
82
|
+
existing_state = {
|
|
83
|
+
"is_many": statement._is_many,
|
|
84
|
+
"is_script": statement._is_script,
|
|
85
|
+
"original_parameters": statement._original_parameters,
|
|
86
|
+
"filters": statement._filters,
|
|
87
|
+
"positional_params": statement._positional_params,
|
|
88
|
+
"named_params": statement._named_params,
|
|
89
|
+
}
|
|
90
|
+
if statement.parameters:
|
|
91
|
+
return SQL(
|
|
92
|
+
sql_source, parameters=statement.parameters, config=new_config, _existing_state=existing_state
|
|
93
|
+
)
|
|
94
|
+
return SQL(sql_source, config=new_config, _existing_state=existing_state)
|
|
54
95
|
return statement
|
|
55
|
-
|
|
96
|
+
new_config = _config
|
|
97
|
+
if self.dialect and not new_config.dialect:
|
|
98
|
+
new_config = replace(new_config, dialect=self.dialect)
|
|
99
|
+
return SQL(statement, *parameters, config=new_config, **kwargs)
|
|
56
100
|
|
|
57
101
|
@abstractmethod
|
|
58
102
|
async def _execute_statement(
|
|
59
103
|
self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
|
|
60
|
-
) -> "
|
|
104
|
+
) -> "SQLResult[RowT]":
|
|
61
105
|
"""Actual execution implementation by concrete drivers, using the raw connection.
|
|
62
106
|
|
|
63
|
-
Returns
|
|
107
|
+
Returns SQLResult directly based on the statement type.
|
|
64
108
|
"""
|
|
65
109
|
raise NotImplementedError
|
|
66
110
|
|
|
67
|
-
@abstractmethod
|
|
68
|
-
async def _wrap_select_result(
|
|
69
|
-
self,
|
|
70
|
-
statement: "SQL",
|
|
71
|
-
result: "SelectResultDict",
|
|
72
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
73
|
-
**kwargs: Any,
|
|
74
|
-
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
75
|
-
raise NotImplementedError
|
|
76
|
-
|
|
77
|
-
@abstractmethod
|
|
78
|
-
async def _wrap_execute_result(
|
|
79
|
-
self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
|
|
80
|
-
) -> "SQLResult[RowT]":
|
|
81
|
-
raise NotImplementedError
|
|
82
|
-
|
|
83
|
-
# Type-safe overloads based on the refactor plan pattern
|
|
84
111
|
@overload
|
|
85
112
|
async def execute(
|
|
86
113
|
self,
|
|
87
|
-
statement: "
|
|
114
|
+
statement: "Select",
|
|
88
115
|
/,
|
|
89
116
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
90
117
|
schema_type: "type[ModelDTOT]",
|
|
@@ -96,7 +123,7 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
96
123
|
@overload
|
|
97
124
|
async def execute(
|
|
98
125
|
self,
|
|
99
|
-
statement: "
|
|
126
|
+
statement: "Select",
|
|
100
127
|
/,
|
|
101
128
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
102
129
|
schema_type: None = None,
|
|
@@ -108,7 +135,7 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
108
135
|
@overload
|
|
109
136
|
async def execute(
|
|
110
137
|
self,
|
|
111
|
-
statement: "Union[
|
|
138
|
+
statement: "Union[Insert, Update, Delete]",
|
|
112
139
|
/,
|
|
113
140
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
114
141
|
_connection: "Optional[ConnectionT]" = None,
|
|
@@ -155,51 +182,45 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
155
182
|
statement=sql_statement, connection=self._connection(_connection), **kwargs
|
|
156
183
|
)
|
|
157
184
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
185
|
+
# If schema_type is provided and we have data, convert it
|
|
186
|
+
if schema_type and result.data and can_convert_to_schema(self):
|
|
187
|
+
converted_data = list(self.to_schema(data=result.data, schema_type=schema_type))
|
|
188
|
+
return SQLResult[ModelDTOT](
|
|
189
|
+
statement=result.statement,
|
|
190
|
+
data=converted_data,
|
|
191
|
+
column_names=result.column_names,
|
|
192
|
+
rows_affected=result.rows_affected,
|
|
193
|
+
operation_type=result.operation_type,
|
|
194
|
+
last_inserted_id=result.last_inserted_id,
|
|
195
|
+
execution_time=result.execution_time,
|
|
196
|
+
metadata=result.metadata,
|
|
161
197
|
)
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
)
|
|
198
|
+
|
|
199
|
+
return result
|
|
165
200
|
|
|
166
201
|
async def execute_many(
|
|
167
202
|
self,
|
|
168
|
-
statement: "Union[SQL, Statement, QueryBuilder[Any]]",
|
|
203
|
+
statement: "Union[SQL, Statement, QueryBuilder[Any]]",
|
|
169
204
|
/,
|
|
170
205
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
171
206
|
_connection: "Optional[ConnectionT]" = None,
|
|
172
207
|
_config: "Optional[SQLConfig]" = None,
|
|
173
208
|
**kwargs: Any,
|
|
174
209
|
) -> "SQLResult[RowT]":
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
for
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
# Ensure param_sequence is a list or None
|
|
190
|
-
if param_sequence is not None and not isinstance(param_sequence, list):
|
|
191
|
-
param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
|
|
192
|
-
sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs)
|
|
193
|
-
sql_statement = sql_statement.as_many(param_sequence)
|
|
194
|
-
result = await self._execute_statement(
|
|
195
|
-
statement=sql_statement,
|
|
196
|
-
connection=self._connection(_connection),
|
|
197
|
-
parameters=param_sequence,
|
|
198
|
-
is_many=True,
|
|
199
|
-
**kwargs,
|
|
200
|
-
)
|
|
201
|
-
return await self._wrap_execute_result(
|
|
202
|
-
sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
|
|
210
|
+
_filters, param_sequence = process_execute_many_parameters(parameters)
|
|
211
|
+
|
|
212
|
+
# For execute_many, disable transformations to prevent literal extraction
|
|
213
|
+
# since the SQL already has placeholders for bulk operations
|
|
214
|
+
many_config = _config or self.config
|
|
215
|
+
if many_config.enable_transformations:
|
|
216
|
+
from dataclasses import replace
|
|
217
|
+
|
|
218
|
+
many_config = replace(many_config, enable_transformations=False)
|
|
219
|
+
|
|
220
|
+
sql_statement = self._build_statement(statement, _config=many_config, **kwargs).as_many(param_sequence)
|
|
221
|
+
|
|
222
|
+
return await self._execute_statement(
|
|
223
|
+
statement=sql_statement, connection=self._connection(_connection), **kwargs
|
|
203
224
|
)
|
|
204
225
|
|
|
205
226
|
async def execute_script(
|
|
@@ -211,42 +232,12 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT],
|
|
|
211
232
|
_config: "Optional[SQLConfig]" = None,
|
|
212
233
|
**kwargs: Any,
|
|
213
234
|
) -> "SQLResult[RowT]":
|
|
214
|
-
param_values = []
|
|
215
|
-
filters = []
|
|
216
|
-
for param in parameters:
|
|
217
|
-
if isinstance(param, StatementFilter):
|
|
218
|
-
filters.append(param)
|
|
219
|
-
else:
|
|
220
|
-
param_values.append(param)
|
|
221
|
-
|
|
222
|
-
# Use first parameter as the primary parameter value, or None if no parameters
|
|
223
|
-
primary_params = param_values[0] if param_values else None
|
|
224
|
-
|
|
225
235
|
script_config = _config or self.config
|
|
226
236
|
if script_config.enable_validation:
|
|
227
|
-
script_config =
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
enable_transformations=script_config.enable_transformations,
|
|
231
|
-
enable_analysis=script_config.enable_analysis,
|
|
232
|
-
strict_mode=False,
|
|
233
|
-
cache_parsed_expression=script_config.cache_parsed_expression,
|
|
234
|
-
parameter_converter=script_config.parameter_converter,
|
|
235
|
-
parameter_validator=script_config.parameter_validator,
|
|
236
|
-
analysis_cache_size=script_config.analysis_cache_size,
|
|
237
|
-
allowed_parameter_styles=script_config.allowed_parameter_styles,
|
|
238
|
-
target_parameter_style=script_config.target_parameter_style,
|
|
239
|
-
allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
|
|
240
|
-
)
|
|
241
|
-
sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
|
|
237
|
+
script_config = replace(script_config, enable_validation=False, strict_mode=False)
|
|
238
|
+
|
|
239
|
+
sql_statement = self._build_statement(statement, *parameters, _config=script_config, **kwargs)
|
|
242
240
|
sql_statement = sql_statement.as_script()
|
|
243
|
-
|
|
244
|
-
statement=sql_statement, connection=self._connection(_connection),
|
|
241
|
+
return await self._execute_statement(
|
|
242
|
+
statement=sql_statement, connection=self._connection(_connection), **kwargs
|
|
245
243
|
)
|
|
246
|
-
if isinstance(script_output, str):
|
|
247
|
-
result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
|
|
248
|
-
result.total_statements = 1
|
|
249
|
-
result.successful_statements = 1
|
|
250
|
-
return result
|
|
251
|
-
# Wrap the ScriptResultDict using the driver's wrapper
|
|
252
|
-
return await self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)
|