sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
8
7
|
|
|
9
8
|
import asyncmy
|
|
@@ -193,7 +192,6 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
193
192
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
194
193
|
}
|
|
195
194
|
|
|
196
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
197
195
|
config.update(self.extras)
|
|
198
196
|
|
|
199
197
|
return config
|
|
@@ -264,15 +262,16 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
264
262
|
An AsyncmyDriver instance.
|
|
265
263
|
"""
|
|
266
264
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
267
|
-
# Create statement config with parameter style info if not already set
|
|
268
265
|
statement_config = self.statement_config
|
|
266
|
+
# Inject parameter style info if not already set
|
|
269
267
|
if statement_config.allowed_parameter_styles is None:
|
|
268
|
+
from dataclasses import replace
|
|
269
|
+
|
|
270
270
|
statement_config = replace(
|
|
271
271
|
statement_config,
|
|
272
272
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
273
273
|
target_parameter_style=self.preferred_parameter_style,
|
|
274
274
|
)
|
|
275
|
-
|
|
276
275
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
277
276
|
|
|
278
277
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: ignore
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections.abc import AsyncGenerator, Sequence
|
|
3
3
|
from contextlib import asynccontextmanager
|
|
4
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
4
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
5
5
|
|
|
6
6
|
from asyncmy import Connection
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
9
9
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
10
11
|
from sqlspec.driver.mixins import (
|
|
11
12
|
AsyncPipelinedExecutionMixin,
|
|
12
13
|
AsyncStorageMixin,
|
|
@@ -14,10 +15,11 @@ from sqlspec.driver.mixins import (
|
|
|
14
15
|
ToSchemaMixin,
|
|
15
16
|
TypeCoercionMixin,
|
|
16
17
|
)
|
|
17
|
-
from sqlspec.
|
|
18
|
-
from sqlspec.statement.
|
|
18
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
19
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
20
|
+
from sqlspec.statement.result import SQLResult
|
|
19
21
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
-
from sqlspec.typing import DictRow,
|
|
22
|
+
from sqlspec.typing import DictRow, RowT
|
|
21
23
|
|
|
22
24
|
if TYPE_CHECKING:
|
|
23
25
|
from asyncmy.cursors import Cursor, DictCursor
|
|
@@ -60,7 +62,7 @@ class AsyncmyDriver(
|
|
|
60
62
|
self, connection: "Optional[AsyncmyConnection]" = None
|
|
61
63
|
) -> "AsyncGenerator[Union[Cursor, DictCursor], None]":
|
|
62
64
|
conn = self._connection(connection)
|
|
63
|
-
cursor =
|
|
65
|
+
cursor = conn.cursor()
|
|
64
66
|
try:
|
|
65
67
|
yield cursor
|
|
66
68
|
finally:
|
|
@@ -68,95 +70,146 @@ class AsyncmyDriver(
|
|
|
68
70
|
|
|
69
71
|
async def _execute_statement(
|
|
70
72
|
self, statement: SQL, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
71
|
-
) ->
|
|
73
|
+
) -> SQLResult[RowT]:
|
|
72
74
|
if statement.is_script:
|
|
73
75
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
74
76
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
75
77
|
|
|
76
|
-
#
|
|
77
|
-
|
|
78
|
+
# Detect parameter styles in the SQL
|
|
79
|
+
detected_styles = set()
|
|
80
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
81
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
82
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
83
|
+
if param_infos:
|
|
84
|
+
detected_styles = {p.style for p in param_infos}
|
|
85
|
+
|
|
86
|
+
# Determine target style based on what's in the SQL
|
|
87
|
+
target_style = self.default_parameter_style
|
|
88
|
+
|
|
89
|
+
# Check if there are unsupported styles
|
|
90
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
91
|
+
if unsupported_styles:
|
|
92
|
+
# Force conversion to default style
|
|
93
|
+
target_style = self.default_parameter_style
|
|
94
|
+
elif detected_styles:
|
|
95
|
+
# Prefer the first supported style found
|
|
96
|
+
for style in detected_styles:
|
|
97
|
+
if style in self.supported_parameter_styles:
|
|
98
|
+
target_style = style
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
# Compile with the determined style
|
|
102
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
78
103
|
|
|
79
104
|
if statement.is_many:
|
|
80
|
-
# Process parameter list through type coercion
|
|
81
105
|
params = self._process_parameters(params)
|
|
82
106
|
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
83
107
|
|
|
84
|
-
# Process parameters through type coercion
|
|
85
108
|
params = self._process_parameters(params)
|
|
86
109
|
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
87
110
|
|
|
88
111
|
async def _execute(
|
|
89
112
|
self, sql: str, parameters: Any, statement: SQL, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
90
|
-
) ->
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
113
|
+
) -> SQLResult[RowT]:
|
|
114
|
+
# Use provided connection or driver's default connection
|
|
115
|
+
conn = connection if connection is not None else self._connection(None)
|
|
116
|
+
|
|
117
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
118
|
+
# Normalize parameters using consolidated utility
|
|
119
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
120
|
+
# AsyncMy doesn't like empty lists/tuples, convert to None
|
|
121
|
+
final_params = (
|
|
122
|
+
normalized_params[0] if normalized_params and len(normalized_params) == 1 else normalized_params
|
|
123
|
+
)
|
|
124
|
+
if not final_params:
|
|
125
|
+
final_params = None
|
|
126
|
+
|
|
127
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
128
|
+
# AsyncMy expects list/tuple parameters or dict for named params
|
|
129
|
+
await cursor.execute(sql, final_params)
|
|
130
|
+
|
|
131
|
+
if self.returns_rows(statement.expression):
|
|
132
|
+
# For SELECT queries, fetch data and return SQLResult
|
|
133
|
+
data = await cursor.fetchall()
|
|
134
|
+
column_names = [desc[0] for desc in cursor.description or []]
|
|
135
|
+
return SQLResult(
|
|
136
|
+
statement=statement,
|
|
137
|
+
data=data,
|
|
138
|
+
column_names=column_names,
|
|
139
|
+
rows_affected=len(data),
|
|
140
|
+
operation_type="SELECT",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# For DML/DDL queries
|
|
144
|
+
return SQLResult(
|
|
145
|
+
statement=statement,
|
|
146
|
+
data=[],
|
|
147
|
+
rows_affected=cursor.rowcount if cursor.rowcount is not None else -1,
|
|
148
|
+
operation_type=self._determine_operation_type(statement),
|
|
149
|
+
metadata={"status_message": "OK"},
|
|
150
|
+
)
|
|
112
151
|
|
|
113
152
|
async def _execute_many(
|
|
114
153
|
self, sql: str, param_list: Any, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
115
|
-
) ->
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
154
|
+
) -> SQLResult[RowT]:
|
|
155
|
+
# Use provided connection or driver's default connection
|
|
156
|
+
conn = connection if connection is not None else self._connection(None)
|
|
157
|
+
|
|
158
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
159
|
+
# Normalize parameter list using consolidated utility
|
|
160
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
161
|
+
|
|
162
|
+
params_list: list[Union[list[Any], tuple[Any, ...]]] = []
|
|
163
|
+
if normalized_param_list and isinstance(normalized_param_list, Sequence):
|
|
164
|
+
for param_set in normalized_param_list:
|
|
165
|
+
if isinstance(param_set, (list, tuple)):
|
|
166
|
+
params_list.append(param_set)
|
|
167
|
+
elif param_set is None:
|
|
168
|
+
params_list.append([])
|
|
169
|
+
else:
|
|
170
|
+
params_list.append([param_set])
|
|
171
|
+
|
|
172
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
173
|
+
await cursor.executemany(sql, params_list)
|
|
174
|
+
return SQLResult(
|
|
175
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
176
|
+
data=[],
|
|
177
|
+
rows_affected=cursor.rowcount if cursor.rowcount != -1 else len(params_list),
|
|
178
|
+
operation_type="EXECUTE",
|
|
179
|
+
metadata={"status_message": "OK"},
|
|
180
|
+
)
|
|
136
181
|
|
|
137
182
|
async def _execute_script(
|
|
138
183
|
self, script: str, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
139
|
-
) ->
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
184
|
+
) -> SQLResult[RowT]:
|
|
185
|
+
# Use provided connection or driver's default connection
|
|
186
|
+
conn = connection if connection is not None else self._connection(None)
|
|
187
|
+
|
|
188
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
189
|
+
# AsyncMy may not support multi-statement scripts without CLIENT_MULTI_STATEMENTS flag
|
|
190
|
+
statements = self._split_script_statements(script)
|
|
191
|
+
statements_executed = 0
|
|
192
|
+
|
|
193
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
194
|
+
for statement_str in statements:
|
|
195
|
+
if statement_str:
|
|
196
|
+
await cursor.execute(statement_str)
|
|
197
|
+
statements_executed += 1
|
|
198
|
+
|
|
199
|
+
return SQLResult(
|
|
200
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
201
|
+
data=[],
|
|
202
|
+
rows_affected=0,
|
|
203
|
+
operation_type="SCRIPT",
|
|
204
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
205
|
+
total_statements=statements_executed,
|
|
206
|
+
successful_statements=statements_executed,
|
|
207
|
+
)
|
|
154
208
|
|
|
155
209
|
async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|
|
156
210
|
self._ensure_pyarrow_installed()
|
|
157
211
|
conn = self._connection(None)
|
|
158
|
-
|
|
159
|
-
async with self._get_cursor(conn) as cursor:
|
|
212
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
|
|
160
213
|
if mode == "replace":
|
|
161
214
|
await cursor.execute(f"TRUNCATE TABLE {table_name}")
|
|
162
215
|
elif mode == "create":
|
|
@@ -174,71 +227,6 @@ class AsyncmyDriver(
|
|
|
174
227
|
await cursor.executemany(sql, data_for_ingest)
|
|
175
228
|
return cursor.rowcount if cursor.rowcount is not None else -1
|
|
176
229
|
|
|
177
|
-
async def _wrap_select_result(
|
|
178
|
-
self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
|
|
179
|
-
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
180
|
-
data = result["data"]
|
|
181
|
-
column_names = result["column_names"]
|
|
182
|
-
rows_affected = result["rows_affected"]
|
|
183
|
-
|
|
184
|
-
if not data:
|
|
185
|
-
return SQLResult[RowT](
|
|
186
|
-
statement=statement, data=[], column_names=column_names, rows_affected=0, operation_type="SELECT"
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
rows_as_dicts = [dict(zip(column_names, row)) for row in data]
|
|
190
|
-
|
|
191
|
-
if schema_type:
|
|
192
|
-
converted_data = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
193
|
-
return SQLResult[ModelDTOT](
|
|
194
|
-
statement=statement,
|
|
195
|
-
data=list(converted_data),
|
|
196
|
-
column_names=column_names,
|
|
197
|
-
rows_affected=rows_affected,
|
|
198
|
-
operation_type="SELECT",
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
return SQLResult[RowT](
|
|
202
|
-
statement=statement,
|
|
203
|
-
data=rows_as_dicts,
|
|
204
|
-
column_names=column_names,
|
|
205
|
-
rows_affected=rows_affected,
|
|
206
|
-
operation_type="SELECT",
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
async def _wrap_execute_result(
|
|
210
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
211
|
-
) -> SQLResult[RowT]:
|
|
212
|
-
operation_type = "UNKNOWN"
|
|
213
|
-
if statement.expression:
|
|
214
|
-
operation_type = str(statement.expression.key).upper()
|
|
215
|
-
|
|
216
|
-
# Handle script results
|
|
217
|
-
if "statements_executed" in result:
|
|
218
|
-
script_result = cast("ScriptResultDict", result)
|
|
219
|
-
return SQLResult[RowT](
|
|
220
|
-
statement=statement,
|
|
221
|
-
data=[],
|
|
222
|
-
rows_affected=0,
|
|
223
|
-
operation_type="SCRIPT",
|
|
224
|
-
metadata={
|
|
225
|
-
"status_message": script_result.get("status_message", ""),
|
|
226
|
-
"statements_executed": script_result.get("statements_executed", -1),
|
|
227
|
-
},
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
# Handle DML results
|
|
231
|
-
dml_result = cast("DMLResultDict", result)
|
|
232
|
-
rows_affected = dml_result.get("rows_affected", -1)
|
|
233
|
-
status_message = dml_result.get("status_message", "")
|
|
234
|
-
return SQLResult[RowT](
|
|
235
|
-
statement=statement,
|
|
236
|
-
data=[],
|
|
237
|
-
rows_affected=rows_affected,
|
|
238
|
-
operation_type=operation_type,
|
|
239
|
-
metadata={"status_message": status_message},
|
|
240
|
-
)
|
|
241
|
-
|
|
242
230
|
def _connection(self, connection: Optional["AsyncmyConnection"] = None) -> "AsyncmyConnection":
|
|
243
231
|
"""Get the connection to use for the operation."""
|
|
244
232
|
return connection or self.connection
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
|
|
8
7
|
|
|
9
8
|
from asyncpg import Record
|
|
@@ -224,7 +223,6 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
224
223
|
|
|
225
224
|
super().__init__()
|
|
226
225
|
|
|
227
|
-
# Set pool_instance after super().__init__() to ensure it's not overridden
|
|
228
226
|
if pool_instance_from_kwargs is not None:
|
|
229
227
|
self.pool_instance = pool_instance_from_kwargs
|
|
230
228
|
|
|
@@ -241,7 +239,6 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
241
239
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
242
240
|
}
|
|
243
241
|
|
|
244
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
245
242
|
config.update(self.extras)
|
|
246
243
|
|
|
247
244
|
return config
|
|
@@ -318,15 +315,16 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
318
315
|
An AsyncpgDriver instance.
|
|
319
316
|
"""
|
|
320
317
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
321
|
-
# Create statement config with parameter style info if not already set
|
|
322
318
|
statement_config = self.statement_config
|
|
319
|
+
# Inject parameter style info if not already set
|
|
323
320
|
if statement_config is not None and statement_config.allowed_parameter_styles is None:
|
|
321
|
+
from dataclasses import replace
|
|
322
|
+
|
|
324
323
|
statement_config = replace(
|
|
325
324
|
statement_config,
|
|
326
325
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
327
326
|
target_parameter_style=self.preferred_parameter_style,
|
|
328
327
|
)
|
|
329
|
-
|
|
330
328
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
331
329
|
|
|
332
330
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
|
|
@@ -348,10 +346,8 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
348
346
|
Returns:
|
|
349
347
|
Dictionary mapping type names to types.
|
|
350
348
|
"""
|
|
351
|
-
# Get base types from parent
|
|
352
349
|
namespace = super().get_signature_namespace()
|
|
353
350
|
|
|
354
|
-
# Add AsyncPG-specific types
|
|
355
351
|
try:
|
|
356
352
|
from asyncpg import Connection, Record
|
|
357
353
|
from asyncpg.connection import ConnectionMeta
|