sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +725 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
- sqlspec-0.12.1.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,531 +1,461 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import re
|
|
3
|
-
from
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Optional, Union,
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
5
4
|
|
|
6
|
-
from asyncpg import Connection
|
|
7
|
-
from
|
|
5
|
+
from asyncpg import Connection as AsyncpgNativeConnection
|
|
6
|
+
from asyncpg import Record
|
|
8
7
|
from typing_extensions import TypeAlias
|
|
9
8
|
|
|
10
|
-
from sqlspec.
|
|
11
|
-
from sqlspec.
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
|
+
from sqlspec.driver.mixins import (
|
|
11
|
+
AsyncPipelinedExecutionMixin,
|
|
12
|
+
AsyncStorageMixin,
|
|
13
|
+
SQLTranslatorMixin,
|
|
14
|
+
ToSchemaMixin,
|
|
15
|
+
TypeCoercionMixin,
|
|
16
|
+
)
|
|
17
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
18
|
+
from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
|
|
19
|
+
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
+
from sqlspec.typing import DictRow, ModelDTOT, RowT
|
|
21
|
+
from sqlspec.utils.logging import get_logger
|
|
14
22
|
|
|
15
23
|
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import Mapping, Sequence
|
|
17
|
-
|
|
18
|
-
from asyncpg import Record
|
|
19
|
-
from asyncpg.connection import Connection
|
|
20
24
|
from asyncpg.pool import PoolConnectionProxy
|
|
21
|
-
|
|
22
|
-
from sqlspec.typing import ModelDTOT, StatementParameterType, T
|
|
25
|
+
from sqlglot.dialects.dialect import DialectType
|
|
23
26
|
|
|
24
27
|
__all__ = ("AsyncpgConnection", "AsyncpgDriver")
|
|
25
28
|
|
|
26
|
-
logger =
|
|
29
|
+
logger = get_logger("adapters.asyncpg")
|
|
27
30
|
|
|
28
31
|
if TYPE_CHECKING:
|
|
29
|
-
AsyncpgConnection: TypeAlias = Union[
|
|
32
|
+
AsyncpgConnection: TypeAlias = Union[AsyncpgNativeConnection[Record], PoolConnectionProxy[Record]]
|
|
30
33
|
else:
|
|
31
|
-
AsyncpgConnection: TypeAlias =
|
|
32
|
-
|
|
33
|
-
#
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
#
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
#
|
|
40
|
-
|
|
41
|
-
# 4. Multi-line comments (/* to */)
|
|
42
|
-
# 5. Only question marks outside of these contexts are considered parameters
|
|
43
|
-
QUESTION_MARK_PATTERN = re.compile(
|
|
44
|
-
r"""
|
|
45
|
-
(?:'[^']*(?:''[^']*)*') | # Skip single-quoted strings (with '' escapes)
|
|
46
|
-
(?:"[^"]*(?:""[^"]*)*") | # Skip double-quoted strings (with "" escapes)
|
|
47
|
-
(?:--.*?(?:\n|$)) | # Skip single-line comments
|
|
48
|
-
(?:/\*(?:[^*]|\*(?!/))*\*/) | # Skip multi-line comments
|
|
49
|
-
(\?) # Capture only question marks outside of these contexts
|
|
50
|
-
""",
|
|
51
|
-
re.VERBOSE | re.DOTALL,
|
|
52
|
-
)
|
|
34
|
+
AsyncpgConnection: TypeAlias = Union[AsyncpgNativeConnection, Any]
|
|
35
|
+
|
|
36
|
+
# Compiled regex to parse asyncpg status messages like "INSERT 0 1" or "UPDATE 1"
|
|
37
|
+
# Group 1: Command Tag (e.g., INSERT, UPDATE)
|
|
38
|
+
# Group 2: (Optional) OID count for INSERT (we ignore this)
|
|
39
|
+
# Group 3: Rows affected
|
|
40
|
+
ASYNC_PG_STATUS_REGEX = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE)
|
|
41
|
+
|
|
42
|
+
# Expected number of groups in the regex match for row count extraction
|
|
43
|
+
EXPECTED_REGEX_GROUPS = 3
|
|
53
44
|
|
|
54
45
|
|
|
55
46
|
class AsyncpgDriver(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
47
|
+
AsyncDriverAdapterProtocol[AsyncpgConnection, RowT],
|
|
48
|
+
SQLTranslatorMixin,
|
|
49
|
+
TypeCoercionMixin,
|
|
50
|
+
AsyncStorageMixin,
|
|
51
|
+
AsyncPipelinedExecutionMixin,
|
|
52
|
+
ToSchemaMixin,
|
|
59
53
|
):
|
|
60
|
-
"""AsyncPG
|
|
61
|
-
|
|
62
|
-
connection: "AsyncpgConnection"
|
|
63
|
-
dialect: str = "postgres"
|
|
54
|
+
"""AsyncPG PostgreSQL Driver Adapter. Modern protocol implementation."""
|
|
64
55
|
|
|
65
|
-
|
|
66
|
-
|
|
56
|
+
dialect: "DialectType" = "postgres"
|
|
57
|
+
supported_parameter_styles: "tuple[ParameterStyle, ...]" = (ParameterStyle.NUMERIC,)
|
|
58
|
+
default_parameter_style: ParameterStyle = ParameterStyle.NUMERIC
|
|
59
|
+
__slots__ = ()
|
|
67
60
|
|
|
68
|
-
def
|
|
61
|
+
def __init__(
|
|
69
62
|
self,
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
"""
|
|
89
|
-
|
|
90
|
-
|
|
63
|
+
connection: "AsyncpgConnection",
|
|
64
|
+
config: "Optional[SQLConfig]" = None,
|
|
65
|
+
default_row_type: "type[DictRow]" = dict[str, Any],
|
|
66
|
+
) -> None:
|
|
67
|
+
super().__init__(connection=connection, config=config, default_row_type=default_row_type)
|
|
68
|
+
|
|
69
|
+
# AsyncPG-specific type coercion overrides (PostgreSQL has rich native types)
|
|
70
|
+
def _coerce_boolean(self, value: Any) -> Any:
|
|
71
|
+
"""AsyncPG/PostgreSQL has native boolean support."""
|
|
72
|
+
# Keep booleans as-is, AsyncPG handles them natively
|
|
73
|
+
return value
|
|
74
|
+
|
|
75
|
+
def _coerce_decimal(self, value: Any) -> Any:
|
|
76
|
+
"""AsyncPG/PostgreSQL has native decimal/numeric support."""
|
|
77
|
+
# Keep decimals as-is, AsyncPG handles them natively
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
def _coerce_json(self, value: Any) -> Any:
|
|
81
|
+
"""AsyncPG/PostgreSQL has native JSON/JSONB support."""
|
|
82
|
+
# AsyncPG can handle dict/list directly for JSON columns
|
|
83
|
+
return value
|
|
84
|
+
|
|
85
|
+
def _coerce_array(self, value: Any) -> Any:
|
|
86
|
+
"""AsyncPG/PostgreSQL has native array support."""
|
|
87
|
+
# Convert tuples to lists for consistency
|
|
88
|
+
if isinstance(value, tuple):
|
|
89
|
+
return list(value)
|
|
90
|
+
# Keep other arrays as-is, AsyncPG handles them natively
|
|
91
|
+
return value
|
|
92
|
+
|
|
93
|
+
async def _execute_statement(
|
|
94
|
+
self, statement: SQL, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
95
|
+
) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
|
|
96
|
+
if statement.is_script:
|
|
97
|
+
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
98
|
+
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
99
|
+
|
|
100
|
+
detected_styles = {p.style for p in statement.parameter_info}
|
|
101
|
+
target_style = self.default_parameter_style
|
|
102
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
103
|
+
if unsupported_styles:
|
|
104
|
+
target_style = self.default_parameter_style
|
|
105
|
+
elif detected_styles:
|
|
106
|
+
for style in detected_styles:
|
|
107
|
+
if style in self.supported_parameter_styles:
|
|
108
|
+
target_style = style
|
|
109
|
+
break
|
|
110
|
+
|
|
111
|
+
if statement.is_many:
|
|
112
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
113
|
+
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
114
|
+
|
|
115
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
116
|
+
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
117
|
+
|
|
118
|
+
async def _execute(
|
|
119
|
+
self, sql: str, parameters: Any, statement: SQL, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
120
|
+
) -> Union[SelectResultDict, DMLResultDict]:
|
|
121
|
+
conn = self._connection(connection)
|
|
122
|
+
# Process parameters to handle TypedParameter objects
|
|
123
|
+
parameters = self._process_parameters(parameters)
|
|
124
|
+
|
|
125
|
+
# Check if this is actually a many operation that was misrouted
|
|
126
|
+
if statement.is_many:
|
|
127
|
+
# This should have gone to _execute_many, redirect it
|
|
128
|
+
return await self._execute_many(sql, parameters, connection=connection, **kwargs)
|
|
129
|
+
|
|
130
|
+
# AsyncPG expects parameters as *args, not a single list
|
|
131
|
+
args_for_driver: list[Any] = []
|
|
91
132
|
|
|
92
133
|
if parameters is not None:
|
|
93
|
-
if isinstance(parameters,
|
|
94
|
-
|
|
95
|
-
# data_params_for_statement remains None
|
|
134
|
+
if isinstance(parameters, (list, tuple)):
|
|
135
|
+
args_for_driver.extend(parameters)
|
|
96
136
|
else:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
# Apply any filters from the combined list
|
|
108
|
-
for filter_obj in combined_filters_list:
|
|
109
|
-
statement = statement.apply_filter(filter_obj)
|
|
110
|
-
|
|
111
|
-
# Process the statement
|
|
112
|
-
processed_sql, processed_params, parsed_expr = statement.process()
|
|
113
|
-
|
|
114
|
-
if processed_params is None:
|
|
115
|
-
return processed_sql, ()
|
|
116
|
-
|
|
117
|
-
# Convert question marks to PostgreSQL style $N parameters
|
|
118
|
-
if isinstance(processed_params, (list, tuple)) and "?" in processed_sql:
|
|
119
|
-
# Use a counter to generate $1, $2, etc. for each ? in the SQL that's outside strings/comments
|
|
120
|
-
param_index = 0
|
|
121
|
-
|
|
122
|
-
def replace_question_mark(match: Match[str]) -> str:
|
|
123
|
-
# Only process the match if it's not in a skipped context (string/comment)
|
|
124
|
-
if match.group(1): # This is a question mark outside string/comment
|
|
125
|
-
nonlocal param_index
|
|
126
|
-
param_index += 1
|
|
127
|
-
return f"${param_index}"
|
|
128
|
-
# Return the entire matched text unchanged for strings/comments
|
|
129
|
-
return match.group(0)
|
|
130
|
-
|
|
131
|
-
processed_sql = QUESTION_MARK_PATTERN.sub(replace_question_mark, processed_sql)
|
|
132
|
-
|
|
133
|
-
# Now handle the asyncpg-specific parameter conversion - asyncpg requires positional parameters
|
|
134
|
-
if isinstance(processed_params, dict):
|
|
135
|
-
if parsed_expr is not None:
|
|
136
|
-
# Find named parameters
|
|
137
|
-
named_params = []
|
|
138
|
-
for node in parsed_expr.find_all(exp.Parameter, exp.Placeholder):
|
|
139
|
-
if isinstance(node, exp.Parameter) and node.name and node.name in processed_params:
|
|
140
|
-
named_params.append(node.name)
|
|
141
|
-
elif (
|
|
142
|
-
isinstance(node, exp.Placeholder)
|
|
143
|
-
and isinstance(node.this, str)
|
|
144
|
-
and node.this in processed_params
|
|
145
|
-
):
|
|
146
|
-
named_params.append(node.this)
|
|
147
|
-
|
|
148
|
-
# Convert named parameters to positional
|
|
149
|
-
if named_params:
|
|
150
|
-
# Transform the SQL to use $1, $2, etc.
|
|
151
|
-
def replace_named_with_positional(node: exp.Expression) -> exp.Expression:
|
|
152
|
-
if isinstance(node, exp.Parameter) and node.name and node.name in processed_params:
|
|
153
|
-
idx = named_params.index(node.name) + 1
|
|
154
|
-
return exp.Parameter(this=str(idx))
|
|
155
|
-
if (
|
|
156
|
-
isinstance(node, exp.Placeholder)
|
|
157
|
-
and isinstance(node.this, str)
|
|
158
|
-
and node.this in processed_params
|
|
159
|
-
):
|
|
160
|
-
idx = named_params.index(node.this) + 1
|
|
161
|
-
return exp.Parameter(this=str(idx))
|
|
162
|
-
return node
|
|
163
|
-
|
|
164
|
-
return parsed_expr.transform(replace_named_with_positional, copy=True).sql(
|
|
165
|
-
dialect=self.dialect
|
|
166
|
-
), tuple(processed_params[name] for name in named_params)
|
|
167
|
-
return processed_sql, tuple(processed_params.values())
|
|
168
|
-
if isinstance(processed_params, (list, tuple)):
|
|
169
|
-
return processed_sql, tuple(processed_params)
|
|
170
|
-
return processed_sql, (processed_params,) # type: ignore[unreachable]
|
|
171
|
-
|
|
172
|
-
@overload
|
|
173
|
-
async def select(
|
|
174
|
-
self,
|
|
175
|
-
sql: str,
|
|
176
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
177
|
-
*filters: "StatementFilter",
|
|
178
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
179
|
-
schema_type: None = None,
|
|
180
|
-
**kwargs: Any,
|
|
181
|
-
) -> "Sequence[dict[str, Any]]": ...
|
|
182
|
-
@overload
|
|
183
|
-
async def select(
|
|
184
|
-
self,
|
|
185
|
-
sql: str,
|
|
186
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
187
|
-
*filters: "StatementFilter",
|
|
188
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
189
|
-
schema_type: "type[ModelDTOT]",
|
|
190
|
-
**kwargs: Any,
|
|
191
|
-
) -> "Sequence[ModelDTOT]": ...
|
|
192
|
-
async def select(
|
|
193
|
-
self,
|
|
194
|
-
sql: str,
|
|
195
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
196
|
-
*filters: "StatementFilter",
|
|
197
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
198
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
199
|
-
**kwargs: Any,
|
|
200
|
-
) -> "Sequence[Union[dict[str, Any], ModelDTOT]]":
|
|
201
|
-
"""Fetch data from the database.
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
sql: SQL statement.
|
|
205
|
-
parameters: Query parameters. Can be data or a StatementFilter.
|
|
206
|
-
*filters: Statement filters to apply.
|
|
207
|
-
connection: Optional connection to use.
|
|
208
|
-
schema_type: Optional schema class for the result.
|
|
209
|
-
**kwargs: Additional keyword arguments.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
List of row data as either model instances or dictionaries.
|
|
213
|
-
"""
|
|
214
|
-
connection = self._connection(connection)
|
|
215
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
216
|
-
parameters = parameters if parameters is not None else ()
|
|
217
|
-
|
|
218
|
-
results = await connection.fetch(sql, *parameters) # pyright: ignore
|
|
219
|
-
if not results:
|
|
220
|
-
return []
|
|
221
|
-
return self.to_schema([dict(row.items()) for row in results], schema_type=schema_type) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
222
|
-
|
|
223
|
-
@overload
|
|
224
|
-
async def select_one(
|
|
225
|
-
self,
|
|
226
|
-
sql: str,
|
|
227
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
228
|
-
*filters: "StatementFilter",
|
|
229
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
230
|
-
schema_type: None = None,
|
|
231
|
-
**kwargs: Any,
|
|
232
|
-
) -> "dict[str, Any]": ...
|
|
233
|
-
@overload
|
|
234
|
-
async def select_one(
|
|
235
|
-
self,
|
|
236
|
-
sql: str,
|
|
237
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
238
|
-
*filters: "StatementFilter",
|
|
239
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
240
|
-
schema_type: "type[ModelDTOT]",
|
|
241
|
-
**kwargs: Any,
|
|
242
|
-
) -> "ModelDTOT": ...
|
|
243
|
-
async def select_one(
|
|
244
|
-
self,
|
|
245
|
-
sql: str,
|
|
246
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
247
|
-
*filters: "StatementFilter",
|
|
248
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
249
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
250
|
-
**kwargs: Any,
|
|
251
|
-
) -> "Union[dict[str, Any], ModelDTOT]":
|
|
252
|
-
"""Fetch one row from the database.
|
|
253
|
-
|
|
254
|
-
Args:
|
|
255
|
-
sql: SQL statement.
|
|
256
|
-
parameters: Query parameters. Can be data or a StatementFilter.
|
|
257
|
-
*filters: Statement filters to apply.
|
|
258
|
-
connection: Optional connection to use.
|
|
259
|
-
schema_type: Optional schema class for the result.
|
|
260
|
-
**kwargs: Additional keyword arguments.
|
|
261
|
-
|
|
262
|
-
Returns:
|
|
263
|
-
The first row of the query results.
|
|
264
|
-
"""
|
|
265
|
-
connection = self._connection(connection)
|
|
266
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
267
|
-
parameters = parameters if parameters is not None else ()
|
|
268
|
-
result = await connection.fetchrow(sql, *parameters) # pyright: ignore
|
|
269
|
-
result = self.check_not_found(result)
|
|
270
|
-
return self.to_schema(dict(result.items()), schema_type=schema_type)
|
|
271
|
-
|
|
272
|
-
@overload
|
|
273
|
-
async def select_one_or_none(
|
|
274
|
-
self,
|
|
275
|
-
sql: str,
|
|
276
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
277
|
-
*filters: "StatementFilter",
|
|
278
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
279
|
-
schema_type: None = None,
|
|
280
|
-
**kwargs: Any,
|
|
281
|
-
) -> "Optional[dict[str, Any]]": ...
|
|
282
|
-
@overload
|
|
283
|
-
async def select_one_or_none(
|
|
284
|
-
self,
|
|
285
|
-
sql: str,
|
|
286
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
287
|
-
*filters: "StatementFilter",
|
|
288
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
289
|
-
schema_type: "type[ModelDTOT]",
|
|
290
|
-
**kwargs: Any,
|
|
291
|
-
) -> "Optional[ModelDTOT]": ...
|
|
292
|
-
async def select_one_or_none(
|
|
293
|
-
self,
|
|
294
|
-
sql: str,
|
|
295
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
296
|
-
*filters: "StatementFilter",
|
|
297
|
-
connection: "Optional[AsyncpgConnection]" = None,
|
|
298
|
-
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
299
|
-
**kwargs: Any,
|
|
300
|
-
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
301
|
-
"""Fetch one row from the database.
|
|
137
|
+
args_for_driver.append(parameters)
|
|
138
|
+
|
|
139
|
+
if self.returns_rows(statement.expression):
|
|
140
|
+
records = await conn.fetch(sql, *args_for_driver)
|
|
141
|
+
# Convert asyncpg Records to dicts
|
|
142
|
+
data = [dict(record) for record in records]
|
|
143
|
+
# Get column names from first record or empty list
|
|
144
|
+
column_names = list(records[0].keys()) if records else []
|
|
145
|
+
result: SelectResultDict = {"data": data, "column_names": column_names, "rows_affected": len(records)}
|
|
146
|
+
return result
|
|
302
147
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
self
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
148
|
+
status = await conn.execute(sql, *args_for_driver)
|
|
149
|
+
# Parse row count from status string
|
|
150
|
+
rows_affected = 0
|
|
151
|
+
if status and isinstance(status, str):
|
|
152
|
+
match = ASYNC_PG_STATUS_REGEX.match(status)
|
|
153
|
+
if match and len(match.groups()) >= EXPECTED_REGEX_GROUPS:
|
|
154
|
+
rows_affected = int(match.group(3))
|
|
155
|
+
|
|
156
|
+
dml_result: DMLResultDict = {"rows_affected": rows_affected, "status_message": status or "OK"}
|
|
157
|
+
return dml_result
|
|
158
|
+
|
|
159
|
+
async def _execute_many(
|
|
160
|
+
self, sql: str, param_list: Any, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
161
|
+
) -> DMLResultDict:
|
|
162
|
+
conn = self._connection(connection)
|
|
163
|
+
# Process parameters to handle TypedParameter objects
|
|
164
|
+
param_list = self._process_parameters(param_list)
|
|
165
|
+
|
|
166
|
+
params_list: list[tuple[Any, ...]] = []
|
|
167
|
+
rows_affected = 0
|
|
168
|
+
if param_list and isinstance(param_list, Sequence):
|
|
169
|
+
for param_set in param_list:
|
|
170
|
+
if isinstance(param_set, (list, tuple)):
|
|
171
|
+
params_list.append(tuple(param_set))
|
|
172
|
+
elif param_set is None:
|
|
173
|
+
params_list.append(())
|
|
174
|
+
else:
|
|
175
|
+
params_list.append((param_set,))
|
|
176
|
+
|
|
177
|
+
await conn.executemany(sql, params_list)
|
|
178
|
+
# AsyncPG's executemany returns None, not a status string
|
|
179
|
+
# We need to use the number of parameter sets as the row count
|
|
180
|
+
rows_affected = len(params_list)
|
|
181
|
+
|
|
182
|
+
dml_result: DMLResultDict = {"rows_affected": rows_affected, "status_message": "OK"}
|
|
183
|
+
return dml_result
|
|
184
|
+
|
|
185
|
+
async def _execute_script(
|
|
186
|
+
self, script: str, connection: Optional[AsyncpgConnection] = None, **kwargs: Any
|
|
187
|
+
) -> ScriptResultDict:
|
|
188
|
+
conn = self._connection(connection)
|
|
189
|
+
status = await conn.execute(script)
|
|
190
|
+
|
|
191
|
+
result: ScriptResultDict = {
|
|
192
|
+
"statements_executed": -1, # AsyncPG doesn't provide statement count
|
|
193
|
+
"status_message": status or "SCRIPT EXECUTED",
|
|
194
|
+
}
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
async def _wrap_select_result(
|
|
198
|
+
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
199
|
+
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
200
|
+
records = cast("list[Record]", result["data"])
|
|
201
|
+
column_names = result["column_names"]
|
|
202
|
+
rows_affected = result["rows_affected"]
|
|
203
|
+
|
|
204
|
+
rows_as_dicts: list[dict[str, Any]] = [dict(record) for record in records]
|
|
205
|
+
|
|
206
|
+
if schema_type:
|
|
207
|
+
converted_data_seq = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
|
|
208
|
+
converted_data_list = list(converted_data_seq) if converted_data_seq is not None else []
|
|
209
|
+
return SQLResult[ModelDTOT](
|
|
210
|
+
statement=statement,
|
|
211
|
+
data=converted_data_list,
|
|
212
|
+
column_names=column_names,
|
|
213
|
+
rows_affected=rows_affected,
|
|
214
|
+
operation_type="SELECT",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return SQLResult[RowT](
|
|
218
|
+
statement=statement,
|
|
219
|
+
data=cast("list[RowT]", rows_as_dicts),
|
|
220
|
+
column_names=column_names,
|
|
221
|
+
rows_affected=rows_affected,
|
|
222
|
+
operation_type="SELECT",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
async def _wrap_execute_result(
|
|
226
|
+
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
227
|
+
) -> SQLResult[RowT]:
|
|
228
|
+
operation_type = "UNKNOWN"
|
|
229
|
+
if statement.expression:
|
|
230
|
+
operation_type = str(statement.expression.key).upper()
|
|
231
|
+
|
|
232
|
+
# Handle script results
|
|
233
|
+
if "statements_executed" in result:
|
|
234
|
+
return SQLResult[RowT](
|
|
235
|
+
statement=statement,
|
|
236
|
+
data=cast("list[RowT]", []),
|
|
237
|
+
rows_affected=0,
|
|
238
|
+
operation_type="SCRIPT",
|
|
239
|
+
metadata={
|
|
240
|
+
"status_message": result.get("status_message", ""),
|
|
241
|
+
"statements_executed": result.get("statements_executed", -1),
|
|
242
|
+
},
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Handle DML results
|
|
246
|
+
rows_affected = cast("int", result.get("rows_affected", -1))
|
|
247
|
+
status_message = result.get("status_message", "")
|
|
248
|
+
|
|
249
|
+
return SQLResult[RowT](
|
|
250
|
+
statement=statement,
|
|
251
|
+
data=cast("list[RowT]", []),
|
|
252
|
+
rows_affected=rows_affected,
|
|
253
|
+
operation_type=operation_type,
|
|
254
|
+
metadata={"status_message": status_message},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _connection(self, connection: Optional[AsyncpgConnection] = None) -> AsyncpgConnection:
|
|
258
|
+
"""Get the connection to use for the operation."""
|
|
259
|
+
return connection or self.connection
|
|
260
|
+
|
|
261
|
+
async def _execute_pipeline_native(self, operations: "list[Any]", **options: Any) -> "list[SQLResult[RowT]]":
|
|
262
|
+
"""Native pipeline execution using AsyncPG's efficient batch handling.
|
|
263
|
+
|
|
264
|
+
Note: AsyncPG doesn't have explicit pipeline support like Psycopg, but we can
|
|
265
|
+
achieve similar performance benefits through careful batching and transaction
|
|
266
|
+
management.
|
|
352
267
|
|
|
353
268
|
Args:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
*filters: Statement filters to apply.
|
|
357
|
-
connection: Optional connection to use.
|
|
358
|
-
schema_type: Optional schema class for the result.
|
|
359
|
-
**kwargs: Additional keyword arguments.
|
|
269
|
+
operations: List of PipelineOperation objects
|
|
270
|
+
**options: Pipeline configuration options
|
|
360
271
|
|
|
361
272
|
Returns:
|
|
362
|
-
|
|
273
|
+
List of SQLResult objects from all operations
|
|
363
274
|
"""
|
|
364
|
-
connection = self._connection(connection)
|
|
365
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
366
|
-
parameters = parameters if parameters is not None else ()
|
|
367
|
-
result = await connection.fetchval(sql, *parameters) # pyright: ignore
|
|
368
|
-
result = self.check_not_found(result)
|
|
369
|
-
if schema_type is None:
|
|
370
|
-
return result
|
|
371
|
-
return schema_type(result) # type: ignore[call-arg]
|
|
372
275
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
276
|
+
results: list[Any] = []
|
|
277
|
+
connection = self._connection()
|
|
278
|
+
|
|
279
|
+
# Use a single transaction for all operations
|
|
280
|
+
async with connection.transaction():
|
|
281
|
+
for i, op in enumerate(operations):
|
|
282
|
+
await self._execute_pipeline_operation(connection, i, op, options, results)
|
|
283
|
+
|
|
284
|
+
return results
|
|
285
|
+
|
|
286
|
+
async def _execute_pipeline_operation(
|
|
287
|
+
self, connection: Any, i: int, op: Any, options: dict[str, Any], results: list[Any]
|
|
288
|
+
) -> None:
|
|
289
|
+
"""Execute a single pipeline operation with error handling."""
|
|
290
|
+
from sqlspec.exceptions import PipelineExecutionError
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
# Convert parameters to positional for AsyncPG (requires $1, $2, etc.)
|
|
294
|
+
sql_str = op.sql.to_sql(placeholder_style=ParameterStyle.NUMERIC)
|
|
295
|
+
params = self._convert_to_positional_params(op.sql.parameters)
|
|
296
|
+
|
|
297
|
+
# Apply operation-specific filters
|
|
298
|
+
filtered_sql = self._apply_operation_filters(op.sql, op.filters)
|
|
299
|
+
if filtered_sql != op.sql:
|
|
300
|
+
sql_str = filtered_sql.to_sql(placeholder_style=ParameterStyle.NUMERIC)
|
|
301
|
+
params = self._convert_to_positional_params(filtered_sql.parameters)
|
|
302
|
+
|
|
303
|
+
# Execute based on operation type
|
|
304
|
+
if op.operation_type == "execute_many":
|
|
305
|
+
# AsyncPG has native executemany support
|
|
306
|
+
status = await connection.executemany(sql_str, params)
|
|
307
|
+
# Parse row count from status (e.g., "INSERT 0 5")
|
|
308
|
+
rows_affected = self._parse_asyncpg_status(status)
|
|
309
|
+
result = SQLResult[RowT](
|
|
310
|
+
statement=op.sql,
|
|
311
|
+
data=cast("list[RowT]", []),
|
|
312
|
+
rows_affected=rows_affected,
|
|
313
|
+
operation_type="execute_many",
|
|
314
|
+
metadata={"status_message": status},
|
|
315
|
+
)
|
|
316
|
+
elif op.operation_type == "select":
|
|
317
|
+
# Use fetch for SELECT statements
|
|
318
|
+
rows = await connection.fetch(sql_str, *params)
|
|
319
|
+
# Convert AsyncPG Records to dictionaries
|
|
320
|
+
data = [dict(record) for record in rows] if rows else []
|
|
321
|
+
result = SQLResult[RowT](
|
|
322
|
+
statement=op.sql,
|
|
323
|
+
data=cast("list[RowT]", data),
|
|
324
|
+
rows_affected=len(data),
|
|
325
|
+
operation_type="select",
|
|
326
|
+
metadata={"column_names": list(rows[0].keys()) if rows else []},
|
|
327
|
+
)
|
|
328
|
+
elif op.operation_type == "execute_script":
|
|
329
|
+
# For scripts, split and execute each statement
|
|
330
|
+
script_statements = self._split_script_statements(op.sql.to_sql())
|
|
331
|
+
total_affected = 0
|
|
332
|
+
last_status = ""
|
|
333
|
+
|
|
334
|
+
for stmt in script_statements:
|
|
335
|
+
if stmt.strip():
|
|
336
|
+
status = await connection.execute(stmt)
|
|
337
|
+
total_affected += self._parse_asyncpg_status(status)
|
|
338
|
+
last_status = status
|
|
339
|
+
|
|
340
|
+
result = SQLResult[RowT](
|
|
341
|
+
statement=op.sql,
|
|
342
|
+
data=cast("list[RowT]", []),
|
|
343
|
+
rows_affected=total_affected,
|
|
344
|
+
operation_type="execute_script",
|
|
345
|
+
metadata={"status_message": last_status, "statements_executed": len(script_statements)},
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
status = await connection.execute(sql_str, *params)
|
|
349
|
+
rows_affected = self._parse_asyncpg_status(status)
|
|
350
|
+
result = SQLResult[RowT](
|
|
351
|
+
statement=op.sql,
|
|
352
|
+
data=cast("list[RowT]", []),
|
|
353
|
+
rows_affected=rows_affected,
|
|
354
|
+
operation_type="execute",
|
|
355
|
+
metadata={"status_message": status},
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Add operation context
|
|
359
|
+
result.operation_index = i
|
|
360
|
+
result.pipeline_sql = op.sql
|
|
361
|
+
results.append(result)
|
|
362
|
+
|
|
363
|
+
except Exception as e:
|
|
364
|
+
if options.get("continue_on_error"):
|
|
365
|
+
# Create error result
|
|
366
|
+
error_result = SQLResult[RowT](
|
|
367
|
+
statement=op.sql, error=e, operation_index=i, parameters=op.original_params, data=[]
|
|
368
|
+
)
|
|
369
|
+
results.append(error_result)
|
|
370
|
+
else:
|
|
371
|
+
# Transaction will be rolled back automatically
|
|
372
|
+
msg = f"AsyncPG pipeline failed at operation {i}: {e}"
|
|
373
|
+
raise PipelineExecutionError(
|
|
374
|
+
msg, operation_index=i, partial_results=results, failed_operation=op
|
|
375
|
+
) from e
|
|
411
376
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
"""
|
|
415
|
-
connection = self._connection(connection)
|
|
416
|
-
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
417
|
-
parameters = parameters if parameters is not None else ()
|
|
418
|
-
result = await connection.fetchval(sql, *parameters) # pyright: ignore
|
|
419
|
-
if result is None:
|
|
420
|
-
return None
|
|
421
|
-
if schema_type is None:
|
|
422
|
-
return result
|
|
423
|
-
return schema_type(result) # type: ignore[call-arg]
|
|
377
|
+
def _convert_to_positional_params(self, params: Any) -> "tuple[Any, ...]":
|
|
378
|
+
"""Convert parameters to positional format for AsyncPG.
|
|
424
379
|
|
|
425
|
-
|
|
426
|
-
self,
|
|
427
|
-
sql: str,
|
|
428
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
429
|
-
*filters: "StatementFilter",
|
|
430
|
-
connection: Optional["AsyncpgConnection"] = None,
|
|
431
|
-
**kwargs: Any,
|
|
432
|
-
) -> int:
|
|
433
|
-
"""Insert, update, or delete data from the database.
|
|
380
|
+
AsyncPG requires parameters as positional arguments for $1, $2, etc.
|
|
434
381
|
|
|
435
382
|
Args:
|
|
436
|
-
|
|
437
|
-
parameters: Query parameters. Can be data or a StatementFilter.
|
|
438
|
-
*filters: Statement filters to apply.
|
|
439
|
-
connection: Optional connection to use.
|
|
440
|
-
**kwargs: Additional keyword arguments.
|
|
383
|
+
params: Parameters in various formats
|
|
441
384
|
|
|
442
385
|
Returns:
|
|
443
|
-
|
|
386
|
+
Tuple of positional parameters
|
|
444
387
|
"""
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
388
|
+
if params is None:
|
|
389
|
+
return ()
|
|
390
|
+
if isinstance(params, dict):
|
|
391
|
+
if not params:
|
|
392
|
+
return ()
|
|
393
|
+
# Convert dict to positional based on $1, $2, etc. order
|
|
394
|
+
# This assumes the SQL was compiled with NUMERIC style
|
|
395
|
+
max_param = 0
|
|
396
|
+
for key in params:
|
|
397
|
+
if isinstance(key, str) and key.startswith("param_"):
|
|
398
|
+
try:
|
|
399
|
+
param_num = int(key[6:]) # Extract number from "param_N"
|
|
400
|
+
max_param = max(max_param, param_num)
|
|
401
|
+
except ValueError:
|
|
402
|
+
continue
|
|
403
|
+
|
|
404
|
+
if max_param > 0:
|
|
405
|
+
# Rebuild positional args from param_0, param_1, etc.
|
|
406
|
+
positional = []
|
|
407
|
+
for i in range(max_param + 1):
|
|
408
|
+
param_key = f"param_{i}"
|
|
409
|
+
if param_key in params:
|
|
410
|
+
positional.append(params[param_key])
|
|
411
|
+
return tuple(positional)
|
|
412
|
+
# Fall back to dict values in arbitrary order
|
|
413
|
+
return tuple(params.values())
|
|
414
|
+
if isinstance(params, (list, tuple)):
|
|
415
|
+
return tuple(params)
|
|
416
|
+
return (params,)
|
|
417
|
+
|
|
418
|
+
def _apply_operation_filters(self, sql: "SQL", filters: "list[Any]") -> "SQL":
|
|
419
|
+
"""Apply filters to a SQL object for pipeline operations."""
|
|
420
|
+
if not filters:
|
|
421
|
+
return sql
|
|
422
|
+
|
|
423
|
+
result_sql = sql
|
|
424
|
+
for filter_obj in filters:
|
|
425
|
+
if hasattr(filter_obj, "apply"):
|
|
426
|
+
result_sql = filter_obj.apply(result_sql)
|
|
427
|
+
|
|
428
|
+
return result_sql
|
|
429
|
+
|
|
430
|
+
def _split_script_statements(self, script: str, strip_trailing_semicolon: bool = False) -> "list[str]":
|
|
431
|
+
"""Split a SQL script into individual statements."""
|
|
432
|
+
# Simple splitting on semicolon - could be enhanced with proper SQL parsing
|
|
433
|
+
statements = [stmt.strip() for stmt in script.split(";")]
|
|
434
|
+
return [stmt for stmt in statements if stmt]
|
|
435
|
+
|
|
436
|
+
@staticmethod
|
|
437
|
+
def _parse_asyncpg_status(status: str) -> int:
|
|
438
|
+
"""Parse AsyncPG status string to extract row count.
|
|
485
439
|
|
|
486
440
|
Args:
|
|
487
|
-
|
|
488
|
-
parameters: Query parameters. Can be data or a StatementFilter.
|
|
489
|
-
*filters: Statement filters to apply.
|
|
490
|
-
connection: Optional connection to use.
|
|
491
|
-
schema_type: Optional schema class for the result.
|
|
492
|
-
**kwargs: Additional keyword arguments.
|
|
441
|
+
status: Status string like "INSERT 0 1", "UPDATE 3", "DELETE 2"
|
|
493
442
|
|
|
494
443
|
Returns:
|
|
495
|
-
|
|
444
|
+
Number of affected rows, or 0 if cannot parse
|
|
496
445
|
"""
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
parameters = parameters if parameters is not None else ()
|
|
500
|
-
result = await connection.fetchrow(sql, *parameters) # pyright: ignore
|
|
501
|
-
if result is None:
|
|
502
|
-
return None
|
|
446
|
+
if not status:
|
|
447
|
+
return 0
|
|
503
448
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
Args:
|
|
516
|
-
sql: SQL statement.
|
|
517
|
-
parameters: Query parameters.
|
|
518
|
-
connection: Optional connection to use.
|
|
519
|
-
**kwargs: Additional keyword arguments.
|
|
449
|
+
match = ASYNC_PG_STATUS_REGEX.match(status.strip())
|
|
450
|
+
if match:
|
|
451
|
+
# For INSERT: "INSERT 0 5" -> groups: (INSERT, 0, 5)
|
|
452
|
+
# For UPDATE/DELETE: "UPDATE 3" -> groups: (UPDATE, None, 3)
|
|
453
|
+
groups = match.groups()
|
|
454
|
+
if len(groups) >= EXPECTED_REGEX_GROUPS:
|
|
455
|
+
try:
|
|
456
|
+
# The last group is always the row count
|
|
457
|
+
return int(groups[-1])
|
|
458
|
+
except (ValueError, IndexError):
|
|
459
|
+
pass
|
|
520
460
|
|
|
521
|
-
|
|
522
|
-
Status message for the operation.
|
|
523
|
-
"""
|
|
524
|
-
connection = self._connection(connection)
|
|
525
|
-
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
526
|
-
parameters = parameters if parameters is not None else ()
|
|
527
|
-
return await connection.execute(sql, *parameters) # pyright: ignore
|
|
528
|
-
|
|
529
|
-
def _connection(self, connection: "Optional[AsyncpgConnection]" = None) -> "AsyncpgConnection":
|
|
530
|
-
"""Return the connection to use. If None, use the default connection."""
|
|
531
|
-
return connection if connection is not None else self.connection
|
|
461
|
+
return 0
|