sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,541 +1,461 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import re
|
|
3
|
-
from
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Optional, Union,
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
5
4
|
|
|
6
|
-
from asyncpg import Connection
|
|
7
|
-
from
|
|
5
|
+
from asyncpg import Connection as AsyncpgNativeConnection
|
|
6
|
+
from asyncpg import Record
|
|
8
7
|
from typing_extensions import TypeAlias
|
|
9
8
|
|
|
10
|
-
from sqlspec.
|
|
11
|
-
from sqlspec.mixins import
|
|
12
|
-
|
|
9
|
+
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
|
+
from sqlspec.driver.mixins import (
|
|
11
|
+
AsyncPipelinedExecutionMixin,
|
|
12
|
+
AsyncStorageMixin,
|
|
13
|
+
SQLTranslatorMixin,
|
|
14
|
+
ToSchemaMixin,
|
|
15
|
+
TypeCoercionMixin,
|
|
16
|
+
)
|
|
17
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
18
|
+
from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
|
|
19
|
+
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
+
from sqlspec.typing import DictRow, ModelDTOT, RowT
|
|
21
|
+
from sqlspec.utils.logging import get_logger
|
|
13
22
|
|
|
14
23
|
if TYPE_CHECKING:
|
|
15
|
-
from collections.abc import Sequence
|
|
16
|
-
|
|
17
|
-
from asyncpg import Record
|
|
18
|
-
from asyncpg.connection import Connection
|
|
19
24
|
from asyncpg.pool import PoolConnectionProxy
|
|
20
|
-
|
|
21
|
-
from sqlspec.filters import StatementFilter
|
|
22
|
-
from sqlspec.typing import ModelDTOT, StatementParameterType, T
|
|
25
|
+
from sqlglot.dialects.dialect import DialectType
|
|
23
26
|
|
|
24
27
|
__all__ = ("AsyncpgConnection", "AsyncpgDriver")
|
|
25
28
|
|
|
26
|
-
logger =
|
|
29
|
+
logger = get_logger("adapters.asyncpg")
|
|
27
30
|
|
|
28
31
|
if TYPE_CHECKING:
|
|
29
|
-
AsyncpgConnection: TypeAlias = Union[
|
|
32
|
+
AsyncpgConnection: TypeAlias = Union[AsyncpgNativeConnection[Record], PoolConnectionProxy[Record]]
|
|
30
33
|
else:
|
|
31
|
-
AsyncpgConnection: TypeAlias =
|
|
32
|
-
|
|
33
|
-
#
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
#
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
#
|
|
40
|
-
|
|
41
|
-
# 4. Multi-line comments (/* to */)
|
|
42
|
-
# 5. Only question marks outside of these contexts are considered parameters
|
|
43
|
-
QUESTION_MARK_PATTERN = re.compile(
|
|
44
|
-
r"""
|
|
45
|
-
(?:'[^']*(?:''[^']*)*') | # Skip single-quoted strings (with '' escapes)
|
|
46
|
-
(?:"[^"]*(?:""[^"]*)*") | # Skip double-quoted strings (with "" escapes)
|
|
47
|
-
(?:--.*?(?:\n|$)) | # Skip single-line comments
|
|
48
|
-
(?:/\*(?:[^*]|\*(?!/))*\*/) | # Skip multi-line comments
|
|
49
|
-
(\?) # Capture only question marks outside of these contexts
|
|
50
|
-
""",
|
|
51
|
-
re.VERBOSE | re.DOTALL,
|
|
52
|
-
)
|
|
34
|
+
AsyncpgConnection: TypeAlias = Union[AsyncpgNativeConnection, Any]
|
|
35
|
+
|
|
36
|
+
# Compiled regex to parse asyncpg status messages like "INSERT 0 1" or "UPDATE 1"
|
|
37
|
+
# Group 1: Command Tag (e.g., INSERT, UPDATE)
|
|
38
|
+
# Group 2: (Optional) OID count for INSERT (we ignore this)
|
|
39
|
+
# Group 3: Rows affected
|
|
40
|
+
ASYNC_PG_STATUS_REGEX = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE)
|
|
41
|
+
|
|
42
|
+
# Expected number of groups in the regex match for row count extraction
|
|
43
|
+
EXPECTED_REGEX_GROUPS = 3
|
|
53
44
|
|
|
54
45
|
|
|
55
46
|
class AsyncpgDriver(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
47
|
+
AsyncDriverAdapterProtocol[AsyncpgConnection, RowT],
|
|
48
|
+
SQLTranslatorMixin,
|
|
49
|
+
TypeCoercionMixin,
|
|
50
|
+
AsyncStorageMixin,
|
|
51
|
+
AsyncPipelinedExecutionMixin,
|
|
52
|
+
ToSchemaMixin,
|
|
59
53
|
):
|
|
60
|
-
"""AsyncPG
|
|
61
|
-
|
|
62
|
-
connection: "AsyncpgConnection"
|
|
63
|
-
dialect: str = "postgres"
|
|
54
|
+
"""AsyncPG PostgreSQL Driver Adapter. Modern protocol implementation."""
|
|
64
55
|
|
|
65
|
-
|
|
66
|
-
|
|
56
|
+
dialect: "DialectType" = "postgres"
|
|
57
|
+
supported_parameter_styles: "tuple[ParameterStyle, ...]" = (ParameterStyle.NUMERIC,)
|
|
58
|
+
default_parameter_style: ParameterStyle = ParameterStyle.NUMERIC
|
|
59
|
+
__slots__ = ()
|
|
67
60
|
|
|
68
|
-
def
|
|
61
|
+
def __init__(
|
|
69
62
|
self,
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
#
|
|
95
|
-
|
|
63
|
+
connection: "AsyncpgConnection",
|
|
64
|
+
config: "Optional[SQLConfig]" = None,
|
|
65
|
+
default_row_type: "type[DictRow]" = dict[str, Any],
|
|
66
|
+
) -> None:
|
|
67
|
+
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
68
|
+
|
|
69
|
+
# AsyncPG-specific type coercion overrides (PostgreSQL has rich native types)
|
|
70
|
+
def _coerce_boolean(self, value: Any) -> Any:
|
|
71
|
+
"""AsyncPG/PostgreSQL has native boolean support."""
|
|
72
|
+
# Keep booleans as-is, AsyncPG handles them natively
|
|
73
|
+
return value
|
|
74
|
+
|
|
75
|
+
def _coerce_decimal(self, value: Any) -> Any:
|
|
76
|
+
"""AsyncPG/PostgreSQL has native decimal/numeric support."""
|
|
77
|
+
# Keep decimals as-is, AsyncPG handles them natively
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
def _coerce_json(self, value: Any) -> Any:
|
|
81
|
+
"""AsyncPG/PostgreSQL has native JSON/JSONB support."""
|
|
82
|
+
# AsyncPG can handle dict/list directly for JSON columns
|
|
83
|
+
return value
|
|
84
|
+
|
|
85
|
+
def _coerce_array(self, value: Any) -> Any:
|
|
86
|
+
"""AsyncPG/PostgreSQL has native array support."""
|
|
87
|
+
# Convert tuples to lists for consistency
|
|
88
|
+
if isinstance(value, tuple):
|
|
89
|
+
return list(value)
|
|
90
|
+
# Keep other arrays as-is, AsyncPG handles them natively
|
|
91
|
+
return value
|
|
92
|
+
|
|
93
|
+
async def _execute_statement(
|
|
94
|
+
self, statement: SQL, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
95
|
+
) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
|
|
96
|
+
if statement.is_script:
|
|
97
|
+
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
98
|
+
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
99
|
+
|
|
100
|
+
detected_styles = {p.style for p in statement.parameter_info}
|
|
101
|
+
target_style = self.default_parameter_style
|
|
102
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
103
|
+
if unsupported_styles:
|
|
104
|
+
target_style = self.default_parameter_style
|
|
105
|
+
elif detected_styles:
|
|
106
|
+
for style in detected_styles:
|
|
107
|
+
if style in self.supported_parameter_styles:
|
|
108
|
+
target_style = style
|
|
109
|
+
break
|
|
110
|
+
|
|
111
|
+
if statement.is_many:
|
|
112
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
113
|
+
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
114
|
+
|
|
115
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
116
|
+
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
117
|
+
|
|
118
|
+
async def _execute(
|
|
119
|
+
self, sql: str, parameters: Any, statement: SQL, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
120
|
+
) -> Union[SelectResultDict, DMLResultDict]:
|
|
121
|
+
conn = self._connection(connection)
|
|
122
|
+
# Process parameters to handle TypedParameter objects
|
|
123
|
+
parameters = self._process_parameters(parameters)
|
|
124
|
+
|
|
125
|
+
# Check if this is actually a many operation that was misrouted
|
|
126
|
+
if statement.is_many:
|
|
127
|
+
# This should have gone to _execute_many, redirect it
|
|
128
|
+
return await self._execute_many(sql, parameters, connection=connection, **kwargs)
|
|
129
|
+
|
|
130
|
+
# AsyncPG expects parameters as *args, not a single list
|
|
131
|
+
args_for_driver: list[Any] = []
|
|
132
|
+
|
|
133
|
+
if parameters is not None:
|
|
134
|
+
if isinstance(parameters, (list, tuple)):
|
|
135
|
+
args_for_driver.extend(parameters)
|
|
136
|
+
else:
|
|
137
|
+
args_for_driver.append(parameters)
|
|
138
|
+
|
|
139
|
+
if self.returns_rows(statement.expression):
|
|
140
|
+
records = await conn.fetch(sql, *args_for_driver)
|
|
141
|
+
# Convert asyncpg Records to dicts
|
|
142
|
+
data = [dict(record) for record in records]
|
|
143
|
+
# Get column names from first record or empty list
|
|
144
|
+
column_names = list(records[0].keys()) if records else []
|
|
145
|
+
result: SelectResultDict = {"data": data, "column_names": column_names, "rows_affected": len(records)}
|
|
146
|
+
return result
|
|
96
147
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
async def
|
|
175
|
-
self,
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
148
|
+
status = await conn.execute(sql, *args_for_driver)
|
|
149
|
+
# Parse row count from status string
|
|
150
|
+
rows_affected = 0
|
|
151
|
+
if status and isinstance(status, str):
|
|
152
|
+
match = ASYNC_PG_STATUS_REGEX.match(status)
|
|
153
|
+
if match and len(match.groups()) >= EXPECTED_REGEX_GROUPS:
|
|
154
|
+
rows_affected = int(match.group(3))
|
|
155
|
+
|
|
156
|
+
dml_result: DMLResultDict = {"rows_affected": rows_affected, "status_message": status or "OK"}
|
|
157
|
+
return dml_result
|
|
158
|
+
|
|
159
|
+
async def _execute_many(
|
|
160
|
+
self, sql: str, param_list: Any, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
161
|
+
) -> DMLResultDict:
|
|
162
|
+
conn = self._connection(connection)
|
|
163
|
+
# Process parameters to handle TypedParameter objects
|
|
164
|
+
param_list = self._process_parameters(param_list)
|
|
165
|
+
|
|
166
|
+
params_list: list[tuple[Any, ...]] = []
|
|
167
|
+
rows_affected = 0
|
|
168
|
+
if param_list and isinstance(param_list, Sequence):
|
|
169
|
+
for param_set in param_list:
|
|
170
|
+
if isinstance(param_set, (list, tuple)):
|
|
171
|
+
params_list.append(tuple(param_set))
|
|
172
|
+
elif param_set is None:
|
|
173
|
+
params_list.append(())
|
|
174
|
+
else:
|
|
175
|
+
params_list.append((param_set,))
|
|
176
|
+
|
|
177
|
+
await conn.executemany(sql, params_list)
|
|
178
|
+
# AsyncPG's executemany returns None, not a status string
|
|
179
|
+
# We need to use the number of parameter sets as the row count
|
|
180
|
+
rows_affected = len(params_list)
|
|
181
|
+
|
|
182
|
+
dml_result: DMLResultDict = {"rows_affected": rows_affected, "status_message": "OK"}
|
|
183
|
+
return dml_result
|
|
184
|
+
|
|
185
|
+
async def _execute_script(
|
|
186
|
+
self, script: str, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
187
|
+
) -> ScriptResultDict:
|
|
188
|
+
conn = self._connection(connection)
|
|
189
|
+
status = await conn.execute(script)
|
|
190
|
+
|
|
191
|
+
result: ScriptResultDict = {
|
|
192
|
+
"statements_executed": -1, # AsyncPG doesn't provide statement count
|
|
193
|
+
"status_message": status or "SCRIPT EXECUTED",
|
|
194
|
+
}
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
async def _wrap_select_result(
|
|
198
|
+
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
199
|
+
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
200
|
+
records = cast("list[Record]", result["data"])
|
|
201
|
+
column_names = result["column_names"]
|
|
202
|
+
rows_affected = result["rows_affected"]
|
|
203
|
+
|
|
204
|
+
rows_as_dicts: list[dict[str, Any]] = [dict(record) for record in records]
|
|
205
|
+
|
|
206
|
+
if schema_type:
|
|
207
|
+
converted_data_seq = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
208
|
+
converted_data_list = list(converted_data_seq) if converted_data_seq is not None else []
|
|
209
|
+
return SQLResult[ModelDTOT](
|
|
210
|
+
statement=statement,
|
|
211
|
+
data=converted_data_list,
|
|
212
|
+
column_names=column_names,
|
|
213
|
+
rows_affected=rows_affected,
|
|
214
|
+
operation_type="SELECT",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return SQLResult[RowT](
|
|
218
|
+
statement=statement,
|
|
219
|
+
data=cast("list[RowT]", rows_as_dicts),
|
|
220
|
+
column_names=column_names,
|
|
221
|
+
rows_affected=rows_affected,
|
|
222
|
+
operation_type="SELECT",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
async def _wrap_execute_result(
|
|
226
|
+
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
227
|
+
) -> SQLResult[RowT]:
|
|
228
|
+
operation_type = "UNKNOWN"
|
|
229
|
+
if statement.expression:
|
|
230
|
+
operation_type = str(statement.expression.key).upper()
|
|
231
|
+
|
|
232
|
+
# Handle script results
|
|
233
|
+
if "statements_executed" in result:
|
|
234
|
+
return SQLResult[RowT](
|
|
235
|
+
statement=statement,
|
|
236
|
+
data=cast("list[RowT]", []),
|
|
237
|
+
rows_affected=0,
|
|
238
|
+
operation_type="SCRIPT",
|
|
239
|
+
metadata={
|
|
240
|
+
"status_message": result.get("status_message", ""),
|
|
241
|
+
"statements_executed": result.get("statements_executed", -1),
|
|
242
|
+
},
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Handle DML results
|
|
246
|
+
rows_affected = cast("int", result.get("rows_affected", -1))
|
|
247
|
+
status_message = result.get("status_message", "")
|
|
248
|
+
|
|
249
|
+
return SQLResult[RowT](
|
|
250
|
+
statement=statement,
|
|
251
|
+
data=cast("list[RowT]", []),
|
|
252
|
+
rows_affected=rows_affected,
|
|
253
|
+
operation_type=operation_type,
|
|
254
|
+
metadata={"status_message": status_message},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _connection(self, connection: Optional[AsyncpgConnection] = None) -> AsyncpgConnection:
|
|
258
|
+
"""Get the connection to use for the operation."""
|
|
259
|
+
return connection or self.connection
|
|
260
|
+
|
|
261
|
+
async def _execute_pipeline_native(self, operations: "list[Any]", **options: Any) -> "list[SQLResult[RowT]]":
|
|
262
|
+
"""Native pipeline execution using AsyncPG's efficient batch handling.
|
|
263
|
+
|
|
264
|
+
Note: AsyncPG doesn't have explicit pipeline support like Psycopg, but we can
|
|
265
|
+
achieve similar performance benefits through careful batching and transaction
|
|
266
|
+
management.
|
|
195
267
|
|
|
196
268
|
Args:
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
parameters: Query parameters.
|
|
200
|
-
connection: Optional connection to use.
|
|
201
|
-
schema_type: Optional schema class for the result.
|
|
202
|
-
**kwargs: Additional keyword arguments.
|
|
269
|
+
operations: List of PipelineOperation objects
|
|
270
|
+
**options: Pipeline configuration options
|
|
203
271
|
|
|
204
272
|
Returns:
|
|
205
|
-
List of
|
|
273
|
+
List of SQLResult objects from all operations
|
|
206
274
|
"""
|
|
207
|
-
connection = self._connection(connection)
|
|
208
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
209
|
-
parameters = parameters if parameters is not None else ()
|
|
210
275
|
|
|
211
|
-
results =
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
276
|
+
results: list[Any] = []
|
|
277
|
+
connection = self._connection()
|
|
278
|
+
|
|
279
|
+
# Use a single transaction for all operations
|
|
280
|
+
async with connection.transaction():
|
|
281
|
+
for i, op in enumerate(operations):
|
|
282
|
+
await self._execute_pipeline_operation(connection, i, op, options, results)
|
|
283
|
+
|
|
284
|
+
return results
|
|
285
|
+
|
|
286
|
+
async def _execute_pipeline_operation(
|
|
287
|
+
self, connection: Any, i: int, op: Any, options: dict[str, Any], results: list[Any]
|
|
288
|
+
) -> None:
|
|
289
|
+
"""Execute a single pipeline operation with error handling."""
|
|
290
|
+
from sqlspec.exceptions import PipelineExecutionError
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
# Convert parameters to positional for AsyncPG (requires $1, $2, etc.)
|
|
294
|
+
sql_str = op.sql.to_sql(placeholder_style=ParameterStyle.NUMERIC)
|
|
295
|
+
params = self._convert_to_positional_params(op.sql.parameters)
|
|
296
|
+
|
|
297
|
+
# Apply operation-specific filters
|
|
298
|
+
filtered_sql = self._apply_operation_filters(op.sql, op.filters)
|
|
299
|
+
if filtered_sql != op.sql:
|
|
300
|
+
sql_str = filtered_sql.to_sql(placeholder_style=ParameterStyle.NUMERIC)
|
|
301
|
+
params = self._convert_to_positional_params(filtered_sql.parameters)
|
|
302
|
+
|
|
303
|
+
# Execute based on operation type
|
|
304
|
+
if op.operation_type == "execute_many":
|
|
305
|
+
# AsyncPG has native executemany support
|
|
306
|
+
status = await connection.executemany(sql_str, params)
|
|
307
|
+
# Parse row count from status (e.g., "INSERT 0 5")
|
|
308
|
+
rows_affected = self._parse_asyncpg_status(status)
|
|
309
|
+
result = SQLResult[RowT](
|
|
310
|
+
statement=op.sql,
|
|
311
|
+
data=cast("list[RowT]", []),
|
|
312
|
+
rows_affected=rows_affected,
|
|
313
|
+
operation_type="execute_many",
|
|
314
|
+
metadata={"status_message": status},
|
|
315
|
+
)
|
|
316
|
+
elif op.operation_type == "select":
|
|
317
|
+
# Use fetch for SELECT statements
|
|
318
|
+
rows = await connection.fetch(sql_str, *params)
|
|
319
|
+
# Convert AsyncPG Records to dictionaries
|
|
320
|
+
data = [dict(record) for record in rows] if rows else []
|
|
321
|
+
result = SQLResult[RowT](
|
|
322
|
+
statement=op.sql,
|
|
323
|
+
data=cast("list[RowT]", data),
|
|
324
|
+
rows_affected=len(data),
|
|
325
|
+
operation_type="select",
|
|
326
|
+
metadata={"column_names": list(rows[0].keys()) if rows else []},
|
|
327
|
+
)
|
|
328
|
+
elif op.operation_type == "execute_script":
|
|
329
|
+
# For scripts, split and execute each statement
|
|
330
|
+
script_statements = self._split_script_statements(op.sql.to_sql())
|
|
331
|
+
total_affected = 0
|
|
332
|
+
last_status = ""
|
|
333
|
+
|
|
334
|
+
for stmt in script_statements:
|
|
335
|
+
if stmt.strip():
|
|
336
|
+
status = await connection.execute(stmt)
|
|
337
|
+
total_affected += self._parse_asyncpg_status(status)
|
|
338
|
+
last_status = status
|
|
339
|
+
|
|
340
|
+
result = SQLResult[RowT](
|
|
341
|
+
statement=op.sql,
|
|
342
|
+
data=cast("list[RowT]", []),
|
|
343
|
+
rows_affected=total_affected,
|
|
344
|
+
operation_type="execute_script",
|
|
345
|
+
metadata={"status_message": last_status, "statements_executed": len(script_statements)},
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
status = await connection.execute(sql_str, *params)
|
|
349
|
+
rows_affected = self._parse_asyncpg_status(status)
|
|
350
|
+
result = SQLResult[RowT](
|
|
351
|
+
statement=op.sql,
|
|
352
|
+
data=cast("list[RowT]", []),
|
|
353
|
+
rows_affected=rows_affected,
|
|
354
|
+
operation_type="execute",
|
|
355
|
+
metadata={"status_message": status},
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Add operation context
|
|
359
|
+
result.operation_index = i
|
|
360
|
+
result.pipeline_sql = op.sql
|
|
361
|
+
results.append(result)
|
|
362
|
+
|
|
363
|
+
except Exception as e:
|
|
364
|
+
if options.get("continue_on_error"):
|
|
365
|
+
# Create error result
|
|
366
|
+
error_result = SQLResult[RowT](
|
|
367
|
+
statement=op.sql, error=e, operation_index=i, parameters=op.original_params, data=[]
|
|
368
|
+
)
|
|
369
|
+
results.append(error_result)
|
|
370
|
+
else:
|
|
371
|
+
# Transaction will be rolled back automatically
|
|
372
|
+
msg = f"AsyncPG pipeline failed at operation {i}: {e}"
|
|
373
|
+
raise PipelineExecutionError(
|
|
374
|
+
msg, operation_index=i, partial_results=results, failed_operation=op
|
|
375
|
+
) from e
|
|
376
|
+
|
|
377
|
+
def _convert_to_positional_params(self, params: Any) -> "tuple[Any, ...]":
|
|
378
|
+
"""Convert parameters to positional format for AsyncPG.
|
|
379
|
+
|
|
380
|
+
AsyncPG requires parameters as positional arguments for $1, $2, etc.
|
|
249
381
|
|
|
250
382
|
Args:
|
|
251
|
-
|
|
252
|
-
sql: SQL statement.
|
|
253
|
-
parameters: Query parameters.
|
|
254
|
-
connection: Optional connection to use.
|
|
255
|
-
schema_type: Optional schema class for the result.
|
|
256
|
-
**kwargs: Additional keyword arguments.
|
|
383
|
+
params: Parameters in various formats
|
|
257
384
|
|
|
258
385
|
Returns:
|
|
259
|
-
|
|
386
|
+
Tuple of positional parameters
|
|
260
387
|
"""
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
"""Fetch one row from the database.
|
|
301
|
-
|
|
302
|
-
Args:
|
|
303
|
-
*filters: Statement filters to apply.
|
|
304
|
-
sql: SQL statement.
|
|
305
|
-
parameters: Query parameters.
|
|
306
|
-
connection: Optional connection to use.
|
|
307
|
-
schema_type: Optional schema class for the result.
|
|
308
|
-
**kwargs: Additional keyword arguments.
|
|
309
|
-
|
|
310
|
-
Returns:
|
|
311
|
-
The first row of the query results.
|
|
312
|
-
"""
|
|
313
|
-
connection = self._connection(connection)
|
|
314
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
315
|
-
parameters = parameters if parameters is not None else ()
|
|
316
|
-
result = await connection.fetchrow(sql, *parameters) # pyright: ignore
|
|
317
|
-
if result is None:
|
|
318
|
-
return None
|
|
319
|
-
return self.to_schema(dict(result.items()), schema_type=schema_type)
|
|
320
|
-
|
|
321
|
-
@overload
|
|
322
|
-
async def select_value(
|
|
323
|
-
self,
|
|
324
|
-
sql: str,
|
|
325
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
326
|
-
/,
|
|
327
|
-
*filters: "StatementFilter",
|
|
328
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
329
|
-
schema_type: None = None,
|
|
330
|
-
**kwargs: Any,
|
|
331
|
-
) -> "Any": ...
|
|
332
|
-
@overload
|
|
333
|
-
async def select_value(
|
|
334
|
-
self,
|
|
335
|
-
sql: str,
|
|
336
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
337
|
-
/,
|
|
338
|
-
*filters: "StatementFilter",
|
|
339
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
340
|
-
schema_type: "type[T]",
|
|
341
|
-
**kwargs: Any,
|
|
342
|
-
) -> "T": ...
|
|
343
|
-
async def select_value(
|
|
344
|
-
self,
|
|
345
|
-
sql: str,
|
|
346
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
347
|
-
/,
|
|
348
|
-
*filters: "StatementFilter",
|
|
349
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
350
|
-
schema_type: "Optional[type[T]]" = None,
|
|
351
|
-
**kwargs: Any,
|
|
352
|
-
) -> "Union[T, Any]":
|
|
353
|
-
"""Fetch a single value from the database.
|
|
388
|
+
if params is None:
|
|
389
|
+
return ()
|
|
390
|
+
if isinstance(params, dict):
|
|
391
|
+
if not params:
|
|
392
|
+
return ()
|
|
393
|
+
# Convert dict to positional based on $1, $2, etc. order
|
|
394
|
+
# This assumes the SQL was compiled with NUMERIC style
|
|
395
|
+
max_param = 0
|
|
396
|
+
for key in params:
|
|
397
|
+
if isinstance(key, str) and key.startswith("param_"):
|
|
398
|
+
try:
|
|
399
|
+
param_num = int(key[6:]) # Extract number from "param_N"
|
|
400
|
+
max_param = max(max_param, param_num)
|
|
401
|
+
except ValueError:
|
|
402
|
+
continue
|
|
403
|
+
|
|
404
|
+
if max_param > 0:
|
|
405
|
+
# Rebuild positional args from param_0, param_1, etc.
|
|
406
|
+
positional = []
|
|
407
|
+
for i in range(max_param + 1):
|
|
408
|
+
param_key = f"param_{i}"
|
|
409
|
+
if param_key in params:
|
|
410
|
+
positional.append(params[param_key])
|
|
411
|
+
return tuple(positional)
|
|
412
|
+
# Fall back to dict values in arbitrary order
|
|
413
|
+
return tuple(params.values())
|
|
414
|
+
if isinstance(params, (list, tuple)):
|
|
415
|
+
return tuple(params)
|
|
416
|
+
return (params,)
|
|
417
|
+
|
|
418
|
+
def _apply_operation_filters(self, sql: "SQL", filters: "list[Any]") -> "SQL":
|
|
419
|
+
"""Apply filters to a SQL object for pipeline operations."""
|
|
420
|
+
if not filters:
|
|
421
|
+
return sql
|
|
422
|
+
|
|
423
|
+
result_sql = sql
|
|
424
|
+
for filter_obj in filters:
|
|
425
|
+
if hasattr(filter_obj, "apply"):
|
|
426
|
+
result_sql = filter_obj.apply(result_sql)
|
|
354
427
|
|
|
355
|
-
|
|
356
|
-
*filters: Statement filters to apply.
|
|
357
|
-
sql: SQL statement.
|
|
358
|
-
parameters: Query parameters.
|
|
359
|
-
connection: Optional connection to use.
|
|
360
|
-
schema_type: Optional schema class for the result.
|
|
361
|
-
**kwargs: Additional keyword arguments.
|
|
428
|
+
return result_sql
|
|
362
429
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
parameters = parameters if parameters is not None else ()
|
|
369
|
-
result = await connection.fetchval(sql, *parameters) # pyright: ignore
|
|
370
|
-
result = self.check_not_found(result)
|
|
371
|
-
if schema_type is None:
|
|
372
|
-
return result
|
|
373
|
-
return schema_type(result) # type: ignore[call-arg]
|
|
430
|
+
def _split_script_statements(self, script: str, strip_trailing_semicolon: bool = False) -> "list[str]":
|
|
431
|
+
"""Split a SQL script into individual statements."""
|
|
432
|
+
# Simple splitting on semicolon - could be enhanced with proper SQL parsing
|
|
433
|
+
statements = [stmt.strip() for stmt in script.split(";")]
|
|
434
|
+
return [stmt for stmt in statements if stmt]
|
|
374
435
|
|
|
375
|
-
@
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
sql: str,
|
|
379
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
380
|
-
/,
|
|
381
|
-
*filters: "StatementFilter",
|
|
382
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
383
|
-
schema_type: None = None,
|
|
384
|
-
**kwargs: Any,
|
|
385
|
-
) -> "Optional[Any]": ...
|
|
386
|
-
@overload
|
|
387
|
-
async def select_value_or_none(
|
|
388
|
-
self,
|
|
389
|
-
sql: str,
|
|
390
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
391
|
-
/,
|
|
392
|
-
*filters: "StatementFilter",
|
|
393
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
394
|
-
schema_type: "type[T]",
|
|
395
|
-
**kwargs: Any,
|
|
396
|
-
) -> "Optional[T]": ...
|
|
397
|
-
async def select_value_or_none(
|
|
398
|
-
self,
|
|
399
|
-
sql: str,
|
|
400
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
401
|
-
/,
|
|
402
|
-
*filters: "StatementFilter",
|
|
403
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
404
|
-
schema_type: "Optional[type[T]]" = None,
|
|
405
|
-
**kwargs: Any,
|
|
406
|
-
) -> "Optional[Union[T, Any]]":
|
|
407
|
-
"""Fetch a single value from the database.
|
|
436
|
+
@staticmethod
|
|
437
|
+
def _parse_asyncpg_status(status: str) -> int:
|
|
438
|
+
"""Parse AsyncPG status string to extract row count.
|
|
408
439
|
|
|
409
440
|
Args:
|
|
410
|
-
|
|
411
|
-
sql: SQL statement.
|
|
412
|
-
parameters: Query parameters.
|
|
413
|
-
connection: Optional connection to use.
|
|
414
|
-
schema_type: Optional schema class for the result.
|
|
415
|
-
**kwargs: Additional keyword arguments.
|
|
441
|
+
status: Status string like "INSERT 0 1", "UPDATE 3", "DELETE 2"
|
|
416
442
|
|
|
417
443
|
Returns:
|
|
418
|
-
|
|
444
|
+
Number of affected rows, or 0 if cannot parse
|
|
419
445
|
"""
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
parameters = parameters if parameters is not None else ()
|
|
423
|
-
result = await connection.fetchval(sql, *parameters) # pyright: ignore
|
|
424
|
-
if result is None:
|
|
425
|
-
return None
|
|
426
|
-
if schema_type is None:
|
|
427
|
-
return result
|
|
428
|
-
return schema_type(result) # type: ignore[call-arg]
|
|
446
|
+
if not status:
|
|
447
|
+
return 0
|
|
429
448
|
|
|
430
|
-
|
|
431
|
-
self,
|
|
432
|
-
sql: str,
|
|
433
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
434
|
-
/,
|
|
435
|
-
*filters: "StatementFilter",
|
|
436
|
-
connection: Optional["AsyncpgConnection"] = None,
|
|
437
|
-
**kwargs: Any,
|
|
438
|
-
) -> int:
|
|
439
|
-
"""Insert, update, or delete data from the database.
|
|
440
|
-
|
|
441
|
-
Args:
|
|
442
|
-
*filters: Statement filters to apply.
|
|
443
|
-
sql: SQL statement.
|
|
444
|
-
parameters: Query parameters.
|
|
445
|
-
connection: Optional connection to use.
|
|
446
|
-
**kwargs: Additional keyword arguments.
|
|
447
|
-
|
|
448
|
-
Returns:
|
|
449
|
-
Row count affected by the operation.
|
|
450
|
-
"""
|
|
451
|
-
connection = self._connection(connection)
|
|
452
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
453
|
-
parameters = parameters if parameters is not None else ()
|
|
454
|
-
result = await connection.execute(sql, *parameters) # pyright: ignore
|
|
455
|
-
# asyncpg returns e.g. 'INSERT 0 1', 'UPDATE 0 2', etc.
|
|
456
|
-
match = ROWCOUNT_REGEX.match(result)
|
|
449
|
+
match = ASYNC_PG_STATUS_REGEX.match(status.strip())
|
|
457
450
|
if match:
|
|
458
|
-
|
|
459
|
-
|
|
451
|
+
# For INSERT: "INSERT 0 5" -> groups: (INSERT, 0, 5)
|
|
452
|
+
# For UPDATE/DELETE: "UPDATE 3" -> groups: (UPDATE, None, 3)
|
|
453
|
+
groups = match.groups()
|
|
454
|
+
if len(groups) >= EXPECTED_REGEX_GROUPS:
|
|
455
|
+
try:
|
|
456
|
+
# The last group is always the row count
|
|
457
|
+
return int(groups[-1])
|
|
458
|
+
except (ValueError, IndexError):
|
|
459
|
+
pass
|
|
460
460
|
|
|
461
|
-
|
|
462
|
-
async def insert_update_delete_returning(
|
|
463
|
-
self,
|
|
464
|
-
sql: str,
|
|
465
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
466
|
-
/,
|
|
467
|
-
*filters: "StatementFilter",
|
|
468
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
469
|
-
schema_type: None = None,
|
|
470
|
-
**kwargs: Any,
|
|
471
|
-
) -> "dict[str, Any]": ...
|
|
472
|
-
@overload
|
|
473
|
-
async def insert_update_delete_returning(
|
|
474
|
-
self,
|
|
475
|
-
sql: str,
|
|
476
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
477
|
-
/,
|
|
478
|
-
*filters: "StatementFilter",
|
|
479
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
480
|
-
schema_type: "type[ModelDTOT]",
|
|
481
|
-
**kwargs: Any,
|
|
482
|
-
) -> "ModelDTOT": ...
|
|
483
|
-
async def insert_update_delete_returning(
|
|
484
|
-
self,
|
|
485
|
-
sql: str,
|
|
486
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
487
|
-
/,
|
|
488
|
-
*filters: "StatementFilter",
|
|
489
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
490
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
491
|
-
**kwargs: Any,
|
|
492
|
-
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
493
|
-
"""Insert, update, or delete data from the database and return the affected row.
|
|
494
|
-
|
|
495
|
-
Args:
|
|
496
|
-
*filters: Statement filters to apply.
|
|
497
|
-
sql: SQL statement.
|
|
498
|
-
parameters: Query parameters.
|
|
499
|
-
connection: Optional connection to use.
|
|
500
|
-
schema_type: Optional schema class for the result.
|
|
501
|
-
**kwargs: Additional keyword arguments.
|
|
502
|
-
|
|
503
|
-
Returns:
|
|
504
|
-
The affected row data as either a model instance or dictionary.
|
|
505
|
-
"""
|
|
506
|
-
connection = self._connection(connection)
|
|
507
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
508
|
-
parameters = parameters if parameters is not None else ()
|
|
509
|
-
result = await connection.fetchrow(sql, *parameters) # pyright: ignore
|
|
510
|
-
if result is None:
|
|
511
|
-
return None
|
|
512
|
-
|
|
513
|
-
return self.to_schema(dict(result.items()), schema_type=schema_type)
|
|
514
|
-
|
|
515
|
-
async def execute_script(
|
|
516
|
-
self,
|
|
517
|
-
sql: str,
|
|
518
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
519
|
-
/,
|
|
520
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
521
|
-
**kwargs: Any,
|
|
522
|
-
) -> str:
|
|
523
|
-
"""Execute a script.
|
|
524
|
-
|
|
525
|
-
Args:
|
|
526
|
-
sql: SQL statement.
|
|
527
|
-
parameters: Query parameters.
|
|
528
|
-
connection: Optional connection to use.
|
|
529
|
-
**kwargs: Additional keyword arguments.
|
|
530
|
-
|
|
531
|
-
Returns:
|
|
532
|
-
Status message for the operation.
|
|
533
|
-
"""
|
|
534
|
-
connection = self._connection(connection)
|
|
535
|
-
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
536
|
-
parameters = parameters if parameters is not None else ()
|
|
537
|
-
return await connection.execute(sql, *parameters) # pyright: ignore
|
|
538
|
-
|
|
539
|
-
def _connection(self, connection: "Optional[AsyncpgConnection]" = None) -> "AsyncpgConnection":
|
|
540
|
-
"""Return the connection to use. If None, use the default connection."""
|
|
541
|
-
return connection if connection is not None else self.connection
|
|
461
|
+
return 0
|