sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,11 +3,12 @@ import logging
|
|
|
3
3
|
from collections.abc import AsyncGenerator, Sequence
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
7
7
|
|
|
8
8
|
import aiosqlite
|
|
9
9
|
|
|
10
10
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
11
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
11
12
|
from sqlspec.driver.mixins import (
|
|
12
13
|
AsyncPipelinedExecutionMixin,
|
|
13
14
|
AsyncStorageMixin,
|
|
@@ -15,10 +16,11 @@ from sqlspec.driver.mixins import (
|
|
|
15
16
|
ToSchemaMixin,
|
|
16
17
|
TypeCoercionMixin,
|
|
17
18
|
)
|
|
18
|
-
from sqlspec.
|
|
19
|
-
from sqlspec.statement.
|
|
19
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
20
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
21
|
+
from sqlspec.statement.result import SQLResult
|
|
20
22
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
21
|
-
from sqlspec.typing import DictRow,
|
|
23
|
+
from sqlspec.typing import DictRow, RowT
|
|
22
24
|
from sqlspec.utils.serializers import to_json
|
|
23
25
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
@@ -97,22 +99,24 @@ class AiosqliteDriver(
|
|
|
97
99
|
|
|
98
100
|
async def _execute_statement(
|
|
99
101
|
self, statement: SQL, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
100
|
-
) ->
|
|
102
|
+
) -> SQLResult[RowT]:
|
|
101
103
|
if statement.is_script:
|
|
102
104
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
103
105
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
104
106
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
+
detected_styles = set()
|
|
108
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
109
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
110
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
111
|
+
if param_infos:
|
|
112
|
+
detected_styles = {p.style for p in param_infos}
|
|
113
|
+
|
|
107
114
|
target_style = self.default_parameter_style
|
|
108
115
|
|
|
109
|
-
# Check if any detected style is not supported
|
|
110
116
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
111
117
|
if unsupported_styles:
|
|
112
|
-
# Convert to default style if we have unsupported styles
|
|
113
118
|
target_style = self.default_parameter_style
|
|
114
119
|
elif detected_styles:
|
|
115
|
-
# Use the first detected style if all are supported
|
|
116
120
|
# Prefer the first supported style found
|
|
117
121
|
for style in detected_styles:
|
|
118
122
|
if style in self.supported_parameter_styles:
|
|
@@ -122,89 +126,114 @@ class AiosqliteDriver(
|
|
|
122
126
|
if statement.is_many:
|
|
123
127
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
124
128
|
|
|
125
|
-
# Process parameter list through type coercion
|
|
126
129
|
params = self._process_parameters(params)
|
|
127
130
|
|
|
128
131
|
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
129
132
|
|
|
130
133
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
131
134
|
|
|
132
|
-
# Process parameters through type coercion
|
|
133
135
|
params = self._process_parameters(params)
|
|
134
136
|
|
|
135
137
|
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
136
138
|
|
|
137
139
|
async def _execute(
|
|
138
140
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
139
|
-
) ->
|
|
141
|
+
) -> SQLResult[RowT]:
|
|
140
142
|
conn = self._connection(connection)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
143
|
+
|
|
144
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
145
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
146
|
+
|
|
147
|
+
# Extract the actual parameters from the normalized list
|
|
148
|
+
if normalized_params and len(normalized_params) == 1:
|
|
149
|
+
actual_params = normalized_params[0]
|
|
150
|
+
else:
|
|
151
|
+
actual_params = normalized_params
|
|
152
|
+
|
|
153
|
+
# AIOSQLite expects tuple or dict - handle parameter conversion
|
|
154
|
+
if ":param_" in sql or (isinstance(actual_params, dict)):
|
|
155
|
+
# SQL has named placeholders, ensure params are dict
|
|
156
|
+
converted_params = self._convert_parameters_to_driver_format(
|
|
157
|
+
sql, actual_params, target_style=ParameterStyle.NAMED_COLON
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
# SQL has positional placeholders, ensure params are list/tuple
|
|
161
|
+
converted_params = self._convert_parameters_to_driver_format(
|
|
162
|
+
sql, actual_params, target_style=ParameterStyle.QMARK
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
166
|
+
# Aiosqlite handles both dict and tuple parameters
|
|
167
|
+
await cursor.execute(sql, converted_params or ())
|
|
168
|
+
if self.returns_rows(statement.expression):
|
|
169
|
+
fetched_data = await cursor.fetchall()
|
|
170
|
+
column_names = [desc[0] for desc in cursor.description or []]
|
|
171
|
+
data_list: list[Any] = list(fetched_data) if fetched_data else []
|
|
172
|
+
return SQLResult(
|
|
173
|
+
statement=statement,
|
|
174
|
+
data=data_list,
|
|
175
|
+
column_names=column_names,
|
|
176
|
+
rows_affected=len(data_list),
|
|
177
|
+
operation_type="SELECT",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return SQLResult(
|
|
181
|
+
statement=statement,
|
|
182
|
+
data=[],
|
|
183
|
+
rows_affected=cursor.rowcount,
|
|
184
|
+
operation_type=self._determine_operation_type(statement),
|
|
185
|
+
metadata={"status_message": "OK"},
|
|
186
|
+
)
|
|
169
187
|
|
|
170
188
|
async def _execute_many(
|
|
171
189
|
self, sql: str, param_list: Any, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
172
|
-
) ->
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
param_set
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
190
|
+
) -> SQLResult[RowT]:
|
|
191
|
+
# Use provided connection or driver's default connection
|
|
192
|
+
conn = connection if connection is not None else self._connection(None)
|
|
193
|
+
|
|
194
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
195
|
+
# Normalize parameter list using consolidated utility
|
|
196
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
197
|
+
|
|
198
|
+
params_list: list[tuple[Any, ...]] = []
|
|
199
|
+
if normalized_param_list and isinstance(normalized_param_list, Sequence):
|
|
200
|
+
for param_set in normalized_param_list:
|
|
201
|
+
if isinstance(param_set, (list, tuple)):
|
|
202
|
+
params_list.append(tuple(param_set))
|
|
203
|
+
elif param_set is None:
|
|
204
|
+
params_list.append(())
|
|
205
|
+
|
|
206
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
207
|
+
await cursor.executemany(sql, params_list)
|
|
208
|
+
return SQLResult(
|
|
209
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
210
|
+
data=[],
|
|
211
|
+
rows_affected=cursor.rowcount,
|
|
212
|
+
operation_type="EXECUTE",
|
|
213
|
+
metadata={"status_message": "OK"},
|
|
214
|
+
)
|
|
192
215
|
|
|
193
216
|
async def _execute_script(
|
|
194
217
|
self, script: str, connection: Optional[AiosqliteConnection] = None, **kwargs: Any
|
|
195
|
-
) ->
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
218
|
+
) -> SQLResult[RowT]:
|
|
219
|
+
# Use provided connection or driver's default connection
|
|
220
|
+
conn = connection if connection is not None else self._connection(None)
|
|
221
|
+
|
|
222
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
223
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
224
|
+
await cursor.executescript(script)
|
|
225
|
+
return SQLResult(
|
|
226
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
227
|
+
data=[],
|
|
228
|
+
rows_affected=0,
|
|
229
|
+
operation_type="SCRIPT",
|
|
230
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
231
|
+
total_statements=-1, # AIOSQLite doesn't provide this info
|
|
232
|
+
successful_statements=-1,
|
|
233
|
+
)
|
|
204
234
|
|
|
205
235
|
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
206
|
-
"""Database-specific bulk load implementation."""
|
|
207
|
-
# TODO: convert this to use the storage backend. it has async support
|
|
236
|
+
"""Database-specific bulk load implementation using storage backend."""
|
|
208
237
|
if format != "csv":
|
|
209
238
|
msg = f"aiosqlite driver only supports CSV for bulk loading, not {format}."
|
|
210
239
|
raise NotImplementedError(msg)
|
|
@@ -215,80 +244,26 @@ class AiosqliteDriver(
|
|
|
215
244
|
if mode == "replace":
|
|
216
245
|
await cursor.execute(f"DELETE FROM {table_name}")
|
|
217
246
|
|
|
218
|
-
#
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
247
|
+
# Use async storage backend to read the file
|
|
248
|
+
file_path_str = str(file_path)
|
|
249
|
+
backend = self._get_storage_backend(file_path_str)
|
|
250
|
+
content = await backend.read_text_async(file_path_str, encoding="utf-8")
|
|
251
|
+
# Parse CSV content
|
|
252
|
+
import io
|
|
253
|
+
|
|
254
|
+
csv_file = io.StringIO(content)
|
|
255
|
+
reader = csv.reader(csv_file, **options)
|
|
256
|
+
header = next(reader) # Skip header
|
|
257
|
+
placeholders = ", ".join("?" for _ in header)
|
|
258
|
+
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
|
|
259
|
+
data_iter = list(reader)
|
|
260
|
+
await cursor.executemany(sql, data_iter)
|
|
261
|
+
rowcount = cursor.rowcount
|
|
227
262
|
await conn.commit()
|
|
228
263
|
return rowcount
|
|
229
264
|
finally:
|
|
230
265
|
await conn.close()
|
|
231
266
|
|
|
232
|
-
async def _wrap_select_result(
|
|
233
|
-
self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
|
|
234
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
235
|
-
fetched_data = result["data"]
|
|
236
|
-
column_names = result["column_names"]
|
|
237
|
-
rows_affected = result["rows_affected"]
|
|
238
|
-
|
|
239
|
-
rows_as_dicts: list[dict[str, Any]] = [dict(row) for row in fetched_data]
|
|
240
|
-
|
|
241
|
-
if self.returns_rows(statement.expression):
|
|
242
|
-
converted_data_seq = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
243
|
-
return SQLResult[ModelDTOT](
|
|
244
|
-
statement=statement,
|
|
245
|
-
data=list(converted_data_seq),
|
|
246
|
-
column_names=column_names,
|
|
247
|
-
rows_affected=rows_affected,
|
|
248
|
-
operation_type="SELECT",
|
|
249
|
-
)
|
|
250
|
-
return SQLResult[RowT](
|
|
251
|
-
statement=statement,
|
|
252
|
-
data=rows_as_dicts,
|
|
253
|
-
column_names=column_names,
|
|
254
|
-
rows_affected=rows_affected,
|
|
255
|
-
operation_type="SELECT",
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
async def _wrap_execute_result(
|
|
259
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
260
|
-
) -> SQLResult[RowT]:
|
|
261
|
-
operation_type = "UNKNOWN"
|
|
262
|
-
if statement.expression:
|
|
263
|
-
operation_type = str(statement.expression.key).upper()
|
|
264
|
-
|
|
265
|
-
if "statements_executed" in result:
|
|
266
|
-
script_result = cast("ScriptResultDict", result)
|
|
267
|
-
return SQLResult[RowT](
|
|
268
|
-
statement=statement,
|
|
269
|
-
data=[],
|
|
270
|
-
rows_affected=0,
|
|
271
|
-
operation_type="SCRIPT",
|
|
272
|
-
total_statements=script_result.get("statements_executed", -1),
|
|
273
|
-
metadata={"status_message": script_result.get("status_message", "")},
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
if "rows_affected" in result:
|
|
277
|
-
dml_result = cast("DMLResultDict", result)
|
|
278
|
-
rows_affected = dml_result["rows_affected"]
|
|
279
|
-
status_message = dml_result["status_message"]
|
|
280
|
-
return SQLResult[RowT](
|
|
281
|
-
statement=statement,
|
|
282
|
-
data=[],
|
|
283
|
-
rows_affected=rows_affected,
|
|
284
|
-
operation_type=operation_type,
|
|
285
|
-
metadata={"status_message": status_message},
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
# This shouldn't happen with TypedDict approach
|
|
289
|
-
msg = f"Unexpected result type: {type(result)}"
|
|
290
|
-
raise ValueError(msg)
|
|
291
|
-
|
|
292
267
|
def _connection(self, connection: Optional[AiosqliteConnection] = None) -> AiosqliteConnection:
|
|
293
268
|
"""Get the connection to use for the operation."""
|
|
294
269
|
return connection or self.connection
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
8
7
|
|
|
9
8
|
import asyncmy
|
|
@@ -193,7 +192,6 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
193
192
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
194
193
|
}
|
|
195
194
|
|
|
196
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
197
195
|
config.update(self.extras)
|
|
198
196
|
|
|
199
197
|
return config
|
|
@@ -264,15 +262,16 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
264
262
|
An AsyncmyDriver instance.
|
|
265
263
|
"""
|
|
266
264
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
267
|
-
# Create statement config with parameter style info if not already set
|
|
268
265
|
statement_config = self.statement_config
|
|
266
|
+
# Inject parameter style info if not already set
|
|
269
267
|
if statement_config.allowed_parameter_styles is None:
|
|
268
|
+
from dataclasses import replace
|
|
269
|
+
|
|
270
270
|
statement_config = replace(
|
|
271
271
|
statement_config,
|
|
272
272
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
273
273
|
target_parameter_style=self.preferred_parameter_style,
|
|
274
274
|
)
|
|
275
|
-
|
|
276
275
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
277
276
|
|
|
278
277
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: ignore
|