sqlspec 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +16 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +17 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +17 -29
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +32 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +18 -9
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +44 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/METADATA +1 -1
- sqlspec-0.13.1.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
8
7
|
|
|
9
8
|
import asyncmy
|
|
9
|
+
from asyncmy.pool import Pool as AsyncmyPool
|
|
10
10
|
|
|
11
11
|
from sqlspec.adapters.asyncmy.driver import AsyncmyConnection, AsyncmyDriver
|
|
12
12
|
from sqlspec.config import AsyncDatabaseConfig
|
|
@@ -193,7 +193,6 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
193
193
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
194
194
|
}
|
|
195
195
|
|
|
196
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
197
196
|
config.update(self.extras)
|
|
198
197
|
|
|
199
198
|
return config
|
|
@@ -264,15 +263,16 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
264
263
|
An AsyncmyDriver instance.
|
|
265
264
|
"""
|
|
266
265
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
267
|
-
# Create statement config with parameter style info if not already set
|
|
268
266
|
statement_config = self.statement_config
|
|
267
|
+
# Inject parameter style info if not already set
|
|
269
268
|
if statement_config.allowed_parameter_styles is None:
|
|
269
|
+
from dataclasses import replace
|
|
270
|
+
|
|
270
271
|
statement_config = replace(
|
|
271
272
|
statement_config,
|
|
272
273
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
273
274
|
target_parameter_style=self.preferred_parameter_style,
|
|
274
275
|
)
|
|
275
|
-
|
|
276
276
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
277
277
|
|
|
278
278
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: ignore
|
|
@@ -284,3 +284,16 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
|
|
|
284
284
|
if not self.pool_instance:
|
|
285
285
|
self.pool_instance = await self.create_pool()
|
|
286
286
|
return self.pool_instance
|
|
287
|
+
|
|
288
|
+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
|
|
289
|
+
"""Get the signature namespace for Asyncmy types.
|
|
290
|
+
|
|
291
|
+
This provides all Asyncmy-specific types that Litestar needs to recognize
|
|
292
|
+
to avoid serialization attempts.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Dictionary mapping type names to types.
|
|
296
|
+
"""
|
|
297
|
+
namespace = super().get_signature_namespace()
|
|
298
|
+
namespace.update({"AsyncmyConnection": AsyncmyConnection, "AsyncmyPool": AsyncmyPool})
|
|
299
|
+
return namespace
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections.abc import AsyncGenerator, Sequence
|
|
3
3
|
from contextlib import asynccontextmanager
|
|
4
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
4
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
|
5
5
|
|
|
6
6
|
from asyncmy import Connection
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
9
9
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
10
11
|
from sqlspec.driver.mixins import (
|
|
11
12
|
AsyncPipelinedExecutionMixin,
|
|
12
13
|
AsyncStorageMixin,
|
|
@@ -14,10 +15,11 @@ from sqlspec.driver.mixins import (
|
|
|
14
15
|
ToSchemaMixin,
|
|
15
16
|
TypeCoercionMixin,
|
|
16
17
|
)
|
|
17
|
-
from sqlspec.
|
|
18
|
-
from sqlspec.statement.
|
|
18
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
19
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
20
|
+
from sqlspec.statement.result import SQLResult
|
|
19
21
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
-
from sqlspec.typing import DictRow,
|
|
22
|
+
from sqlspec.typing import DictRow, RowT
|
|
21
23
|
|
|
22
24
|
if TYPE_CHECKING:
|
|
23
25
|
from asyncmy.cursors import Cursor, DictCursor
|
|
@@ -60,7 +62,7 @@ class AsyncmyDriver(
|
|
|
60
62
|
self, connection: "Optional[AsyncmyConnection]" = None
|
|
61
63
|
) -> "AsyncGenerator[Union[Cursor, DictCursor], None]":
|
|
62
64
|
conn = self._connection(connection)
|
|
63
|
-
cursor =
|
|
65
|
+
cursor = conn.cursor()
|
|
64
66
|
try:
|
|
65
67
|
yield cursor
|
|
66
68
|
finally:
|
|
@@ -68,95 +70,146 @@ class AsyncmyDriver(
|
|
|
68
70
|
|
|
69
71
|
async def _execute_statement(
|
|
70
72
|
self, statement: SQL, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
71
|
-
) ->
|
|
73
|
+
) -> SQLResult[RowT]:
|
|
72
74
|
if statement.is_script:
|
|
73
75
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
74
76
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
75
77
|
|
|
76
|
-
#
|
|
77
|
-
|
|
78
|
+
# Detect parameter styles in the SQL
|
|
79
|
+
detected_styles = set()
|
|
80
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
81
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
82
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
83
|
+
if param_infos:
|
|
84
|
+
detected_styles = {p.style for p in param_infos}
|
|
85
|
+
|
|
86
|
+
# Determine target style based on what's in the SQL
|
|
87
|
+
target_style = self.default_parameter_style
|
|
88
|
+
|
|
89
|
+
# Check if there are unsupported styles
|
|
90
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
91
|
+
if unsupported_styles:
|
|
92
|
+
# Force conversion to default style
|
|
93
|
+
target_style = self.default_parameter_style
|
|
94
|
+
elif detected_styles:
|
|
95
|
+
# Prefer the first supported style found
|
|
96
|
+
for style in detected_styles:
|
|
97
|
+
if style in self.supported_parameter_styles:
|
|
98
|
+
target_style = style
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
# Compile with the determined style
|
|
102
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
78
103
|
|
|
79
104
|
if statement.is_many:
|
|
80
|
-
# Process parameter list through type coercion
|
|
81
105
|
params = self._process_parameters(params)
|
|
82
106
|
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
83
107
|
|
|
84
|
-
# Process parameters through type coercion
|
|
85
108
|
params = self._process_parameters(params)
|
|
86
109
|
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
87
110
|
|
|
88
111
|
async def _execute(
|
|
89
112
|
self, sql: str, parameters: Any, statement: SQL, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
90
|
-
) ->
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
113
|
+
) -> SQLResult[RowT]:
|
|
114
|
+
# Use provided connection or driver's default connection
|
|
115
|
+
conn = connection if connection is not None else self._connection(None)
|
|
116
|
+
|
|
117
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
118
|
+
# Normalize parameters using consolidated utility
|
|
119
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
120
|
+
# AsyncMy doesn't like empty lists/tuples, convert to None
|
|
121
|
+
final_params = (
|
|
122
|
+
normalized_params[0] if normalized_params and len(normalized_params) == 1 else normalized_params
|
|
123
|
+
)
|
|
124
|
+
if not final_params:
|
|
125
|
+
final_params = None
|
|
126
|
+
|
|
127
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
128
|
+
# AsyncMy expects list/tuple parameters or dict for named params
|
|
129
|
+
await cursor.execute(sql, final_params)
|
|
130
|
+
|
|
131
|
+
if self.returns_rows(statement.expression):
|
|
132
|
+
# For SELECT queries, fetch data and return SQLResult
|
|
133
|
+
data = await cursor.fetchall()
|
|
134
|
+
column_names = [desc[0] for desc in cursor.description or []]
|
|
135
|
+
return SQLResult(
|
|
136
|
+
statement=statement,
|
|
137
|
+
data=data,
|
|
138
|
+
column_names=column_names,
|
|
139
|
+
rows_affected=len(data),
|
|
140
|
+
operation_type="SELECT",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# For DML/DDL queries
|
|
144
|
+
return SQLResult(
|
|
145
|
+
statement=statement,
|
|
146
|
+
data=[],
|
|
147
|
+
rows_affected=cursor.rowcount if cursor.rowcount is not None else -1,
|
|
148
|
+
operation_type=self._determine_operation_type(statement),
|
|
149
|
+
metadata={"status_message": "OK"},
|
|
150
|
+
)
|
|
112
151
|
|
|
113
152
|
async def _execute_many(
|
|
114
153
|
self, sql: str, param_list: Any, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
115
|
-
) ->
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
154
|
+
) -> SQLResult[RowT]:
|
|
155
|
+
# Use provided connection or driver's default connection
|
|
156
|
+
conn = connection if connection is not None else self._connection(None)
|
|
157
|
+
|
|
158
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
159
|
+
# Normalize parameter list using consolidated utility
|
|
160
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
161
|
+
|
|
162
|
+
params_list: list[Union[list[Any], tuple[Any, ...]]] = []
|
|
163
|
+
if normalized_param_list and isinstance(normalized_param_list, Sequence):
|
|
164
|
+
for param_set in normalized_param_list:
|
|
165
|
+
if isinstance(param_set, (list, tuple)):
|
|
166
|
+
params_list.append(param_set)
|
|
167
|
+
elif param_set is None:
|
|
168
|
+
params_list.append([])
|
|
169
|
+
else:
|
|
170
|
+
params_list.append([param_set])
|
|
171
|
+
|
|
172
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
173
|
+
await cursor.executemany(sql, params_list)
|
|
174
|
+
return SQLResult(
|
|
175
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
176
|
+
data=[],
|
|
177
|
+
rows_affected=cursor.rowcount if cursor.rowcount != -1 else len(params_list),
|
|
178
|
+
operation_type="EXECUTE",
|
|
179
|
+
metadata={"status_message": "OK"},
|
|
180
|
+
)
|
|
136
181
|
|
|
137
182
|
async def _execute_script(
|
|
138
183
|
self, script: str, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
|
|
139
|
-
) ->
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
184
|
+
) -> SQLResult[RowT]:
|
|
185
|
+
# Use provided connection or driver's default connection
|
|
186
|
+
conn = connection if connection is not None else self._connection(None)
|
|
187
|
+
|
|
188
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
189
|
+
# AsyncMy may not support multi-statement scripts without CLIENT_MULTI_STATEMENTS flag
|
|
190
|
+
statements = self._split_script_statements(script)
|
|
191
|
+
statements_executed = 0
|
|
192
|
+
|
|
193
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
194
|
+
for statement_str in statements:
|
|
195
|
+
if statement_str:
|
|
196
|
+
await cursor.execute(statement_str)
|
|
197
|
+
statements_executed += 1
|
|
198
|
+
|
|
199
|
+
return SQLResult(
|
|
200
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
201
|
+
data=[],
|
|
202
|
+
rows_affected=0,
|
|
203
|
+
operation_type="SCRIPT",
|
|
204
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
205
|
+
total_statements=statements_executed,
|
|
206
|
+
successful_statements=statements_executed,
|
|
207
|
+
)
|
|
154
208
|
|
|
155
209
|
async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|
|
156
210
|
self._ensure_pyarrow_installed()
|
|
157
211
|
conn = self._connection(None)
|
|
158
|
-
|
|
159
|
-
async with self._get_cursor(conn) as cursor:
|
|
212
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
|
|
160
213
|
if mode == "replace":
|
|
161
214
|
await cursor.execute(f"TRUNCATE TABLE {table_name}")
|
|
162
215
|
elif mode == "create":
|
|
@@ -174,71 +227,6 @@ class AsyncmyDriver(
|
|
|
174
227
|
await cursor.executemany(sql, data_for_ingest)
|
|
175
228
|
return cursor.rowcount if cursor.rowcount is not None else -1
|
|
176
229
|
|
|
177
|
-
async def _wrap_select_result(
|
|
178
|
-
self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
|
|
179
|
-
) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
|
|
180
|
-
data = result["data"]
|
|
181
|
-
column_names = result["column_names"]
|
|
182
|
-
rows_affected = result["rows_affected"]
|
|
183
|
-
|
|
184
|
-
if not data:
|
|
185
|
-
return SQLResult[RowT](
|
|
186
|
-
statement=statement, data=[], column_names=column_names, rows_affected=0, operation_type="SELECT"
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
rows_as_dicts = [dict(zip(column_names, row)) for row in data]
|
|
190
|
-
|
|
191
|
-
if schema_type:
|
|
192
|
-
converted_data = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
193
|
-
return SQLResult[ModelDTOT](
|
|
194
|
-
statement=statement,
|
|
195
|
-
data=list(converted_data),
|
|
196
|
-
column_names=column_names,
|
|
197
|
-
rows_affected=rows_affected,
|
|
198
|
-
operation_type="SELECT",
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
return SQLResult[RowT](
|
|
202
|
-
statement=statement,
|
|
203
|
-
data=rows_as_dicts,
|
|
204
|
-
column_names=column_names,
|
|
205
|
-
rows_affected=rows_affected,
|
|
206
|
-
operation_type="SELECT",
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
async def _wrap_execute_result(
|
|
210
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
211
|
-
) -> SQLResult[RowT]:
|
|
212
|
-
operation_type = "UNKNOWN"
|
|
213
|
-
if statement.expression:
|
|
214
|
-
operation_type = str(statement.expression.key).upper()
|
|
215
|
-
|
|
216
|
-
# Handle script results
|
|
217
|
-
if "statements_executed" in result:
|
|
218
|
-
script_result = cast("ScriptResultDict", result)
|
|
219
|
-
return SQLResult[RowT](
|
|
220
|
-
statement=statement,
|
|
221
|
-
data=[],
|
|
222
|
-
rows_affected=0,
|
|
223
|
-
operation_type="SCRIPT",
|
|
224
|
-
metadata={
|
|
225
|
-
"status_message": script_result.get("status_message", ""),
|
|
226
|
-
"statements_executed": script_result.get("statements_executed", -1),
|
|
227
|
-
},
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
# Handle DML results
|
|
231
|
-
dml_result = cast("DMLResultDict", result)
|
|
232
|
-
rows_affected = dml_result.get("rows_affected", -1)
|
|
233
|
-
status_message = dml_result.get("status_message", "")
|
|
234
|
-
return SQLResult[RowT](
|
|
235
|
-
statement=statement,
|
|
236
|
-
data=[],
|
|
237
|
-
rows_affected=rows_affected,
|
|
238
|
-
operation_type=operation_type,
|
|
239
|
-
metadata={"status_message": status_message},
|
|
240
|
-
)
|
|
241
|
-
|
|
242
230
|
def _connection(self, connection: Optional["AsyncmyConnection"] = None) -> "AsyncmyConnection":
|
|
243
231
|
"""Get the connection to use for the operation."""
|
|
244
232
|
return connection or self.connection
|
|
@@ -3,11 +3,12 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
|
|
8
7
|
|
|
9
|
-
from asyncpg import Record
|
|
8
|
+
from asyncpg import Connection, Record
|
|
10
9
|
from asyncpg import create_pool as asyncpg_create_pool
|
|
10
|
+
from asyncpg.connection import ConnectionMeta
|
|
11
|
+
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
|
|
11
12
|
from typing_extensions import NotRequired, Unpack
|
|
12
13
|
|
|
13
14
|
from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver
|
|
@@ -19,7 +20,6 @@ from sqlspec.utils.serializers import from_json, to_json
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
from asyncio.events import AbstractEventLoop
|
|
21
22
|
|
|
22
|
-
from asyncpg.pool import Pool
|
|
23
23
|
from sqlglot.dialects.dialect import DialectType
|
|
24
24
|
|
|
25
25
|
|
|
@@ -224,7 +224,6 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
224
224
|
|
|
225
225
|
super().__init__()
|
|
226
226
|
|
|
227
|
-
# Set pool_instance after super().__init__() to ensure it's not overridden
|
|
228
227
|
if pool_instance_from_kwargs is not None:
|
|
229
228
|
self.pool_instance = pool_instance_from_kwargs
|
|
230
229
|
|
|
@@ -241,7 +240,6 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
241
240
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
242
241
|
}
|
|
243
242
|
|
|
244
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
245
243
|
config.update(self.extras)
|
|
246
244
|
|
|
247
245
|
return config
|
|
@@ -318,15 +316,16 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
318
316
|
An AsyncpgDriver instance.
|
|
319
317
|
"""
|
|
320
318
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
321
|
-
# Create statement config with parameter style info if not already set
|
|
322
319
|
statement_config = self.statement_config
|
|
320
|
+
# Inject parameter style info if not already set
|
|
323
321
|
if statement_config is not None and statement_config.allowed_parameter_styles is None:
|
|
322
|
+
from dataclasses import replace
|
|
323
|
+
|
|
324
324
|
statement_config = replace(
|
|
325
325
|
statement_config,
|
|
326
326
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
327
327
|
target_parameter_style=self.preferred_parameter_style,
|
|
328
328
|
)
|
|
329
|
-
|
|
330
329
|
yield self.driver_type(connection=connection, config=statement_config)
|
|
331
330
|
|
|
332
331
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
|
|
@@ -348,27 +347,16 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
348
347
|
Returns:
|
|
349
348
|
Dictionary mapping type names to types.
|
|
350
349
|
"""
|
|
351
|
-
# Get base types from parent
|
|
352
350
|
namespace = super().get_signature_namespace()
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
"PoolConnectionProxy": PoolConnectionProxy,
|
|
365
|
-
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
|
|
366
|
-
"ConnectionMeta": ConnectionMeta,
|
|
367
|
-
"Record": Record,
|
|
368
|
-
"AsyncpgConnection": type(AsyncpgConnection), # The Union type alias
|
|
369
|
-
}
|
|
370
|
-
)
|
|
371
|
-
except ImportError:
|
|
372
|
-
logger.warning("Failed to import AsyncPG types for signature namespace")
|
|
373
|
-
|
|
351
|
+
namespace.update(
|
|
352
|
+
{
|
|
353
|
+
"Connection": Connection,
|
|
354
|
+
"Pool": Pool,
|
|
355
|
+
"PoolConnectionProxy": PoolConnectionProxy,
|
|
356
|
+
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
|
|
357
|
+
"ConnectionMeta": ConnectionMeta,
|
|
358
|
+
"Record": Record,
|
|
359
|
+
"AsyncpgConnection": type(AsyncpgConnection),
|
|
360
|
+
}
|
|
361
|
+
)
|
|
374
362
|
return namespace
|