sqlspec 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +16 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +17 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +17 -29
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +32 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +18 -9
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +44 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/METADATA +1 -1
- sqlspec-0.13.1.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from collections.abc import Sequence
|
|
3
2
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
4
3
|
|
|
5
4
|
from asyncpg import Connection as AsyncpgNativeConnection
|
|
@@ -7,6 +6,7 @@ from asyncpg import Record
|
|
|
7
6
|
from typing_extensions import TypeAlias
|
|
8
7
|
|
|
9
8
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
9
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
10
10
|
from sqlspec.driver.mixins import (
|
|
11
11
|
AsyncPipelinedExecutionMixin,
|
|
12
12
|
AsyncStorageMixin,
|
|
@@ -14,10 +14,11 @@ from sqlspec.driver.mixins import (
|
|
|
14
14
|
ToSchemaMixin,
|
|
15
15
|
TypeCoercionMixin,
|
|
16
16
|
)
|
|
17
|
-
from sqlspec.
|
|
18
|
-
from sqlspec.statement.
|
|
17
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
18
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
19
|
+
from sqlspec.statement.result import SQLResult
|
|
19
20
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
-
from sqlspec.typing import DictRow,
|
|
21
|
+
from sqlspec.typing import DictRow, RowT
|
|
21
22
|
from sqlspec.utils.logging import get_logger
|
|
22
23
|
|
|
23
24
|
if TYPE_CHECKING:
|
|
@@ -69,12 +70,10 @@ class AsyncpgDriver(
|
|
|
69
70
|
# AsyncPG-specific type coercion overrides (PostgreSQL has rich native types)
|
|
70
71
|
def _coerce_boolean(self, value: Any) -> Any:
|
|
71
72
|
"""AsyncPG/PostgreSQL has native boolean support."""
|
|
72
|
-
# Keep booleans as-is, AsyncPG handles them natively
|
|
73
73
|
return value
|
|
74
74
|
|
|
75
75
|
def _coerce_decimal(self, value: Any) -> Any:
|
|
76
76
|
"""AsyncPG/PostgreSQL has native decimal/numeric support."""
|
|
77
|
-
# Keep decimals as-is, AsyncPG handles them natively
|
|
78
77
|
return value
|
|
79
78
|
|
|
80
79
|
def _coerce_json(self, value: Any) -> Any:
|
|
@@ -84,20 +83,24 @@ class AsyncpgDriver(
|
|
|
84
83
|
|
|
85
84
|
def _coerce_array(self, value: Any) -> Any:
|
|
86
85
|
"""AsyncPG/PostgreSQL has native array support."""
|
|
87
|
-
# Convert tuples to lists for consistency
|
|
88
86
|
if isinstance(value, tuple):
|
|
89
87
|
return list(value)
|
|
90
|
-
# Keep other arrays as-is, AsyncPG handles them natively
|
|
91
88
|
return value
|
|
92
89
|
|
|
93
90
|
async def _execute_statement(
|
|
94
91
|
self, statement: SQL, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
95
|
-
) ->
|
|
92
|
+
) -> SQLResult[RowT]:
|
|
96
93
|
if statement.is_script:
|
|
97
94
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
98
95
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
99
96
|
|
|
100
|
-
detected_styles =
|
|
97
|
+
detected_styles = set()
|
|
98
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
99
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
100
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
101
|
+
if param_infos:
|
|
102
|
+
detected_styles = {p.style for p in param_infos}
|
|
103
|
+
|
|
101
104
|
target_style = self.default_parameter_style
|
|
102
105
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
103
106
|
if unsupported_styles:
|
|
@@ -117,143 +120,104 @@ class AsyncpgDriver(
|
|
|
117
120
|
|
|
118
121
|
async def _execute(
|
|
119
122
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
120
|
-
) ->
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
parameters = self._process_parameters(parameters)
|
|
123
|
+
) -> SQLResult[RowT]:
|
|
124
|
+
# Use provided connection or driver's default connection
|
|
125
|
+
conn = connection if connection is not None else self._connection(None)
|
|
124
126
|
|
|
125
|
-
# Check if this is actually a many operation that was misrouted
|
|
126
127
|
if statement.is_many:
|
|
127
128
|
# This should have gone to _execute_many, redirect it
|
|
128
129
|
return await self._execute_many(sql, parameters, connection=connection, **kwargs)
|
|
129
130
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
args_for_driver
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
rows_affected = 0
|
|
151
|
-
if status and isinstance(status, str):
|
|
152
|
-
match = ASYNC_PG_STATUS_REGEX.match(status)
|
|
153
|
-
if match and len(match.groups()) >= EXPECTED_REGEX_GROUPS:
|
|
154
|
-
rows_affected = int(match.group(3))
|
|
155
|
-
|
|
156
|
-
dml_result: DMLResultDict = {"rows_affected": rows_affected, "status_message": status or "OK"}
|
|
157
|
-
return dml_result
|
|
131
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
132
|
+
# Normalize parameters using consolidated utility
|
|
133
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
134
|
+
# AsyncPG expects parameters as *args, not a single list
|
|
135
|
+
args_for_driver: list[Any] = []
|
|
136
|
+
if normalized_params:
|
|
137
|
+
# normalized_params is already a list, just use it directly
|
|
138
|
+
args_for_driver = normalized_params
|
|
139
|
+
|
|
140
|
+
if self.returns_rows(statement.expression):
|
|
141
|
+
records = await txn_conn.fetch(sql, *args_for_driver)
|
|
142
|
+
data = [dict(record) for record in records]
|
|
143
|
+
column_names = list(records[0].keys()) if records else []
|
|
144
|
+
return SQLResult(
|
|
145
|
+
statement=statement,
|
|
146
|
+
data=cast("list[RowT]", data),
|
|
147
|
+
column_names=column_names,
|
|
148
|
+
rows_affected=len(records),
|
|
149
|
+
operation_type="SELECT",
|
|
150
|
+
)
|
|
158
151
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
params_list: list[tuple[Any, ...]] = []
|
|
167
|
-
rows_affected = 0
|
|
168
|
-
if param_list and isinstance(param_list, Sequence):
|
|
169
|
-
for param_set in param_list:
|
|
170
|
-
if isinstance(param_set, (list, tuple)):
|
|
171
|
-
params_list.append(tuple(param_set))
|
|
172
|
-
elif param_set is None:
|
|
173
|
-
params_list.append(())
|
|
174
|
-
else:
|
|
175
|
-
params_list.append((param_set,))
|
|
176
|
-
|
|
177
|
-
await conn.executemany(sql, params_list)
|
|
178
|
-
# AsyncPG's executemany returns None, not a status string
|
|
179
|
-
# We need to use the number of parameter sets as the row count
|
|
180
|
-
rows_affected = len(params_list)
|
|
181
|
-
|
|
182
|
-
dml_result: DMLResultDict = {"rows_affected": rows_affected, "status_message": "OK"}
|
|
183
|
-
return dml_result
|
|
152
|
+
status = await txn_conn.execute(sql, *args_for_driver)
|
|
153
|
+
# Parse row count from status string
|
|
154
|
+
rows_affected = 0
|
|
155
|
+
if status and isinstance(status, str):
|
|
156
|
+
match = ASYNC_PG_STATUS_REGEX.match(status)
|
|
157
|
+
if match and len(match.groups()) >= EXPECTED_REGEX_GROUPS:
|
|
158
|
+
rows_affected = int(match.group(3))
|
|
184
159
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
) -> ScriptResultDict:
|
|
188
|
-
conn = self._connection(connection)
|
|
189
|
-
status = await conn.execute(script)
|
|
190
|
-
|
|
191
|
-
result: ScriptResultDict = {
|
|
192
|
-
"statements_executed": -1, # AsyncPG doesn't provide statement count
|
|
193
|
-
"status_message": status or "SCRIPT EXECUTED",
|
|
194
|
-
}
|
|
195
|
-
return result
|
|
196
|
-
|
|
197
|
-
async def _wrap_select_result(
|
|
198
|
-
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
199
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
200
|
-
records = cast("list[Record]", result["data"])
|
|
201
|
-
column_names = result["column_names"]
|
|
202
|
-
rows_affected = result["rows_affected"]
|
|
203
|
-
|
|
204
|
-
rows_as_dicts: list[dict[str, Any]] = [dict(record) for record in records]
|
|
205
|
-
|
|
206
|
-
if schema_type:
|
|
207
|
-
converted_data_seq = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
208
|
-
converted_data_list = list(converted_data_seq) if converted_data_seq is not None else []
|
|
209
|
-
return SQLResult[ModelDTOT](
|
|
160
|
+
operation_type = self._determine_operation_type(statement)
|
|
161
|
+
return SQLResult(
|
|
210
162
|
statement=statement,
|
|
211
|
-
data=
|
|
212
|
-
column_names=column_names,
|
|
163
|
+
data=cast("list[RowT]", []),
|
|
213
164
|
rows_affected=rows_affected,
|
|
214
|
-
operation_type=
|
|
165
|
+
operation_type=operation_type,
|
|
166
|
+
metadata={"status_message": status or "OK"},
|
|
215
167
|
)
|
|
216
168
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
data=cast("list[RowT]", rows_as_dicts),
|
|
220
|
-
column_names=column_names,
|
|
221
|
-
rows_affected=rows_affected,
|
|
222
|
-
operation_type="SELECT",
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
async def _wrap_execute_result(
|
|
226
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
169
|
+
async def _execute_many(
|
|
170
|
+
self, sql: str, param_list: Any, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
227
171
|
) -> SQLResult[RowT]:
|
|
228
|
-
|
|
229
|
-
if
|
|
230
|
-
|
|
172
|
+
# Use provided connection or driver's default connection
|
|
173
|
+
conn = connection if connection is not None else self._connection(None)
|
|
174
|
+
|
|
175
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
176
|
+
# Normalize parameter list using consolidated utility
|
|
177
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
178
|
+
|
|
179
|
+
params_list: list[tuple[Any, ...]] = []
|
|
180
|
+
rows_affected = 0
|
|
181
|
+
if normalized_param_list:
|
|
182
|
+
for param_set in normalized_param_list:
|
|
183
|
+
if isinstance(param_set, (list, tuple)):
|
|
184
|
+
params_list.append(tuple(param_set))
|
|
185
|
+
elif param_set is None:
|
|
186
|
+
params_list.append(())
|
|
187
|
+
else:
|
|
188
|
+
params_list.append((param_set,))
|
|
189
|
+
|
|
190
|
+
await txn_conn.executemany(sql, params_list)
|
|
191
|
+
# AsyncPG's executemany returns None, not a status string
|
|
192
|
+
# We need to use the number of parameter sets as the row count
|
|
193
|
+
rows_affected = len(params_list)
|
|
194
|
+
|
|
195
|
+
return SQLResult(
|
|
196
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
197
|
+
data=[],
|
|
198
|
+
rows_affected=rows_affected,
|
|
199
|
+
operation_type="EXECUTE",
|
|
200
|
+
metadata={"status_message": "OK"},
|
|
201
|
+
)
|
|
231
202
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
203
|
+
async def _execute_script(
|
|
204
|
+
self, script: str, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
205
|
+
) -> SQLResult[RowT]:
|
|
206
|
+
# Use provided connection or driver's default connection
|
|
207
|
+
conn = connection if connection is not None else self._connection(None)
|
|
208
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
209
|
+
status = await txn_conn.execute(script)
|
|
210
|
+
|
|
211
|
+
return SQLResult(
|
|
212
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
213
|
+
data=[],
|
|
237
214
|
rows_affected=0,
|
|
238
215
|
operation_type="SCRIPT",
|
|
239
|
-
metadata={
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
},
|
|
216
|
+
metadata={"status_message": status or "SCRIPT EXECUTED"},
|
|
217
|
+
total_statements=1,
|
|
218
|
+
successful_statements=1,
|
|
243
219
|
)
|
|
244
220
|
|
|
245
|
-
# Handle DML results
|
|
246
|
-
rows_affected = cast("int", result.get("rows_affected", -1))
|
|
247
|
-
status_message = result.get("status_message", "")
|
|
248
|
-
|
|
249
|
-
return SQLResult[RowT](
|
|
250
|
-
statement=statement,
|
|
251
|
-
data=cast("list[RowT]", []),
|
|
252
|
-
rows_affected=rows_affected,
|
|
253
|
-
operation_type=operation_type,
|
|
254
|
-
metadata={"status_message": status_message},
|
|
255
|
-
)
|
|
256
|
-
|
|
257
221
|
def _connection(self, connection: Optional[AsyncpgConnection] = None) -> AsyncpgConnection:
|
|
258
222
|
"""Get the connection to use for the operation."""
|
|
259
223
|
return connection or self.connection
|
|
@@ -290,11 +254,9 @@ class AsyncpgDriver(
|
|
|
290
254
|
from sqlspec.exceptions import PipelineExecutionError
|
|
291
255
|
|
|
292
256
|
try:
|
|
293
|
-
# Convert parameters to positional for AsyncPG (requires $1, $2, etc.)
|
|
294
257
|
sql_str = op.sql.to_sql(placeholder_style=ParameterStyle.NUMERIC)
|
|
295
258
|
params = self._convert_to_positional_params(op.sql.parameters)
|
|
296
259
|
|
|
297
|
-
# Apply operation-specific filters
|
|
298
260
|
filtered_sql = self._apply_operation_filters(op.sql, op.filters)
|
|
299
261
|
if filtered_sql != op.sql:
|
|
300
262
|
sql_str = filtered_sql.to_sql(placeholder_style=ParameterStyle.NUMERIC)
|
|
@@ -310,19 +272,18 @@ class AsyncpgDriver(
|
|
|
310
272
|
statement=op.sql,
|
|
311
273
|
data=cast("list[RowT]", []),
|
|
312
274
|
rows_affected=rows_affected,
|
|
313
|
-
operation_type="
|
|
275
|
+
operation_type="EXECUTE",
|
|
314
276
|
metadata={"status_message": status},
|
|
315
277
|
)
|
|
316
278
|
elif op.operation_type == "select":
|
|
317
279
|
# Use fetch for SELECT statements
|
|
318
280
|
rows = await connection.fetch(sql_str, *params)
|
|
319
|
-
# Convert AsyncPG Records to dictionaries
|
|
320
281
|
data = [dict(record) for record in rows] if rows else []
|
|
321
282
|
result = SQLResult[RowT](
|
|
322
283
|
statement=op.sql,
|
|
323
284
|
data=cast("list[RowT]", data),
|
|
324
285
|
rows_affected=len(data),
|
|
325
|
-
operation_type="
|
|
286
|
+
operation_type="SELECT",
|
|
326
287
|
metadata={"column_names": list(rows[0].keys()) if rows else []},
|
|
327
288
|
)
|
|
328
289
|
elif op.operation_type == "execute_script":
|
|
@@ -341,7 +302,7 @@ class AsyncpgDriver(
|
|
|
341
302
|
statement=op.sql,
|
|
342
303
|
data=cast("list[RowT]", []),
|
|
343
304
|
rows_affected=total_affected,
|
|
344
|
-
operation_type="
|
|
305
|
+
operation_type="SCRIPT",
|
|
345
306
|
metadata={"status_message": last_status, "statements_executed": len(script_statements)},
|
|
346
307
|
)
|
|
347
308
|
else:
|
|
@@ -351,18 +312,16 @@ class AsyncpgDriver(
|
|
|
351
312
|
statement=op.sql,
|
|
352
313
|
data=cast("list[RowT]", []),
|
|
353
314
|
rows_affected=rows_affected,
|
|
354
|
-
operation_type="
|
|
315
|
+
operation_type="EXECUTE",
|
|
355
316
|
metadata={"status_message": status},
|
|
356
317
|
)
|
|
357
318
|
|
|
358
|
-
# Add operation context
|
|
359
319
|
result.operation_index = i
|
|
360
320
|
result.pipeline_sql = op.sql
|
|
361
321
|
results.append(result)
|
|
362
322
|
|
|
363
323
|
except Exception as e:
|
|
364
324
|
if options.get("continue_on_error"):
|
|
365
|
-
# Create error result
|
|
366
325
|
error_result = SQLResult[RowT](
|
|
367
326
|
statement=op.sql, error=e, operation_index=i, parameters=op.original_params, data=[]
|
|
368
327
|
)
|
|
@@ -390,7 +349,6 @@ class AsyncpgDriver(
|
|
|
390
349
|
if isinstance(params, dict):
|
|
391
350
|
if not params:
|
|
392
351
|
return ()
|
|
393
|
-
# Convert dict to positional based on $1, $2, etc. order
|
|
394
352
|
# This assumes the SQL was compiled with NUMERIC style
|
|
395
353
|
max_param = 0
|
|
396
354
|
for key in params:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import contextlib
|
|
4
4
|
import logging
|
|
5
|
-
from dataclasses import replace
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional
|
|
7
6
|
|
|
8
7
|
from google.cloud.bigquery import LoadJobConfig, QueryJobConfig
|
|
@@ -273,13 +272,12 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
273
272
|
self.extras = kwargs or {}
|
|
274
273
|
|
|
275
274
|
# Store other config
|
|
276
|
-
self.statement_config = statement_config or SQLConfig()
|
|
275
|
+
self.statement_config = statement_config or SQLConfig(dialect="bigquery")
|
|
277
276
|
self.default_row_type = default_row_type
|
|
278
277
|
self.on_connection_create = on_connection_create
|
|
279
278
|
self.on_job_start = on_job_start
|
|
280
279
|
self.on_job_complete = on_job_complete
|
|
281
280
|
|
|
282
|
-
# Set up default query job config if not provided
|
|
283
281
|
if self.default_query_job_config is None:
|
|
284
282
|
self._setup_default_job_config()
|
|
285
283
|
|
|
@@ -385,15 +383,16 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
385
383
|
@contextlib.contextmanager
|
|
386
384
|
def session_manager() -> "Generator[BigQueryDriver, None, None]":
|
|
387
385
|
with self.provide_connection(*args, **kwargs) as connection:
|
|
388
|
-
# Create statement config with parameter style info if not already set
|
|
389
386
|
statement_config = self.statement_config
|
|
387
|
+
# Inject parameter style info if not already set
|
|
390
388
|
if statement_config.allowed_parameter_styles is None:
|
|
389
|
+
from dataclasses import replace
|
|
390
|
+
|
|
391
391
|
statement_config = replace(
|
|
392
392
|
statement_config,
|
|
393
393
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
394
394
|
target_parameter_style=self.preferred_parameter_style,
|
|
395
395
|
)
|
|
396
|
-
|
|
397
396
|
driver = self.driver_type(
|
|
398
397
|
connection=connection,
|
|
399
398
|
config=statement_config,
|