sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/sql.py
CHANGED
|
@@ -1,28 +1,63 @@
|
|
|
1
1
|
"""SQL statement handling with centralized parameter management."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from
|
|
3
|
+
import operator
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
5
6
|
|
|
6
7
|
import sqlglot
|
|
7
8
|
import sqlglot.expressions as exp
|
|
8
|
-
from sqlglot.dialects.dialect import DialectType
|
|
9
9
|
from sqlglot.errors import ParseError
|
|
10
|
+
from typing_extensions import TypeAlias
|
|
10
11
|
|
|
11
|
-
from sqlspec.exceptions import RiskLevel, SQLValidationError
|
|
12
|
+
from sqlspec.exceptions import RiskLevel, SQLParsingError, SQLValidationError
|
|
12
13
|
from sqlspec.statement.filters import StatementFilter
|
|
13
|
-
from sqlspec.statement.parameters import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
from sqlspec.statement.parameters import (
|
|
15
|
+
SQLGLOT_INCOMPATIBLE_STYLES,
|
|
16
|
+
ParameterConverter,
|
|
17
|
+
ParameterStyle,
|
|
18
|
+
ParameterValidator,
|
|
19
|
+
)
|
|
20
|
+
from sqlspec.statement.pipelines import SQLProcessingContext, StatementPipeline
|
|
21
|
+
from sqlspec.statement.pipelines.transformers import CommentAndHintRemover, ParameterizeLiterals
|
|
17
22
|
from sqlspec.statement.pipelines.validators import DMLSafetyValidator, ParameterStyleValidator
|
|
18
|
-
from sqlspec.typing import is_dict
|
|
19
23
|
from sqlspec.utils.logging import get_logger
|
|
24
|
+
from sqlspec.utils.type_guards import (
|
|
25
|
+
can_append_to_statement,
|
|
26
|
+
can_extract_parameters,
|
|
27
|
+
has_parameter_value,
|
|
28
|
+
has_risk_level,
|
|
29
|
+
is_dict,
|
|
30
|
+
is_expression,
|
|
31
|
+
is_statement_filter,
|
|
32
|
+
supports_limit,
|
|
33
|
+
supports_offset,
|
|
34
|
+
supports_order_by,
|
|
35
|
+
supports_where,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from sqlglot.dialects.dialect import DialectType
|
|
40
|
+
|
|
41
|
+
from sqlspec.statement.parameters import ParameterNormalizationState
|
|
20
42
|
|
|
21
43
|
__all__ = ("SQL", "SQLConfig", "Statement")
|
|
22
44
|
|
|
23
45
|
logger = get_logger("sqlspec.statement")
|
|
24
46
|
|
|
25
|
-
Statement = Union[str, exp.Expression, "SQL"]
|
|
47
|
+
Statement: TypeAlias = Union[str, exp.Expression, "SQL"]
|
|
48
|
+
|
|
49
|
+
# Parameter naming constants
|
|
50
|
+
PARAM_PREFIX = "param_"
|
|
51
|
+
POS_PARAM_PREFIX = "pos_param_"
|
|
52
|
+
KW_POS_PARAM_PREFIX = "kw_pos_param_"
|
|
53
|
+
ARG_PREFIX = "arg_"
|
|
54
|
+
|
|
55
|
+
# Cache and limit constants
|
|
56
|
+
DEFAULT_CACHE_SIZE = 1000
|
|
57
|
+
|
|
58
|
+
# Oracle/Colon style parameter constants
|
|
59
|
+
COLON_PARAM_ONE = "1"
|
|
60
|
+
COLON_PARAM_MIN_INDEX = 1
|
|
26
61
|
|
|
27
62
|
|
|
28
63
|
@dataclass
|
|
@@ -39,9 +74,30 @@ class _ProcessedState:
|
|
|
39
74
|
|
|
40
75
|
@dataclass
|
|
41
76
|
class SQLConfig:
|
|
42
|
-
"""Configuration for SQL statement behavior.
|
|
77
|
+
"""Configuration for SQL statement behavior.
|
|
78
|
+
|
|
79
|
+
Uses conservative defaults that prioritize compatibility and robustness
|
|
80
|
+
over strict enforcement, making it easier to work with diverse SQL dialects
|
|
81
|
+
and complex queries.
|
|
82
|
+
|
|
83
|
+
Component Lists:
|
|
84
|
+
transformers: Optional list of SQL transformers for explicit staging
|
|
85
|
+
validators: Optional list of SQL validators for explicit staging
|
|
86
|
+
analyzers: Optional list of SQL analyzers for explicit staging
|
|
87
|
+
|
|
88
|
+
Configuration Options:
|
|
89
|
+
parameter_converter: Handles parameter style conversions
|
|
90
|
+
parameter_validator: Validates parameter usage and styles
|
|
91
|
+
analysis_cache_size: Cache size for analysis results
|
|
92
|
+
input_sql_had_placeholders: Populated by SQL.__init__ to track original SQL state
|
|
93
|
+
dialect: SQL dialect to use for parsing and generation
|
|
94
|
+
|
|
95
|
+
Parameter Style Configuration:
|
|
96
|
+
allowed_parameter_styles: Allowed parameter styles (e.g., ('qmark', 'named_colon'))
|
|
97
|
+
target_parameter_style: Target parameter style for SQL generation
|
|
98
|
+
allow_mixed_parameter_styles: Whether to allow mixing parameter styles in same query
|
|
99
|
+
"""
|
|
43
100
|
|
|
44
|
-
# Behavior flags
|
|
45
101
|
enable_parsing: bool = True
|
|
46
102
|
enable_validation: bool = True
|
|
47
103
|
enable_transformations: bool = True
|
|
@@ -49,29 +105,23 @@ class SQLConfig:
|
|
|
49
105
|
enable_normalization: bool = True
|
|
50
106
|
strict_mode: bool = False
|
|
51
107
|
cache_parsed_expression: bool = True
|
|
108
|
+
parse_errors_as_warnings: bool = True
|
|
52
109
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
analyzers: Optional[list[Any]] = None
|
|
110
|
+
transformers: "Optional[list[Any]]" = None
|
|
111
|
+
validators: "Optional[list[Any]]" = None
|
|
112
|
+
analyzers: "Optional[list[Any]]" = None
|
|
57
113
|
|
|
58
|
-
# Other configs
|
|
59
114
|
parameter_converter: ParameterConverter = field(default_factory=ParameterConverter)
|
|
60
115
|
parameter_validator: ParameterValidator = field(default_factory=ParameterValidator)
|
|
61
116
|
analysis_cache_size: int = 1000
|
|
62
|
-
input_sql_had_placeholders: bool = False
|
|
63
|
-
|
|
64
|
-
# Parameter style configuration
|
|
65
|
-
allowed_parameter_styles: Optional[tuple[str, ...]] = None
|
|
66
|
-
"""Allowed parameter styles for this SQL configuration (e.g., ('qmark', 'named_colon'))."""
|
|
67
|
-
|
|
68
|
-
target_parameter_style: Optional[str] = None
|
|
69
|
-
"""Target parameter style for SQL generation."""
|
|
117
|
+
input_sql_had_placeholders: bool = False
|
|
118
|
+
dialect: "Optional[DialectType]" = None
|
|
70
119
|
|
|
120
|
+
allowed_parameter_styles: "Optional[tuple[str, ...]]" = None
|
|
121
|
+
target_parameter_style: "Optional[str]" = None
|
|
71
122
|
allow_mixed_parameter_styles: bool = False
|
|
72
|
-
"""Whether to allow mixing named and positional parameters in same query."""
|
|
73
123
|
|
|
74
|
-
def validate_parameter_style(self, style: Union[ParameterStyle, str]) -> bool:
|
|
124
|
+
def validate_parameter_style(self, style: "Union[ParameterStyle, str]") -> bool:
|
|
75
125
|
"""Check if a parameter style is allowed.
|
|
76
126
|
|
|
77
127
|
Args:
|
|
@@ -81,7 +131,7 @@ class SQLConfig:
|
|
|
81
131
|
True if the style is allowed, False otherwise
|
|
82
132
|
"""
|
|
83
133
|
if self.allowed_parameter_styles is None:
|
|
84
|
-
return True
|
|
134
|
+
return True
|
|
85
135
|
style_str = str(style)
|
|
86
136
|
return style_str in self.allowed_parameter_styles
|
|
87
137
|
|
|
@@ -91,36 +141,23 @@ class SQLConfig:
|
|
|
91
141
|
Returns:
|
|
92
142
|
StatementPipeline configured with transformers, validators, and analyzers
|
|
93
143
|
"""
|
|
94
|
-
# Import here to avoid circular dependencies
|
|
95
|
-
|
|
96
|
-
# Create transformers based on config
|
|
97
144
|
transformers = []
|
|
98
145
|
if self.transformers is not None:
|
|
99
|
-
# Use explicit transformers if provided
|
|
100
146
|
transformers = list(self.transformers)
|
|
101
|
-
# Use default transformers
|
|
102
147
|
elif self.enable_transformations:
|
|
103
|
-
# Use target_parameter_style if available, otherwise default to "?"
|
|
104
148
|
placeholder_style = self.target_parameter_style or "?"
|
|
105
|
-
transformers = [
|
|
149
|
+
transformers = [CommentAndHintRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
|
|
106
150
|
|
|
107
|
-
# Create validators based on config
|
|
108
151
|
validators = []
|
|
109
152
|
if self.validators is not None:
|
|
110
|
-
# Use explicit validators if provided
|
|
111
153
|
validators = list(self.validators)
|
|
112
|
-
# Use default validators
|
|
113
154
|
elif self.enable_validation:
|
|
114
155
|
validators = [ParameterStyleValidator(fail_on_violation=self.strict_mode), DMLSafetyValidator()]
|
|
115
156
|
|
|
116
|
-
# Create analyzers based on config
|
|
117
157
|
analyzers = []
|
|
118
158
|
if self.analyzers is not None:
|
|
119
|
-
# Use explicit analyzers if provided
|
|
120
159
|
analyzers = list(self.analyzers)
|
|
121
|
-
# Use default analyzers
|
|
122
160
|
elif self.enable_analysis:
|
|
123
|
-
# Currently no default analyzers
|
|
124
161
|
analyzers = []
|
|
125
162
|
|
|
126
163
|
return StatementPipeline(transformers=transformers, validators=validators, analyzers=analyzers)
|
|
@@ -139,36 +176,39 @@ class SQL:
|
|
|
139
176
|
"""
|
|
140
177
|
|
|
141
178
|
__slots__ = (
|
|
142
|
-
"_builder_result_type",
|
|
143
|
-
"_config",
|
|
144
|
-
"_dialect",
|
|
145
|
-
"_filters",
|
|
146
|
-
"_is_many",
|
|
147
|
-
"_is_script",
|
|
148
|
-
"_named_params",
|
|
149
|
-
"_original_parameters",
|
|
150
|
-
"_original_sql",
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"
|
|
154
|
-
"
|
|
155
|
-
"
|
|
156
|
-
"
|
|
179
|
+
"_builder_result_type",
|
|
180
|
+
"_config",
|
|
181
|
+
"_dialect",
|
|
182
|
+
"_filters",
|
|
183
|
+
"_is_many",
|
|
184
|
+
"_is_script",
|
|
185
|
+
"_named_params",
|
|
186
|
+
"_original_parameters",
|
|
187
|
+
"_original_sql",
|
|
188
|
+
"_parameter_normalization_state",
|
|
189
|
+
"_placeholder_mapping",
|
|
190
|
+
"_positional_params",
|
|
191
|
+
"_processed_state",
|
|
192
|
+
"_processing_context",
|
|
193
|
+
"_raw_sql",
|
|
194
|
+
"_statement",
|
|
157
195
|
)
|
|
158
196
|
|
|
159
197
|
def __init__(
|
|
160
198
|
self,
|
|
161
|
-
statement: Union[str, exp.Expression,
|
|
162
|
-
*parameters: Union[Any, StatementFilter, list[Union[Any, StatementFilter]]],
|
|
163
|
-
_dialect: DialectType = None,
|
|
164
|
-
_config: Optional[SQLConfig] = None,
|
|
165
|
-
_builder_result_type: Optional[type] = None,
|
|
166
|
-
_existing_state: Optional[dict[str, Any]] = None,
|
|
199
|
+
statement: "Union[str, exp.Expression, 'SQL']",
|
|
200
|
+
*parameters: "Union[Any, StatementFilter, list[Union[Any, StatementFilter]]]",
|
|
201
|
+
_dialect: "DialectType" = None,
|
|
202
|
+
_config: "Optional[SQLConfig]" = None,
|
|
203
|
+
_builder_result_type: "Optional[type]" = None,
|
|
204
|
+
_existing_state: "Optional[dict[str, Any]]" = None,
|
|
167
205
|
**kwargs: Any,
|
|
168
206
|
) -> None:
|
|
169
207
|
"""Initialize SQL with centralized parameter management."""
|
|
208
|
+
if "config" in kwargs and _config is None:
|
|
209
|
+
_config = kwargs.pop("config")
|
|
170
210
|
self._config = _config or SQLConfig()
|
|
171
|
-
self._dialect = _dialect
|
|
211
|
+
self._dialect = _dialect or (self._config.dialect if self._config else None)
|
|
172
212
|
self._builder_result_type = _builder_result_type
|
|
173
213
|
self._processed_state: Optional[_ProcessedState] = None
|
|
174
214
|
self._processing_context: Optional[SQLProcessingContext] = None
|
|
@@ -180,6 +220,7 @@ class SQL:
|
|
|
180
220
|
self._original_parameters: Any = None
|
|
181
221
|
self._original_sql: str = ""
|
|
182
222
|
self._placeholder_mapping: dict[str, Union[str, int]] = {}
|
|
223
|
+
self._parameter_normalization_state: Optional[ParameterNormalizationState] = None
|
|
183
224
|
self._is_many: bool = False
|
|
184
225
|
self._is_script: bool = False
|
|
185
226
|
|
|
@@ -191,13 +232,17 @@ class SQL:
|
|
|
191
232
|
if _existing_state:
|
|
192
233
|
self._load_from_existing_state(_existing_state)
|
|
193
234
|
|
|
194
|
-
if not isinstance(statement, SQL):
|
|
235
|
+
if not isinstance(statement, SQL) and not _existing_state:
|
|
195
236
|
self._set_original_parameters(*parameters)
|
|
196
237
|
|
|
197
238
|
self._process_parameters(*parameters, **kwargs)
|
|
198
239
|
|
|
199
240
|
def _init_from_sql_object(
|
|
200
|
-
self,
|
|
241
|
+
self,
|
|
242
|
+
statement: "SQL",
|
|
243
|
+
dialect: "DialectType",
|
|
244
|
+
config: "Optional[SQLConfig]",
|
|
245
|
+
builder_result_type: "Optional[type]",
|
|
201
246
|
) -> None:
|
|
202
247
|
"""Initialize attributes from an existing SQL object."""
|
|
203
248
|
self._statement = statement._statement
|
|
@@ -210,24 +255,21 @@ class SQL:
|
|
|
210
255
|
self._original_parameters = statement._original_parameters
|
|
211
256
|
self._original_sql = statement._original_sql
|
|
212
257
|
self._placeholder_mapping = statement._placeholder_mapping.copy()
|
|
258
|
+
self._parameter_normalization_state = statement._parameter_normalization_state
|
|
213
259
|
self._positional_params.extend(statement._positional_params)
|
|
214
260
|
self._named_params.update(statement._named_params)
|
|
215
261
|
self._filters.extend(statement._filters)
|
|
216
262
|
|
|
217
|
-
def _init_from_str_or_expression(self, statement: Union[str, exp.Expression]) -> None:
|
|
263
|
+
def _init_from_str_or_expression(self, statement: "Union[str, exp.Expression]") -> None:
|
|
218
264
|
"""Initialize attributes from a SQL string or expression."""
|
|
219
265
|
if isinstance(statement, str):
|
|
220
266
|
self._raw_sql = statement
|
|
221
|
-
if self._raw_sql and not self._config.input_sql_had_placeholders:
|
|
222
|
-
param_info = self._config.parameter_validator.extract_parameters(self._raw_sql)
|
|
223
|
-
if param_info:
|
|
224
|
-
self._config = replace(self._config, input_sql_had_placeholders=True)
|
|
225
267
|
self._statement = self._to_expression(statement)
|
|
226
268
|
else:
|
|
227
269
|
self._raw_sql = statement.sql(dialect=self._dialect) # pyright: ignore
|
|
228
270
|
self._statement = statement
|
|
229
271
|
|
|
230
|
-
def _load_from_existing_state(self, existing_state: dict[str, Any]) -> None:
|
|
272
|
+
def _load_from_existing_state(self, existing_state: "dict[str, Any]") -> None:
|
|
231
273
|
"""Load state from a dictionary (used by copy)."""
|
|
232
274
|
self._positional_params = list(existing_state.get("positional_params", self._positional_params))
|
|
233
275
|
self._named_params = dict(existing_state.get("named_params", self._named_params))
|
|
@@ -235,15 +277,16 @@ class SQL:
|
|
|
235
277
|
self._is_many = existing_state.get("is_many", self._is_many)
|
|
236
278
|
self._is_script = existing_state.get("is_script", self._is_script)
|
|
237
279
|
self._raw_sql = existing_state.get("raw_sql", self._raw_sql)
|
|
280
|
+
self._original_parameters = existing_state.get("original_parameters", self._original_parameters)
|
|
238
281
|
|
|
239
282
|
def _set_original_parameters(self, *parameters: Any) -> None:
|
|
240
283
|
"""Store the original parameters for compatibility."""
|
|
241
|
-
if len(parameters) == 1 and
|
|
284
|
+
if len(parameters) == 0 or (len(parameters) == 1 and is_statement_filter(parameters[0])):
|
|
285
|
+
self._original_parameters = None
|
|
286
|
+
elif len(parameters) == 1 and isinstance(parameters[0], (list, tuple)):
|
|
242
287
|
self._original_parameters = parameters[0]
|
|
243
|
-
elif len(parameters) > 1:
|
|
244
|
-
self._original_parameters = parameters
|
|
245
288
|
else:
|
|
246
|
-
self._original_parameters =
|
|
289
|
+
self._original_parameters = parameters
|
|
247
290
|
|
|
248
291
|
def _process_parameters(self, *parameters: Any, **kwargs: Any) -> None:
|
|
249
292
|
"""Process positional and keyword arguments for parameters and filters."""
|
|
@@ -254,7 +297,7 @@ class SQL:
|
|
|
254
297
|
param_value = kwargs.pop("parameters")
|
|
255
298
|
if isinstance(param_value, (list, tuple)):
|
|
256
299
|
self._positional_params.extend(param_value)
|
|
257
|
-
elif
|
|
300
|
+
elif is_dict(param_value):
|
|
258
301
|
self._named_params.update(param_value)
|
|
259
302
|
else:
|
|
260
303
|
self._positional_params.append(param_value)
|
|
@@ -265,7 +308,7 @@ class SQL:
|
|
|
265
308
|
|
|
266
309
|
def _process_parameter_item(self, item: Any) -> None:
|
|
267
310
|
"""Process a single item from the parameters list."""
|
|
268
|
-
if
|
|
311
|
+
if is_statement_filter(item):
|
|
269
312
|
self._filters.append(item)
|
|
270
313
|
pos_params, named_params = self._extract_filter_parameters(item)
|
|
271
314
|
self._positional_params.extend(pos_params)
|
|
@@ -273,7 +316,7 @@ class SQL:
|
|
|
273
316
|
elif isinstance(item, list):
|
|
274
317
|
for sub_item in item:
|
|
275
318
|
self._process_parameter_item(sub_item)
|
|
276
|
-
elif
|
|
319
|
+
elif is_dict(item):
|
|
277
320
|
self._named_params.update(item)
|
|
278
321
|
elif isinstance(item, tuple):
|
|
279
322
|
self._positional_params.extend(item)
|
|
@@ -289,120 +332,255 @@ class SQL:
|
|
|
289
332
|
if self._processed_state is not None:
|
|
290
333
|
return
|
|
291
334
|
|
|
292
|
-
# Get the final expression and parameters after filters
|
|
293
335
|
final_expr, final_params = self._build_final_state()
|
|
336
|
+
has_placeholders = self._detect_placeholders()
|
|
337
|
+
initial_sql_for_context, final_params = self._prepare_context_sql(final_expr, final_params)
|
|
338
|
+
|
|
339
|
+
context = self._create_processing_context(initial_sql_for_context, final_expr, final_params, has_placeholders)
|
|
340
|
+
result = self._run_pipeline(context)
|
|
341
|
+
|
|
342
|
+
processed_sql, merged_params = self._process_pipeline_result(result, final_params, context)
|
|
343
|
+
|
|
344
|
+
self._finalize_processed_state(result, processed_sql, merged_params)
|
|
294
345
|
|
|
295
|
-
|
|
346
|
+
def _detect_placeholders(self) -> bool:
|
|
347
|
+
"""Detect if the raw SQL has placeholders."""
|
|
296
348
|
if self._raw_sql:
|
|
297
349
|
validator = self._config.parameter_validator
|
|
298
350
|
raw_param_info = validator.extract_parameters(self._raw_sql)
|
|
299
351
|
has_placeholders = bool(raw_param_info)
|
|
300
|
-
|
|
301
|
-
|
|
352
|
+
if has_placeholders:
|
|
353
|
+
self._config.input_sql_had_placeholders = True
|
|
354
|
+
return has_placeholders
|
|
355
|
+
return self._config.input_sql_had_placeholders
|
|
356
|
+
|
|
357
|
+
def _prepare_context_sql(self, final_expr: exp.Expression, final_params: Any) -> tuple[str, Any]:
|
|
358
|
+
"""Prepare SQL string and parameters for context."""
|
|
359
|
+
initial_sql_for_context = self._raw_sql or final_expr.sql(dialect=self._dialect or self._config.dialect)
|
|
360
|
+
|
|
361
|
+
if is_expression(final_expr) and self._placeholder_mapping:
|
|
362
|
+
initial_sql_for_context = final_expr.sql(dialect=self._dialect or self._config.dialect)
|
|
363
|
+
if self._placeholder_mapping:
|
|
364
|
+
final_params = self._normalize_parameters(final_params)
|
|
365
|
+
|
|
366
|
+
return initial_sql_for_context, final_params
|
|
367
|
+
|
|
368
|
+
def _normalize_parameters(self, final_params: Any) -> Any:
|
|
369
|
+
"""Normalize parameters based on placeholder mapping."""
|
|
370
|
+
if is_dict(final_params):
|
|
371
|
+
normalized_params = {}
|
|
372
|
+
for placeholder_key, original_name in self._placeholder_mapping.items():
|
|
373
|
+
if str(original_name) in final_params:
|
|
374
|
+
normalized_params[placeholder_key] = final_params[str(original_name)]
|
|
375
|
+
non_oracle_params = {
|
|
376
|
+
key: value
|
|
377
|
+
for key, value in final_params.items()
|
|
378
|
+
if key not in {str(name) for name in self._placeholder_mapping.values()}
|
|
379
|
+
}
|
|
380
|
+
normalized_params.update(non_oracle_params)
|
|
381
|
+
return normalized_params
|
|
382
|
+
if isinstance(final_params, (list, tuple)):
|
|
383
|
+
validator = self._config.parameter_validator
|
|
384
|
+
param_info = validator.extract_parameters(self._raw_sql)
|
|
385
|
+
|
|
386
|
+
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
387
|
+
|
|
388
|
+
if all_numeric:
|
|
389
|
+
normalized_params = {}
|
|
302
390
|
|
|
303
|
-
|
|
304
|
-
if has_placeholders and not self._config.input_sql_had_placeholders:
|
|
305
|
-
self._config = replace(self._config, input_sql_had_placeholders=True)
|
|
391
|
+
min_param_num = min(int(p.name) for p in param_info if p.name)
|
|
306
392
|
|
|
307
|
-
|
|
393
|
+
for i, param in enumerate(final_params):
|
|
394
|
+
param_num = str(i + min_param_num)
|
|
395
|
+
normalized_params[param_num] = param
|
|
396
|
+
|
|
397
|
+
return normalized_params
|
|
398
|
+
normalized_params = {}
|
|
399
|
+
for i, param in enumerate(final_params):
|
|
400
|
+
if i < len(param_info):
|
|
401
|
+
placeholder_key = f"{PARAM_PREFIX}{param_info[i].ordinal}"
|
|
402
|
+
normalized_params[placeholder_key] = param
|
|
403
|
+
return normalized_params
|
|
404
|
+
return final_params
|
|
405
|
+
|
|
406
|
+
def _create_processing_context(
|
|
407
|
+
self, initial_sql_for_context: str, final_expr: exp.Expression, final_params: Any, has_placeholders: bool
|
|
408
|
+
) -> SQLProcessingContext:
|
|
409
|
+
"""Create SQL processing context."""
|
|
308
410
|
context = SQLProcessingContext(
|
|
309
|
-
initial_sql_string=
|
|
310
|
-
dialect=self._dialect,
|
|
411
|
+
initial_sql_string=initial_sql_for_context,
|
|
412
|
+
dialect=self._dialect or self._config.dialect,
|
|
311
413
|
config=self._config,
|
|
312
|
-
current_expression=final_expr,
|
|
313
414
|
initial_expression=final_expr,
|
|
415
|
+
current_expression=final_expr,
|
|
314
416
|
merged_parameters=final_params,
|
|
315
|
-
input_sql_had_placeholders=has_placeholders,
|
|
417
|
+
input_sql_had_placeholders=has_placeholders or self._config.input_sql_had_placeholders,
|
|
316
418
|
)
|
|
317
419
|
|
|
318
|
-
|
|
420
|
+
if self._placeholder_mapping:
|
|
421
|
+
context.extra_info["placeholder_map"] = self._placeholder_mapping
|
|
422
|
+
|
|
423
|
+
# Set normalization state if available
|
|
424
|
+
if self._parameter_normalization_state:
|
|
425
|
+
context.parameter_normalization = self._parameter_normalization_state
|
|
426
|
+
|
|
319
427
|
validator = self._config.parameter_validator
|
|
320
428
|
context.parameter_info = validator.extract_parameters(context.initial_sql_string)
|
|
321
429
|
|
|
322
|
-
|
|
430
|
+
return context
|
|
431
|
+
|
|
432
|
+
def _run_pipeline(self, context: SQLProcessingContext) -> Any:
|
|
433
|
+
"""Run the SQL processing pipeline."""
|
|
323
434
|
pipeline = self._config.get_statement_pipeline()
|
|
324
435
|
result = pipeline.execute_pipeline(context)
|
|
325
|
-
|
|
326
|
-
# Store the processing context for later use
|
|
327
436
|
self._processing_context = result.context
|
|
437
|
+
return result
|
|
328
438
|
|
|
329
|
-
|
|
439
|
+
def _process_pipeline_result(
|
|
440
|
+
self, result: Any, final_params: Any, context: SQLProcessingContext
|
|
441
|
+
) -> tuple[str, Any]:
|
|
442
|
+
"""Process the result from the pipeline."""
|
|
330
443
|
processed_expr = result.expression
|
|
444
|
+
|
|
331
445
|
if isinstance(processed_expr, exp.Anonymous):
|
|
332
446
|
processed_sql = self._raw_sql or context.initial_sql_string
|
|
333
447
|
else:
|
|
334
|
-
processed_sql = processed_expr.sql(dialect=self._dialect, comments=False)
|
|
448
|
+
processed_sql = processed_expr.sql(dialect=self._dialect or self._config.dialect, comments=False)
|
|
335
449
|
logger.debug("Processed expression SQL: '%s'", processed_sql)
|
|
336
450
|
|
|
337
|
-
# Check if we need to denormalize pyformat placeholders
|
|
338
451
|
if self._placeholder_mapping and self._original_sql:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
452
|
+
processed_sql, result = self._denormalize_sql(processed_sql, result)
|
|
453
|
+
|
|
454
|
+
merged_params = self._merge_pipeline_parameters(result, final_params)
|
|
455
|
+
|
|
456
|
+
return processed_sql, merged_params
|
|
457
|
+
|
|
458
|
+
def _denormalize_sql(self, processed_sql: str, result: Any) -> tuple[str, Any]:
|
|
459
|
+
"""Denormalize SQL back to original parameter style."""
|
|
460
|
+
|
|
461
|
+
original_sql = self._original_sql
|
|
462
|
+
param_info = self._config.parameter_validator.extract_parameters(original_sql)
|
|
463
|
+
target_styles = {p.style for p in param_info}
|
|
464
|
+
|
|
465
|
+
logger.debug(
|
|
466
|
+
"Denormalizing SQL: before='%s', original='%s', styles=%s", processed_sql, original_sql, target_styles
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
|
|
470
|
+
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
471
|
+
processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
|
|
472
|
+
)
|
|
473
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
474
|
+
elif ParameterStyle.NAMED_PYFORMAT in target_styles:
|
|
475
|
+
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
476
|
+
processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
|
|
477
|
+
)
|
|
478
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
479
|
+
# Also denormalize the parameters back to their original names
|
|
480
|
+
if (
|
|
481
|
+
self._placeholder_mapping
|
|
482
|
+
and result.context.merged_parameters
|
|
483
|
+
and is_dict(result.context.merged_parameters)
|
|
484
|
+
):
|
|
485
|
+
result.context.merged_parameters = self._denormalize_pyformat_params(result.context.merged_parameters)
|
|
486
|
+
elif ParameterStyle.POSITIONAL_COLON in target_styles:
|
|
487
|
+
processed_param_info = self._config.parameter_validator.extract_parameters(processed_sql)
|
|
488
|
+
has_param_placeholders = any(p.name and p.name.startswith(PARAM_PREFIX) for p in processed_param_info)
|
|
489
|
+
|
|
490
|
+
if has_param_placeholders:
|
|
491
|
+
logger.debug("Skipping denormalization for param_N placeholders")
|
|
366
492
|
else:
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
bool(self._placeholder_mapping),
|
|
370
|
-
bool(self._original_sql),
|
|
493
|
+
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
494
|
+
processed_sql, param_info, ParameterStyle.POSITIONAL_COLON
|
|
371
495
|
)
|
|
496
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
497
|
+
if (
|
|
498
|
+
self._placeholder_mapping
|
|
499
|
+
and result.context.merged_parameters
|
|
500
|
+
and is_dict(result.context.merged_parameters)
|
|
501
|
+
):
|
|
502
|
+
result.context.merged_parameters = self._denormalize_colon_params(result.context.merged_parameters)
|
|
503
|
+
else:
|
|
504
|
+
logger.debug(
|
|
505
|
+
"No denormalization needed: mapping=%s, original=%s",
|
|
506
|
+
bool(self._placeholder_mapping),
|
|
507
|
+
bool(self._original_sql),
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
return processed_sql, result
|
|
372
511
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
#
|
|
376
|
-
#
|
|
377
|
-
if
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
512
|
+
def _denormalize_colon_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
|
|
513
|
+
"""Denormalize colon-style parameters back to numeric format."""
|
|
514
|
+
# For positional colon style, all params should have numeric keys
|
|
515
|
+
# Just return the params as-is if they already have the right format
|
|
516
|
+
if all(key.isdigit() for key in params):
|
|
517
|
+
return params
|
|
518
|
+
|
|
519
|
+
# For positional colon, we need ALL parameters in the final result
|
|
520
|
+
# This includes both user parameters and extracted literals
|
|
521
|
+
# We should NOT filter out extracted parameters (param_0, param_1, etc)
|
|
522
|
+
# because they need to be included in the final parameter conversion
|
|
523
|
+
return params
|
|
524
|
+
|
|
525
|
+
def _denormalize_pyformat_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
|
|
526
|
+
"""Denormalize pyformat parameters back to their original names."""
|
|
527
|
+
denormalized_params = {}
|
|
528
|
+
for placeholder_key, original_name in self._placeholder_mapping.items():
|
|
529
|
+
if placeholder_key in params:
|
|
530
|
+
# For pyformat, the original_name is the actual parameter name (e.g., 'max_value')
|
|
531
|
+
denormalized_params[str(original_name)] = params[placeholder_key]
|
|
532
|
+
# Include any parameters that weren't normalized
|
|
533
|
+
non_normalized_params = {key: value for key, value in params.items() if not key.startswith(PARAM_PREFIX)}
|
|
534
|
+
denormalized_params.update(non_normalized_params)
|
|
535
|
+
return denormalized_params
|
|
536
|
+
|
|
537
|
+
def _merge_pipeline_parameters(self, result: Any, final_params: Any) -> Any:
|
|
538
|
+
"""Merge parameters from the pipeline processing."""
|
|
539
|
+
merged_params = result.context.merged_parameters
|
|
540
|
+
|
|
541
|
+
# If we have extracted parameters from the pipeline, only merge them if:
|
|
542
|
+
# 1. We don't already have parameters in merged_params, OR
|
|
543
|
+
# 2. The original params were None and we need to use the extracted ones
|
|
544
|
+
if result.context.extracted_parameters_from_pipeline:
|
|
545
|
+
if merged_params is None:
|
|
546
|
+
# No existing parameters - use the extracted ones
|
|
385
547
|
merged_params = result.context.extracted_parameters_from_pipeline
|
|
386
|
-
|
|
387
|
-
#
|
|
388
|
-
merged_params =
|
|
548
|
+
elif merged_params == final_params and final_params is None:
|
|
549
|
+
# Both are None, use extracted parameters
|
|
550
|
+
merged_params = result.context.extracted_parameters_from_pipeline
|
|
551
|
+
elif merged_params != result.context.extracted_parameters_from_pipeline:
|
|
552
|
+
# Only merge if the extracted parameters are different from what we already have
|
|
553
|
+
# This prevents the duplication issue where the same parameters get added twice
|
|
554
|
+
if is_dict(merged_params):
|
|
555
|
+
for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
|
|
556
|
+
param_name = f"{PARAM_PREFIX}{i}"
|
|
557
|
+
merged_params[param_name] = param
|
|
558
|
+
elif isinstance(merged_params, (list, tuple)):
|
|
559
|
+
# Only extend if we don't already have these parameters
|
|
560
|
+
# Convert to list and extend with extracted parameters
|
|
561
|
+
if isinstance(merged_params, tuple):
|
|
562
|
+
merged_params = list(merged_params)
|
|
563
|
+
merged_params.extend(result.context.extracted_parameters_from_pipeline)
|
|
564
|
+
else:
|
|
565
|
+
# Single parameter case - convert to list with original + extracted
|
|
566
|
+
merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
|
|
389
567
|
|
|
390
|
-
|
|
568
|
+
return merged_params
|
|
569
|
+
|
|
570
|
+
def _finalize_processed_state(self, result: Any, processed_sql: str, merged_params: Any) -> None:
|
|
571
|
+
"""Finalize the processed state."""
|
|
391
572
|
self._processed_state = _ProcessedState(
|
|
392
|
-
processed_expression=
|
|
573
|
+
processed_expression=result.expression,
|
|
393
574
|
processed_sql=processed_sql,
|
|
394
575
|
merged_parameters=merged_params,
|
|
395
576
|
validation_errors=list(result.context.validation_errors),
|
|
396
|
-
analysis_results={},
|
|
397
|
-
transformation_results={},
|
|
577
|
+
analysis_results={},
|
|
578
|
+
transformation_results={},
|
|
398
579
|
)
|
|
399
580
|
|
|
400
|
-
# Check strict mode
|
|
401
581
|
if self._config.strict_mode and self._processed_state.validation_errors:
|
|
402
|
-
# Find the highest risk error
|
|
403
582
|
highest_risk_error = max(
|
|
404
|
-
self._processed_state.validation_errors,
|
|
405
|
-
key=lambda e: e.risk_level.value if hasattr(e, "risk_level") else 0,
|
|
583
|
+
self._processed_state.validation_errors, key=lambda e: e.risk_level.value if has_risk_level(e) else 0
|
|
406
584
|
)
|
|
407
585
|
raise SQLValidationError(
|
|
408
586
|
message=highest_risk_error.message,
|
|
@@ -410,81 +588,85 @@ class SQL:
|
|
|
410
588
|
risk_level=getattr(highest_risk_error, "risk_level", RiskLevel.HIGH),
|
|
411
589
|
)
|
|
412
590
|
|
|
413
|
-
def _to_expression(self, statement: Union[str, exp.Expression]) -> exp.Expression:
|
|
591
|
+
def _to_expression(self, statement: "Union[str, exp.Expression]") -> exp.Expression:
|
|
414
592
|
"""Convert string to sqlglot expression."""
|
|
415
|
-
if
|
|
593
|
+
if is_expression(statement):
|
|
416
594
|
return statement
|
|
417
595
|
|
|
418
|
-
|
|
419
|
-
if not statement or not statement.strip():
|
|
420
|
-
# Return an empty select instead of Anonymous for empty strings
|
|
596
|
+
if not statement or (isinstance(statement, str) and not statement.strip()):
|
|
421
597
|
return exp.Select()
|
|
422
598
|
|
|
423
|
-
# Check if parsing is disabled
|
|
424
599
|
if not self._config.enable_parsing:
|
|
425
|
-
# Return an anonymous expression that preserves the raw SQL
|
|
426
600
|
return exp.Anonymous(this=statement)
|
|
427
601
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
602
|
+
if not isinstance(statement, str):
|
|
603
|
+
return exp.Anonymous(this="")
|
|
431
604
|
validator = self._config.parameter_validator
|
|
432
605
|
param_info = validator.extract_parameters(statement)
|
|
433
606
|
|
|
434
|
-
# Check if
|
|
435
|
-
|
|
436
|
-
p.style in {ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT} for p in param_info
|
|
437
|
-
)
|
|
607
|
+
# Check if normalization is needed
|
|
608
|
+
needs_normalization = any(p.style in SQLGLOT_INCOMPATIBLE_STYLES for p in param_info)
|
|
438
609
|
|
|
439
610
|
normalized_sql = statement
|
|
440
611
|
placeholder_mapping: dict[str, Any] = {}
|
|
441
612
|
|
|
442
|
-
if
|
|
443
|
-
# Normalize pyformat placeholders to named placeholders for SQLGlot
|
|
613
|
+
if needs_normalization:
|
|
444
614
|
converter = self._config.parameter_converter
|
|
445
615
|
normalized_sql, placeholder_mapping = converter._transform_sql_for_parsing(statement, param_info)
|
|
446
|
-
# Store the original SQL before normalization
|
|
447
616
|
self._original_sql = statement
|
|
448
617
|
self._placeholder_mapping = placeholder_mapping
|
|
449
618
|
|
|
619
|
+
# Create normalization state
|
|
620
|
+
from sqlspec.statement.parameters import ParameterNormalizationState
|
|
621
|
+
|
|
622
|
+
self._parameter_normalization_state = ParameterNormalizationState(
|
|
623
|
+
was_normalized=True,
|
|
624
|
+
original_styles=list({p.style for p in param_info}),
|
|
625
|
+
normalized_style=ParameterStyle.NAMED_COLON,
|
|
626
|
+
placeholder_map=placeholder_mapping,
|
|
627
|
+
original_param_info=param_info,
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
self._parameter_normalization_state = None
|
|
631
|
+
|
|
450
632
|
try:
|
|
451
|
-
# Parse with sqlglot
|
|
452
633
|
expressions = sqlglot.parse(normalized_sql, dialect=self._dialect) # pyright: ignore
|
|
453
634
|
if not expressions:
|
|
454
|
-
# Empty statement
|
|
455
635
|
return exp.Anonymous(this=statement)
|
|
456
636
|
first_expr = expressions[0]
|
|
457
637
|
if first_expr is None:
|
|
458
|
-
# Could not parse
|
|
459
638
|
return exp.Anonymous(this=statement)
|
|
460
639
|
|
|
461
640
|
except ParseError as e:
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
641
|
+
if getattr(self._config, "parse_errors_as_warnings", False):
|
|
642
|
+
logger.warning(
|
|
643
|
+
"Failed to parse SQL, returning Anonymous expression.", extra={"sql": statement, "error": str(e)}
|
|
644
|
+
)
|
|
645
|
+
return exp.Anonymous(this=statement)
|
|
646
|
+
|
|
647
|
+
msg = f"Failed to parse SQL: {statement}"
|
|
648
|
+
raise SQLParsingError(msg) from e
|
|
465
649
|
return first_expr
|
|
466
650
|
|
|
467
651
|
@staticmethod
|
|
468
652
|
def _extract_filter_parameters(filter_obj: StatementFilter) -> tuple[list[Any], dict[str, Any]]:
|
|
469
653
|
"""Extract parameters from a filter object."""
|
|
470
|
-
if
|
|
654
|
+
if can_extract_parameters(filter_obj):
|
|
471
655
|
return filter_obj.extract_parameters()
|
|
472
|
-
# Fallback for filters that don't implement the new method yet
|
|
473
656
|
return [], {}
|
|
474
657
|
|
|
475
658
|
def copy(
|
|
476
659
|
self,
|
|
477
|
-
statement: Optional[Union[str, exp.Expression]] = None,
|
|
478
|
-
parameters: Optional[Any] = None,
|
|
479
|
-
dialect: DialectType = None,
|
|
480
|
-
config: Optional[SQLConfig] = None,
|
|
660
|
+
statement: "Optional[Union[str, exp.Expression]]" = None,
|
|
661
|
+
parameters: "Optional[Any]" = None,
|
|
662
|
+
dialect: "DialectType" = None,
|
|
663
|
+
config: "Optional[SQLConfig]" = None,
|
|
481
664
|
**kwargs: Any,
|
|
482
665
|
) -> "SQL":
|
|
483
666
|
"""Create a copy with optional modifications.
|
|
484
667
|
|
|
485
668
|
This is the primary method for creating modified SQL objects.
|
|
486
669
|
"""
|
|
487
|
-
# Prepare existing state
|
|
488
670
|
existing_state = {
|
|
489
671
|
"positional_params": list(self._positional_params),
|
|
490
672
|
"named_params": dict(self._named_params),
|
|
@@ -493,25 +675,22 @@ class SQL:
|
|
|
493
675
|
"is_script": self._is_script,
|
|
494
676
|
"raw_sql": self._raw_sql,
|
|
495
677
|
}
|
|
678
|
+
existing_state["original_parameters"] = self._original_parameters
|
|
496
679
|
|
|
497
|
-
# Create new instance
|
|
498
680
|
new_statement = statement if statement is not None else self._statement
|
|
499
681
|
new_dialect = dialect if dialect is not None else self._dialect
|
|
500
682
|
new_config = config if config is not None else self._config
|
|
501
683
|
|
|
502
|
-
# If parameters are explicitly provided, they replace existing ones
|
|
503
684
|
if parameters is not None:
|
|
504
|
-
# Clear existing state so only new parameters are used
|
|
505
685
|
existing_state["positional_params"] = []
|
|
506
686
|
existing_state["named_params"] = {}
|
|
507
|
-
# Pass parameters through normal processing
|
|
508
687
|
return SQL(
|
|
509
688
|
new_statement,
|
|
510
689
|
parameters,
|
|
511
690
|
_dialect=new_dialect,
|
|
512
691
|
_config=new_config,
|
|
513
692
|
_builder_result_type=self._builder_result_type,
|
|
514
|
-
_existing_state=None,
|
|
693
|
+
_existing_state=None,
|
|
515
694
|
**kwargs,
|
|
516
695
|
)
|
|
517
696
|
|
|
@@ -524,14 +703,14 @@ class SQL:
|
|
|
524
703
|
**kwargs,
|
|
525
704
|
)
|
|
526
705
|
|
|
527
|
-
def add_named_parameter(self, name: str, value: Any) -> "SQL":
|
|
706
|
+
def add_named_parameter(self, name: "str", value: Any) -> "SQL":
|
|
528
707
|
"""Add a named parameter and return a new SQL instance."""
|
|
529
708
|
new_obj = self.copy()
|
|
530
709
|
new_obj._named_params[name] = value
|
|
531
710
|
return new_obj
|
|
532
711
|
|
|
533
712
|
def get_unique_parameter_name(
|
|
534
|
-
self, base_name: str, namespace: Optional[str] = None, preserve_original: bool = False
|
|
713
|
+
self, base_name: "str", namespace: "Optional[str]" = None, preserve_original: bool = False
|
|
535
714
|
) -> str:
|
|
536
715
|
"""Generate a unique parameter name.
|
|
537
716
|
|
|
@@ -543,21 +722,16 @@ class SQL:
|
|
|
543
722
|
Returns:
|
|
544
723
|
A unique parameter name
|
|
545
724
|
"""
|
|
546
|
-
# Check both positional and named params
|
|
547
725
|
all_param_names = set(self._named_params.keys())
|
|
548
726
|
|
|
549
|
-
# Build the candidate name
|
|
550
727
|
candidate = f"{namespace}_{base_name}" if namespace else base_name
|
|
551
728
|
|
|
552
|
-
# If preserve_original and the name is unique, use it
|
|
553
729
|
if preserve_original and candidate not in all_param_names:
|
|
554
730
|
return candidate
|
|
555
731
|
|
|
556
|
-
# If not preserving or name exists, generate unique name
|
|
557
732
|
if candidate not in all_param_names:
|
|
558
733
|
return candidate
|
|
559
734
|
|
|
560
|
-
# Generate unique name with counter
|
|
561
735
|
counter = 1
|
|
562
736
|
while True:
|
|
563
737
|
new_candidate = f"{candidate}_{counter}"
|
|
@@ -567,24 +741,19 @@ class SQL:
|
|
|
567
741
|
|
|
568
742
|
def where(self, condition: "Union[str, exp.Expression, exp.Condition]") -> "SQL":
|
|
569
743
|
"""Apply WHERE clause and return new SQL instance."""
|
|
570
|
-
# Convert condition to expression
|
|
571
744
|
condition_expr = self._to_expression(condition) if isinstance(condition, str) else condition
|
|
572
745
|
|
|
573
|
-
|
|
574
|
-
if hasattr(self._statement, "where"):
|
|
746
|
+
if supports_where(self._statement):
|
|
575
747
|
new_statement = self._statement.where(condition_expr) # pyright: ignore
|
|
576
748
|
else:
|
|
577
|
-
# Wrap in SELECT if needed
|
|
578
749
|
new_statement = exp.Select().from_(self._statement).where(condition_expr) # pyright: ignore
|
|
579
750
|
|
|
580
751
|
return self.copy(statement=new_statement)
|
|
581
752
|
|
|
582
753
|
def filter(self, filter_obj: StatementFilter) -> "SQL":
|
|
583
754
|
"""Apply a filter and return a new SQL instance."""
|
|
584
|
-
# Create a new SQL object with the filter added
|
|
585
755
|
new_obj = self.copy()
|
|
586
756
|
new_obj._filters.append(filter_obj)
|
|
587
|
-
# Extract filter parameters
|
|
588
757
|
pos_params, named_params = self._extract_filter_parameters(filter_obj)
|
|
589
758
|
new_obj._positional_params.extend(pos_params)
|
|
590
759
|
new_obj._named_params.update(named_params)
|
|
@@ -595,10 +764,9 @@ class SQL:
|
|
|
595
764
|
new_obj = self.copy()
|
|
596
765
|
new_obj._is_many = True
|
|
597
766
|
if parameters is not None:
|
|
598
|
-
# Replace parameters for executemany
|
|
599
767
|
new_obj._positional_params = []
|
|
600
768
|
new_obj._named_params = {}
|
|
601
|
-
new_obj.
|
|
769
|
+
new_obj._original_parameters = parameters
|
|
602
770
|
return new_obj
|
|
603
771
|
|
|
604
772
|
def as_script(self) -> "SQL":
|
|
@@ -609,77 +777,82 @@ class SQL:
|
|
|
609
777
|
|
|
610
778
|
def _build_final_state(self) -> tuple[exp.Expression, Any]:
|
|
611
779
|
"""Build final expression and parameters after applying filters."""
|
|
612
|
-
# Start with current statement
|
|
613
780
|
final_expr = self._statement
|
|
614
781
|
|
|
615
|
-
# Apply all filters to the expression
|
|
616
782
|
for filter_obj in self._filters:
|
|
617
|
-
if
|
|
783
|
+
if can_append_to_statement(filter_obj):
|
|
618
784
|
temp_sql = SQL(final_expr, config=self._config, dialect=self._dialect)
|
|
619
785
|
temp_sql._positional_params = list(self._positional_params)
|
|
620
786
|
temp_sql._named_params = dict(self._named_params)
|
|
621
787
|
result = filter_obj.append_to_statement(temp_sql)
|
|
622
788
|
final_expr = result._statement if isinstance(result, SQL) else result
|
|
623
789
|
|
|
624
|
-
# Determine final parameters format
|
|
625
790
|
final_params: Any
|
|
626
791
|
if self._named_params and not self._positional_params:
|
|
627
|
-
# Only named params
|
|
628
792
|
final_params = dict(self._named_params)
|
|
629
793
|
elif self._positional_params and not self._named_params:
|
|
630
|
-
# Always return a list for positional params to maintain sequence type
|
|
631
794
|
final_params = list(self._positional_params)
|
|
632
795
|
elif self._positional_params and self._named_params:
|
|
633
|
-
# Mixed - merge into dict
|
|
634
796
|
final_params = dict(self._named_params)
|
|
635
|
-
# Add positional params with generated names
|
|
636
797
|
for i, param in enumerate(self._positional_params):
|
|
637
798
|
param_name = f"arg_{i}"
|
|
638
799
|
while param_name in final_params:
|
|
639
800
|
param_name = f"arg_{i}_{id(param)}"
|
|
640
801
|
final_params[param_name] = param
|
|
641
802
|
else:
|
|
642
|
-
# No parameters
|
|
643
803
|
final_params = None
|
|
644
804
|
|
|
645
805
|
return final_expr, final_params
|
|
646
806
|
|
|
647
|
-
# Properties for compatibility
|
|
648
807
|
@property
|
|
649
808
|
def sql(self) -> str:
|
|
650
809
|
"""Get SQL string."""
|
|
651
|
-
# Handle empty string case
|
|
652
810
|
if not self._raw_sql or (self._raw_sql and not self._raw_sql.strip()):
|
|
653
811
|
return ""
|
|
654
812
|
|
|
655
|
-
# For scripts, always return the raw SQL to preserve multi-statement scripts
|
|
656
813
|
if self._is_script and self._raw_sql:
|
|
657
814
|
return self._raw_sql
|
|
658
|
-
# If parsing is disabled, return the raw SQL
|
|
659
815
|
if not self._config.enable_parsing and self._raw_sql:
|
|
660
816
|
return self._raw_sql
|
|
661
817
|
|
|
662
|
-
# Ensure processed
|
|
663
818
|
self._ensure_processed()
|
|
664
|
-
|
|
819
|
+
if self._processed_state is None:
|
|
820
|
+
msg = "Failed to process SQL statement"
|
|
821
|
+
raise RuntimeError(msg)
|
|
665
822
|
return self._processed_state.processed_sql
|
|
666
823
|
|
|
667
824
|
@property
|
|
668
|
-
def expression(self) -> Optional[exp.Expression]:
|
|
825
|
+
def expression(self) -> "Optional[exp.Expression]":
|
|
669
826
|
"""Get the final expression."""
|
|
670
|
-
# Return None if parsing is disabled
|
|
671
827
|
if not self._config.enable_parsing:
|
|
672
828
|
return None
|
|
673
829
|
self._ensure_processed()
|
|
674
|
-
|
|
830
|
+
if self._processed_state is None:
|
|
831
|
+
msg = "Failed to process SQL statement"
|
|
832
|
+
raise RuntimeError(msg)
|
|
675
833
|
return self._processed_state.processed_expression
|
|
676
834
|
|
|
677
835
|
@property
|
|
678
836
|
def parameters(self) -> Any:
|
|
679
837
|
"""Get merged parameters."""
|
|
838
|
+
if self._is_many and self._original_parameters is not None:
|
|
839
|
+
return self._original_parameters
|
|
840
|
+
|
|
841
|
+
if (
|
|
842
|
+
self._original_parameters is not None
|
|
843
|
+
and isinstance(self._original_parameters, tuple)
|
|
844
|
+
and not self._named_params
|
|
845
|
+
):
|
|
846
|
+
return self._original_parameters
|
|
847
|
+
|
|
680
848
|
self._ensure_processed()
|
|
681
|
-
|
|
682
|
-
|
|
849
|
+
if self._processed_state is None:
|
|
850
|
+
msg = "Failed to process SQL statement"
|
|
851
|
+
raise RuntimeError(msg)
|
|
852
|
+
params = self._processed_state.merged_parameters
|
|
853
|
+
if params is None:
|
|
854
|
+
return {}
|
|
855
|
+
return params
|
|
683
856
|
|
|
684
857
|
@property
|
|
685
858
|
def is_many(self) -> bool:
|
|
@@ -691,56 +864,173 @@ class SQL:
|
|
|
691
864
|
"""Check if this is a script."""
|
|
692
865
|
return self._is_script
|
|
693
866
|
|
|
694
|
-
|
|
867
|
+
@property
|
|
868
|
+
def dialect(self) -> "Optional[DialectType]":
|
|
869
|
+
"""Get the SQL dialect."""
|
|
870
|
+
return self._dialect
|
|
871
|
+
|
|
872
|
+
def to_sql(self, placeholder_style: "Optional[str]" = None) -> "str":
|
|
695
873
|
"""Convert to SQL string with given placeholder style."""
|
|
696
874
|
if self._is_script:
|
|
697
875
|
return self.sql
|
|
698
876
|
sql, _ = self.compile(placeholder_style=placeholder_style)
|
|
699
877
|
return sql
|
|
700
878
|
|
|
701
|
-
def get_parameters(self, style: Optional[str] = None) -> Any:
|
|
879
|
+
def get_parameters(self, style: "Optional[str]" = None) -> Any:
|
|
702
880
|
"""Get parameters in the requested style."""
|
|
703
|
-
# Get compiled parameters with style
|
|
704
881
|
_, params = self.compile(placeholder_style=style)
|
|
705
882
|
return params
|
|
706
883
|
|
|
707
|
-
def
|
|
884
|
+
def _compile_execute_many(self, placeholder_style: "Optional[str]") -> "tuple[str, Any]":
|
|
885
|
+
"""Handle compilation for execute_many operations."""
|
|
886
|
+
sql = self.sql
|
|
887
|
+
|
|
888
|
+
self._ensure_processed()
|
|
889
|
+
|
|
890
|
+
params = self._original_parameters
|
|
891
|
+
|
|
892
|
+
extracted_params = self._get_extracted_parameters()
|
|
893
|
+
|
|
894
|
+
if extracted_params:
|
|
895
|
+
params = self._merge_extracted_params_with_sets(params, extracted_params)
|
|
896
|
+
|
|
897
|
+
if placeholder_style:
|
|
898
|
+
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
899
|
+
|
|
900
|
+
return sql, params
|
|
901
|
+
|
|
902
|
+
def _get_extracted_parameters(self) -> "list[Any]":
|
|
903
|
+
"""Get extracted parameters from pipeline processing."""
|
|
904
|
+
extracted_params = []
|
|
905
|
+
if self._processed_state and self._processed_state.merged_parameters:
|
|
906
|
+
merged = self._processed_state.merged_parameters
|
|
907
|
+
if isinstance(merged, list):
|
|
908
|
+
if merged and not isinstance(merged[0], (tuple, list)):
|
|
909
|
+
extracted_params = merged
|
|
910
|
+
elif self._processing_context and self._processing_context.extracted_parameters_from_pipeline:
|
|
911
|
+
extracted_params = self._processing_context.extracted_parameters_from_pipeline
|
|
912
|
+
return extracted_params
|
|
913
|
+
|
|
914
|
+
def _merge_extracted_params_with_sets(self, params: Any, extracted_params: "list[Any]") -> "list[tuple[Any, ...]]":
|
|
915
|
+
"""Merge extracted parameters with each parameter set."""
|
|
916
|
+
enhanced_params = []
|
|
917
|
+
for param_set in params:
|
|
918
|
+
if isinstance(param_set, (list, tuple)):
|
|
919
|
+
extracted_values = []
|
|
920
|
+
for extracted in extracted_params:
|
|
921
|
+
if has_parameter_value(extracted):
|
|
922
|
+
extracted_values.append(extracted.value)
|
|
923
|
+
else:
|
|
924
|
+
extracted_values.append(extracted)
|
|
925
|
+
enhanced_set = list(param_set) + extracted_values
|
|
926
|
+
enhanced_params.append(tuple(enhanced_set))
|
|
927
|
+
else:
|
|
928
|
+
extracted_values = []
|
|
929
|
+
for extracted in extracted_params:
|
|
930
|
+
if has_parameter_value(extracted):
|
|
931
|
+
extracted_values.append(extracted.value)
|
|
932
|
+
else:
|
|
933
|
+
extracted_values.append(extracted)
|
|
934
|
+
enhanced_params.append((param_set, *extracted_values))
|
|
935
|
+
return enhanced_params
|
|
936
|
+
|
|
937
|
+
def compile(self, placeholder_style: "Optional[str]" = None) -> "tuple[str, Any]":
|
|
708
938
|
"""Compile to SQL and parameters."""
|
|
709
|
-
# For scripts, return raw SQL directly without processing
|
|
710
939
|
if self._is_script:
|
|
711
940
|
return self.sql, None
|
|
712
941
|
|
|
713
|
-
|
|
942
|
+
if self._is_many and self._original_parameters is not None:
|
|
943
|
+
return self._compile_execute_many(placeholder_style)
|
|
944
|
+
|
|
714
945
|
if not self._config.enable_parsing and self._raw_sql:
|
|
715
946
|
return self._raw_sql, self._raw_parameters
|
|
716
947
|
|
|
717
|
-
# Ensure processed
|
|
718
948
|
self._ensure_processed()
|
|
719
949
|
|
|
720
|
-
|
|
721
|
-
|
|
950
|
+
if self._processed_state is None:
|
|
951
|
+
msg = "Failed to process SQL statement"
|
|
952
|
+
raise RuntimeError(msg)
|
|
722
953
|
sql = self._processed_state.processed_sql
|
|
723
954
|
params = self._processed_state.merged_parameters
|
|
724
955
|
|
|
725
|
-
|
|
726
|
-
if params is not None and hasattr(self, "_processing_context") and self._processing_context:
|
|
956
|
+
if params is not None and self._processing_context:
|
|
727
957
|
parameter_mapping = self._processing_context.metadata.get("parameter_position_mapping")
|
|
728
958
|
if parameter_mapping:
|
|
729
|
-
# Apply parameter reordering based on the mapping
|
|
730
959
|
params = self._reorder_parameters(params, parameter_mapping)
|
|
731
960
|
|
|
732
|
-
#
|
|
961
|
+
# Handle denormalization if needed
|
|
962
|
+
if self._processing_context and self._processing_context.parameter_normalization:
|
|
963
|
+
norm_state = self._processing_context.parameter_normalization
|
|
964
|
+
|
|
965
|
+
# If original SQL had incompatible styles, denormalize back to the original style
|
|
966
|
+
# when no specific style requested OR when the requested style matches the original
|
|
967
|
+
if norm_state.was_normalized and norm_state.original_styles:
|
|
968
|
+
original_style = norm_state.original_styles[0]
|
|
969
|
+
should_denormalize = placeholder_style is None or (
|
|
970
|
+
placeholder_style and ParameterStyle(placeholder_style) == original_style
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
if should_denormalize and original_style in SQLGLOT_INCOMPATIBLE_STYLES:
|
|
974
|
+
# Denormalize SQL back to original style
|
|
975
|
+
sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
976
|
+
sql, norm_state.original_param_info, original_style
|
|
977
|
+
)
|
|
978
|
+
# Also denormalize parameters if needed
|
|
979
|
+
if original_style == ParameterStyle.POSITIONAL_COLON and is_dict(params):
|
|
980
|
+
params = self._denormalize_colon_params(params)
|
|
981
|
+
|
|
982
|
+
params = self._unwrap_typed_parameters(params)
|
|
983
|
+
|
|
733
984
|
if placeholder_style is None:
|
|
734
985
|
return sql, params
|
|
735
986
|
|
|
736
|
-
# Convert to requested placeholder style
|
|
737
987
|
if placeholder_style:
|
|
738
|
-
sql, params = self.
|
|
988
|
+
sql, params = self._apply_placeholder_style(sql, params, placeholder_style)
|
|
739
989
|
|
|
740
|
-
# Debug log the final SQL
|
|
741
|
-
logger.debug("Final compiled SQL: '%s'", sql)
|
|
742
990
|
return sql, params
|
|
743
991
|
|
|
992
|
+
def _apply_placeholder_style(self, sql: "str", params: Any, placeholder_style: "str") -> "tuple[str, Any]":
|
|
993
|
+
"""Apply placeholder style conversion to SQL and parameters."""
|
|
994
|
+
# Just use the params passed in - they've already been processed
|
|
995
|
+
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
996
|
+
return sql, params
|
|
997
|
+
|
|
998
|
+
@staticmethod
|
|
999
|
+
def _unwrap_typed_parameters(params: Any) -> Any:
|
|
1000
|
+
"""Unwrap TypedParameter objects to their actual values.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
params: Parameters that may contain TypedParameter objects
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
Parameters with TypedParameter objects unwrapped to their values
|
|
1007
|
+
"""
|
|
1008
|
+
if params is None:
|
|
1009
|
+
return None
|
|
1010
|
+
|
|
1011
|
+
if is_dict(params):
|
|
1012
|
+
unwrapped_dict = {}
|
|
1013
|
+
for key, value in params.items():
|
|
1014
|
+
if has_parameter_value(value):
|
|
1015
|
+
unwrapped_dict[key] = value.value
|
|
1016
|
+
else:
|
|
1017
|
+
unwrapped_dict[key] = value
|
|
1018
|
+
return unwrapped_dict
|
|
1019
|
+
|
|
1020
|
+
if isinstance(params, (list, tuple)):
|
|
1021
|
+
unwrapped_list = []
|
|
1022
|
+
for value in params:
|
|
1023
|
+
if has_parameter_value(value):
|
|
1024
|
+
unwrapped_list.append(value.value)
|
|
1025
|
+
else:
|
|
1026
|
+
unwrapped_list.append(value)
|
|
1027
|
+
return type(params)(unwrapped_list)
|
|
1028
|
+
|
|
1029
|
+
if has_parameter_value(params):
|
|
1030
|
+
return params.value
|
|
1031
|
+
|
|
1032
|
+
return params
|
|
1033
|
+
|
|
744
1034
|
@staticmethod
|
|
745
1035
|
def _reorder_parameters(params: Any, mapping: dict[int, int]) -> Any:
|
|
746
1036
|
"""Reorder parameters based on the position mapping.
|
|
@@ -753,43 +1043,34 @@ class SQL:
|
|
|
753
1043
|
Reordered parameters in the same format as input
|
|
754
1044
|
"""
|
|
755
1045
|
if isinstance(params, (list, tuple)):
|
|
756
|
-
# Create a new list with reordered parameters
|
|
757
1046
|
reordered_list = [None] * len(params) # pyright: ignore
|
|
758
1047
|
for new_pos, old_pos in mapping.items():
|
|
759
1048
|
if old_pos < len(params):
|
|
760
1049
|
reordered_list[new_pos] = params[old_pos] # pyright: ignore
|
|
761
1050
|
|
|
762
|
-
# Handle any unmapped positions
|
|
763
1051
|
for i, val in enumerate(reordered_list):
|
|
764
1052
|
if val is None and i < len(params) and i not in mapping:
|
|
765
|
-
# If position wasn't mapped, try to use original
|
|
766
1053
|
reordered_list[i] = params[i] # pyright: ignore
|
|
767
1054
|
|
|
768
|
-
# Return in same format as input
|
|
769
1055
|
return tuple(reordered_list) if isinstance(params, tuple) else reordered_list
|
|
770
1056
|
|
|
771
|
-
if
|
|
772
|
-
|
|
773
|
-
# If keys are like param_0, param_1, we can reorder them
|
|
774
|
-
if all(key.startswith("param_") and key[6:].isdigit() for key in params):
|
|
1057
|
+
if is_dict(params):
|
|
1058
|
+
if all(key.startswith(PARAM_PREFIX) and key[len(PARAM_PREFIX) :].isdigit() for key in params):
|
|
775
1059
|
reordered_dict: dict[str, Any] = {}
|
|
776
1060
|
for new_pos, old_pos in mapping.items():
|
|
777
|
-
old_key = f"
|
|
778
|
-
new_key = f"
|
|
1061
|
+
old_key = f"{PARAM_PREFIX}{old_pos}"
|
|
1062
|
+
new_key = f"{PARAM_PREFIX}{new_pos}"
|
|
779
1063
|
if old_key in params:
|
|
780
1064
|
reordered_dict[new_key] = params[old_key]
|
|
781
1065
|
|
|
782
|
-
# Add any unmapped parameters
|
|
783
1066
|
for key, value in params.items():
|
|
784
|
-
if key not in reordered_dict and key.startswith(
|
|
1067
|
+
if key not in reordered_dict and key.startswith(PARAM_PREFIX):
|
|
785
1068
|
idx = int(key[6:])
|
|
786
1069
|
if idx not in mapping:
|
|
787
1070
|
reordered_dict[key] = value
|
|
788
1071
|
|
|
789
1072
|
return reordered_dict
|
|
790
|
-
# Can't reorder named parameters, return as-is
|
|
791
1073
|
return params
|
|
792
|
-
# Single value or unknown format, return as-is
|
|
793
1074
|
return params
|
|
794
1075
|
|
|
795
1076
|
def _convert_placeholder_style(self, sql: str, params: Any, placeholder_style: str) -> tuple[str, Any]:
|
|
@@ -803,27 +1084,119 @@ class SQL:
|
|
|
803
1084
|
Returns:
|
|
804
1085
|
Tuple of (converted_sql, converted_params)
|
|
805
1086
|
"""
|
|
806
|
-
|
|
1087
|
+
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
1088
|
+
converter = self._config.parameter_converter
|
|
1089
|
+
param_info = converter.validator.extract_parameters(sql)
|
|
1090
|
+
|
|
1091
|
+
if param_info:
|
|
1092
|
+
target_style = (
|
|
1093
|
+
ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
1094
|
+
)
|
|
1095
|
+
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
1096
|
+
|
|
1097
|
+
return sql, params
|
|
1098
|
+
|
|
807
1099
|
converter = self._config.parameter_converter
|
|
808
|
-
|
|
1100
|
+
|
|
1101
|
+
# For POSITIONAL_COLON style, use original parameter info if available to preserve numeric identifiers
|
|
1102
|
+
target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
1103
|
+
if (
|
|
1104
|
+
target_style == ParameterStyle.POSITIONAL_COLON
|
|
1105
|
+
and self._processing_context
|
|
1106
|
+
and self._processing_context.parameter_normalization
|
|
1107
|
+
and self._processing_context.parameter_normalization.original_param_info
|
|
1108
|
+
):
|
|
1109
|
+
param_info = self._processing_context.parameter_normalization.original_param_info
|
|
1110
|
+
else:
|
|
1111
|
+
param_info = converter.validator.extract_parameters(sql)
|
|
1112
|
+
|
|
1113
|
+
# CRITICAL FIX: For POSITIONAL_COLON, we need to ensure param_info reflects
|
|
1114
|
+
# all placeholders in the current SQL, not just the original ones.
|
|
1115
|
+
# This handles cases where transformers (like ParameterizeLiterals) add new placeholders.
|
|
1116
|
+
if target_style == ParameterStyle.POSITIONAL_COLON and param_info:
|
|
1117
|
+
# Re-extract from current SQL to get all placeholders
|
|
1118
|
+
current_param_info = converter.validator.extract_parameters(sql)
|
|
1119
|
+
if len(current_param_info) > len(param_info):
|
|
1120
|
+
# More placeholders in current SQL means transformers added some
|
|
1121
|
+
# Use the current info to ensure all placeholders get parameters
|
|
1122
|
+
param_info = current_param_info
|
|
809
1123
|
|
|
810
1124
|
if not param_info:
|
|
811
1125
|
return sql, params
|
|
812
1126
|
|
|
813
|
-
|
|
814
|
-
|
|
1127
|
+
if target_style == ParameterStyle.STATIC:
|
|
1128
|
+
return self._embed_static_parameters(sql, params, param_info)
|
|
815
1129
|
|
|
816
|
-
|
|
1130
|
+
if param_info and all(p.style == target_style for p in param_info):
|
|
1131
|
+
converted_params = self._convert_parameters_format(params, param_info, target_style)
|
|
1132
|
+
return sql, converted_params
|
|
817
1133
|
|
|
818
|
-
# Replace placeholders in SQL
|
|
819
1134
|
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
820
1135
|
|
|
821
|
-
# Convert parameters to appropriate format
|
|
822
1136
|
params = self._convert_parameters_format(params, param_info, target_style)
|
|
823
1137
|
|
|
824
1138
|
return sql, params
|
|
825
1139
|
|
|
826
|
-
def
|
|
1140
|
+
def _embed_static_parameters(self, sql: str, params: Any, param_info: list[Any]) -> tuple[str, Any]:
|
|
1141
|
+
"""Embed parameter values directly into SQL for STATIC style.
|
|
1142
|
+
|
|
1143
|
+
This is used for scripts and other cases where parameters need to be
|
|
1144
|
+
embedded directly in the SQL string rather than passed separately.
|
|
1145
|
+
|
|
1146
|
+
Args:
|
|
1147
|
+
sql: The SQL string with placeholders
|
|
1148
|
+
params: The parameter values
|
|
1149
|
+
param_info: List of parameter information from extraction
|
|
1150
|
+
|
|
1151
|
+
Returns:
|
|
1152
|
+
Tuple of (sql_with_embedded_values, None)
|
|
1153
|
+
"""
|
|
1154
|
+
param_list: list[Any] = []
|
|
1155
|
+
if is_dict(params):
|
|
1156
|
+
for p in param_info:
|
|
1157
|
+
if p.name and p.name in params:
|
|
1158
|
+
param_list.append(params[p.name])
|
|
1159
|
+
elif f"{PARAM_PREFIX}{p.ordinal}" in params:
|
|
1160
|
+
param_list.append(params[f"{PARAM_PREFIX}{p.ordinal}"])
|
|
1161
|
+
elif f"arg_{p.ordinal}" in params:
|
|
1162
|
+
param_list.append(params[f"arg_{p.ordinal}"])
|
|
1163
|
+
else:
|
|
1164
|
+
param_list.append(params.get(str(p.ordinal), None))
|
|
1165
|
+
elif isinstance(params, (list, tuple)):
|
|
1166
|
+
param_list = list(params)
|
|
1167
|
+
elif params is not None:
|
|
1168
|
+
param_list = [params]
|
|
1169
|
+
|
|
1170
|
+
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
1171
|
+
|
|
1172
|
+
for p in sorted_params:
|
|
1173
|
+
if p.ordinal < len(param_list):
|
|
1174
|
+
value = param_list[p.ordinal]
|
|
1175
|
+
|
|
1176
|
+
if has_parameter_value(value):
|
|
1177
|
+
value = value.value
|
|
1178
|
+
|
|
1179
|
+
if value is None:
|
|
1180
|
+
literal_str = "NULL"
|
|
1181
|
+
elif isinstance(value, bool):
|
|
1182
|
+
literal_str = "TRUE" if value else "FALSE"
|
|
1183
|
+
elif isinstance(value, str):
|
|
1184
|
+
literal_expr = sqlglot.exp.Literal.string(value)
|
|
1185
|
+
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1186
|
+
elif isinstance(value, (int, float)):
|
|
1187
|
+
literal_expr = sqlglot.exp.Literal.number(value)
|
|
1188
|
+
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1189
|
+
else:
|
|
1190
|
+
literal_expr = sqlglot.exp.Literal.string(str(value))
|
|
1191
|
+
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1192
|
+
|
|
1193
|
+
start = p.position
|
|
1194
|
+
end = start + len(p.placeholder_text)
|
|
1195
|
+
sql = sql[:start] + literal_str + sql[end:]
|
|
1196
|
+
|
|
1197
|
+
return sql, None
|
|
1198
|
+
|
|
1199
|
+
def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: ParameterStyle) -> str:
|
|
827
1200
|
"""Replace placeholders in SQL string with target style placeholders.
|
|
828
1201
|
|
|
829
1202
|
Args:
|
|
@@ -834,12 +1207,10 @@ class SQL:
|
|
|
834
1207
|
Returns:
|
|
835
1208
|
SQL string with replaced placeholders
|
|
836
1209
|
"""
|
|
837
|
-
# Sort by position in reverse to avoid position shifts
|
|
838
1210
|
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
839
1211
|
|
|
840
1212
|
for p in sorted_params:
|
|
841
1213
|
new_placeholder = self._generate_placeholder(p, target_style)
|
|
842
|
-
# Replace the placeholder in SQL
|
|
843
1214
|
start = p.position
|
|
844
1215
|
end = start + len(p.placeholder_text)
|
|
845
1216
|
sql = sql[:start] + new_placeholder + sql[end:]
|
|
@@ -847,7 +1218,7 @@ class SQL:
|
|
|
847
1218
|
return sql
|
|
848
1219
|
|
|
849
1220
|
@staticmethod
|
|
850
|
-
def _generate_placeholder(param: Any, target_style:
|
|
1221
|
+
def _generate_placeholder(param: Any, target_style: ParameterStyle) -> str:
|
|
851
1222
|
"""Generate a placeholder string for the given parameter style.
|
|
852
1223
|
|
|
853
1224
|
Args:
|
|
@@ -857,36 +1228,34 @@ class SQL:
|
|
|
857
1228
|
Returns:
|
|
858
1229
|
Placeholder string
|
|
859
1230
|
"""
|
|
860
|
-
if target_style
|
|
1231
|
+
if target_style in {ParameterStyle.STATIC, ParameterStyle.QMARK}:
|
|
861
1232
|
return "?"
|
|
862
1233
|
if target_style == ParameterStyle.NUMERIC:
|
|
863
|
-
# Use 1-based numbering for numeric style
|
|
864
1234
|
return f"${param.ordinal + 1}"
|
|
865
1235
|
if target_style == ParameterStyle.NAMED_COLON:
|
|
866
|
-
# Use original name if available, otherwise generate one
|
|
867
|
-
# Oracle doesn't like underscores at the start of parameter names
|
|
868
1236
|
if param.name and not param.name.isdigit():
|
|
869
|
-
# Use the name if it's not just a number
|
|
870
1237
|
return f":{param.name}"
|
|
871
|
-
# Generate a new name for numeric placeholders or missing names
|
|
872
1238
|
return f":arg_{param.ordinal}"
|
|
873
1239
|
if target_style == ParameterStyle.NAMED_AT:
|
|
874
|
-
# Use @ prefix for BigQuery style
|
|
875
|
-
# BigQuery requires parameter names to start with a letter, not underscore
|
|
876
1240
|
return f"@{param.name or f'param_{param.ordinal}'}"
|
|
877
1241
|
if target_style == ParameterStyle.POSITIONAL_COLON:
|
|
878
|
-
#
|
|
1242
|
+
# For Oracle positional colon, preserve the original numeric identifier if it was already :N style
|
|
1243
|
+
if (
|
|
1244
|
+
hasattr(param, "style")
|
|
1245
|
+
and param.style == ParameterStyle.POSITIONAL_COLON
|
|
1246
|
+
and hasattr(param, "name")
|
|
1247
|
+
and param.name
|
|
1248
|
+
and param.name.isdigit()
|
|
1249
|
+
):
|
|
1250
|
+
return f":{param.name}"
|
|
879
1251
|
return f":{param.ordinal + 1}"
|
|
880
1252
|
if target_style == ParameterStyle.POSITIONAL_PYFORMAT:
|
|
881
|
-
# Use %s for positional pyformat
|
|
882
1253
|
return "%s"
|
|
883
1254
|
if target_style == ParameterStyle.NAMED_PYFORMAT:
|
|
884
|
-
|
|
885
|
-
return f"%({param.name or f'_arg_{param.ordinal}'})s"
|
|
886
|
-
# Keep original for unknown styles
|
|
1255
|
+
return f"%({param.name or f'arg_{param.ordinal}'})s"
|
|
887
1256
|
return str(param.placeholder_text)
|
|
888
1257
|
|
|
889
|
-
def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style:
|
|
1258
|
+
def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: ParameterStyle) -> Any:
|
|
890
1259
|
"""Convert parameters to the appropriate format for the target style.
|
|
891
1260
|
|
|
892
1261
|
Args:
|
|
@@ -907,10 +1276,96 @@ class SQL:
|
|
|
907
1276
|
return self._convert_to_named_pyformat_format(params, param_info)
|
|
908
1277
|
return params
|
|
909
1278
|
|
|
1279
|
+
def _convert_list_to_colon_dict(
|
|
1280
|
+
self, params: "Union[list[Any], tuple[Any, ...]]", param_info: "list[Any]"
|
|
1281
|
+
) -> "dict[str, Any]":
|
|
1282
|
+
"""Convert list/tuple parameters to colon-style dict format."""
|
|
1283
|
+
result_dict: dict[str, Any] = {}
|
|
1284
|
+
|
|
1285
|
+
if param_info:
|
|
1286
|
+
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
1287
|
+
if all_numeric:
|
|
1288
|
+
for i, value in enumerate(params):
|
|
1289
|
+
result_dict[str(i + 1)] = value
|
|
1290
|
+
else:
|
|
1291
|
+
for i, value in enumerate(params):
|
|
1292
|
+
if i < len(param_info):
|
|
1293
|
+
param_name = param_info[i].name or str(i + 1)
|
|
1294
|
+
result_dict[param_name] = value
|
|
1295
|
+
else:
|
|
1296
|
+
result_dict[str(i + 1)] = value
|
|
1297
|
+
else:
|
|
1298
|
+
for i, value in enumerate(params):
|
|
1299
|
+
result_dict[str(i + 1)] = value
|
|
1300
|
+
|
|
1301
|
+
return result_dict
|
|
1302
|
+
|
|
1303
|
+
def _convert_single_value_to_colon_dict(self, params: Any, param_info: "list[Any]") -> "dict[str, Any]":
|
|
1304
|
+
"""Convert single value parameter to colon-style dict format."""
|
|
1305
|
+
result_dict: dict[str, Any] = {}
|
|
1306
|
+
if param_info and param_info[0].name and param_info[0].name.isdigit():
|
|
1307
|
+
result_dict[param_info[0].name] = params
|
|
1308
|
+
else:
|
|
1309
|
+
result_dict["1"] = params
|
|
1310
|
+
return result_dict
|
|
1311
|
+
|
|
1312
|
+
def _process_mixed_colon_params(self, params: "dict[str, Any]", param_info: "list[Any]") -> "dict[str, Any]":
|
|
1313
|
+
"""Process mixed colon-style numeric and normalized parameters."""
|
|
1314
|
+
result_dict: dict[str, Any] = {}
|
|
1315
|
+
|
|
1316
|
+
# When we have mixed parameters (extracted literals + user oracle params),
|
|
1317
|
+
# we need to be careful about the ordering. The extracted literals should
|
|
1318
|
+
# fill positions based on where they appear in the SQL, not based on
|
|
1319
|
+
# matching parameter names.
|
|
1320
|
+
|
|
1321
|
+
# Separate extracted parameters and user oracle parameters
|
|
1322
|
+
extracted_params = []
|
|
1323
|
+
user_oracle_params = {}
|
|
1324
|
+
extracted_keys_sorted = []
|
|
1325
|
+
|
|
1326
|
+
for key, value in params.items():
|
|
1327
|
+
if has_parameter_value(value):
|
|
1328
|
+
extracted_params.append((key, value))
|
|
1329
|
+
elif key.isdigit():
|
|
1330
|
+
user_oracle_params[key] = value
|
|
1331
|
+
elif key.startswith("param_") and key[6:].isdigit():
|
|
1332
|
+
param_idx = int(key[6:])
|
|
1333
|
+
oracle_key = str(param_idx + 1)
|
|
1334
|
+
if oracle_key not in user_oracle_params:
|
|
1335
|
+
extracted_keys_sorted.append((param_idx, key, value))
|
|
1336
|
+
else:
|
|
1337
|
+
extracted_params.append((key, value))
|
|
1338
|
+
|
|
1339
|
+
extracted_keys_sorted.sort(key=operator.itemgetter(0))
|
|
1340
|
+
for _, key, value in extracted_keys_sorted:
|
|
1341
|
+
extracted_params.append((key, value))
|
|
1342
|
+
|
|
1343
|
+
# Build lists of parameter values in order
|
|
1344
|
+
extracted_values = []
|
|
1345
|
+
for _, value in extracted_params:
|
|
1346
|
+
if has_parameter_value(value):
|
|
1347
|
+
extracted_values.append(value.value)
|
|
1348
|
+
else:
|
|
1349
|
+
extracted_values.append(value)
|
|
1350
|
+
|
|
1351
|
+
user_values = [user_oracle_params[key] for key in sorted(user_oracle_params.keys(), key=int)]
|
|
1352
|
+
|
|
1353
|
+
# Now assign parameters based on position
|
|
1354
|
+
# Extracted parameters go first (they were literals in original positions)
|
|
1355
|
+
# User parameters follow
|
|
1356
|
+
all_values = extracted_values + user_values
|
|
1357
|
+
|
|
1358
|
+
for i, p in enumerate(sorted(param_info, key=lambda x: x.ordinal)):
|
|
1359
|
+
oracle_key = str(p.ordinal + 1)
|
|
1360
|
+
if i < len(all_values):
|
|
1361
|
+
result_dict[oracle_key] = all_values[i]
|
|
1362
|
+
|
|
1363
|
+
return result_dict
|
|
1364
|
+
|
|
910
1365
|
def _convert_to_positional_colon_format(self, params: Any, param_info: list[Any]) -> Any:
|
|
911
|
-
"""Convert to dict format for
|
|
1366
|
+
"""Convert to dict format for positional colon style.
|
|
912
1367
|
|
|
913
|
-
|
|
1368
|
+
Positional colon style uses :1, :2, etc. placeholders and expects
|
|
914
1369
|
parameters as a dict with string keys "1", "2", etc.
|
|
915
1370
|
|
|
916
1371
|
For execute_many operations, returns a list of parameter sets.
|
|
@@ -922,68 +1377,76 @@ class SQL:
|
|
|
922
1377
|
Returns:
|
|
923
1378
|
Dict of parameters with string keys "1", "2", etc., or list for execute_many
|
|
924
1379
|
"""
|
|
925
|
-
# Special handling for execute_many
|
|
926
1380
|
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
927
|
-
# This is execute_many - keep as list but process each item
|
|
928
1381
|
return params
|
|
929
1382
|
|
|
930
|
-
result_dict: dict[str, Any] = {}
|
|
931
|
-
|
|
932
1383
|
if isinstance(params, (list, tuple)):
|
|
933
|
-
|
|
934
|
-
if param_info:
|
|
935
|
-
# Check if all param names are numeric (positional colon style)
|
|
936
|
-
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
937
|
-
if all_numeric:
|
|
938
|
-
# Sort param_info by numeric name to match list order
|
|
939
|
-
sorted_params = sorted(param_info, key=lambda p: int(p.name))
|
|
940
|
-
for i, value in enumerate(params):
|
|
941
|
-
if i < len(sorted_params):
|
|
942
|
-
# Map based on numeric order, not SQL appearance order
|
|
943
|
-
param_name = sorted_params[i].name
|
|
944
|
-
result_dict[param_name] = value
|
|
945
|
-
else:
|
|
946
|
-
# Extra parameters
|
|
947
|
-
result_dict[str(i + 1)] = value
|
|
948
|
-
else:
|
|
949
|
-
# Non-numeric names, map by ordinal
|
|
950
|
-
for i, value in enumerate(params):
|
|
951
|
-
if i < len(param_info):
|
|
952
|
-
param_name = param_info[i].name or str(i + 1)
|
|
953
|
-
result_dict[param_name] = value
|
|
954
|
-
else:
|
|
955
|
-
result_dict[str(i + 1)] = value
|
|
956
|
-
else:
|
|
957
|
-
# No param_info, default to 1-based indexing
|
|
958
|
-
for i, value in enumerate(params):
|
|
959
|
-
result_dict[str(i + 1)] = value
|
|
960
|
-
return result_dict
|
|
1384
|
+
return self._convert_list_to_colon_dict(params, param_info)
|
|
961
1385
|
|
|
962
1386
|
if not is_dict(params) and param_info:
|
|
963
|
-
|
|
964
|
-
if param_info and param_info[0].name and param_info[0].name.isdigit():
|
|
965
|
-
# Use the actual parameter name from SQL (e.g., "0")
|
|
966
|
-
result_dict[param_info[0].name] = params
|
|
967
|
-
else:
|
|
968
|
-
# Default to "1"
|
|
969
|
-
result_dict["1"] = params
|
|
970
|
-
return result_dict
|
|
1387
|
+
return self._convert_single_value_to_colon_dict(params, param_info)
|
|
971
1388
|
|
|
972
|
-
if
|
|
973
|
-
# Check if already in correct format (keys are "1", "2", etc.)
|
|
1389
|
+
if is_dict(params):
|
|
974
1390
|
if all(key.isdigit() for key in params):
|
|
975
1391
|
return params
|
|
976
1392
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
1393
|
+
if all(key.startswith("param_") for key in params):
|
|
1394
|
+
param_result_dict: dict[str, Any] = {}
|
|
1395
|
+
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
1396
|
+
# Use the parameter's ordinal to find the normalized key
|
|
1397
|
+
normalized_key = f"param_{p.ordinal}"
|
|
1398
|
+
if normalized_key in params:
|
|
1399
|
+
if p.name and p.name.isdigit():
|
|
1400
|
+
# For Oracle numeric parameters, preserve the original number
|
|
1401
|
+
param_result_dict[p.name] = params[normalized_key]
|
|
1402
|
+
else:
|
|
1403
|
+
# For other cases, use sequential numbering
|
|
1404
|
+
param_result_dict[str(p.ordinal + 1)] = params[normalized_key]
|
|
1405
|
+
return param_result_dict
|
|
1406
|
+
|
|
1407
|
+
has_oracle_numeric = any(key.isdigit() for key in params)
|
|
1408
|
+
has_param_normalized = any(key.startswith("param_") for key in params)
|
|
1409
|
+
has_typed_params = any(has_parameter_value(v) for v in params.values())
|
|
1410
|
+
|
|
1411
|
+
if (has_oracle_numeric and has_param_normalized) or has_typed_params:
|
|
1412
|
+
return self._process_mixed_colon_params(params, param_info)
|
|
1413
|
+
|
|
1414
|
+
result_dict: dict[str, Any] = {}
|
|
1415
|
+
|
|
1416
|
+
if param_info:
|
|
1417
|
+
# Process all parameters in order of their ordinals
|
|
1418
|
+
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
1419
|
+
oracle_key = str(p.ordinal + 1)
|
|
1420
|
+
value = None
|
|
1421
|
+
|
|
1422
|
+
# Try different ways to find the parameter value
|
|
1423
|
+
if p.name and (
|
|
1424
|
+
p.name in params
|
|
1425
|
+
or (p.name.isdigit() and p.name in params)
|
|
1426
|
+
or (p.name.startswith("param_") and p.name in params)
|
|
1427
|
+
):
|
|
1428
|
+
value = params[p.name]
|
|
1429
|
+
|
|
1430
|
+
# If not found by name, try by ordinal-based keys
|
|
1431
|
+
if value is None:
|
|
1432
|
+
# Try param_N format (common for pipeline parameters)
|
|
1433
|
+
param_key = f"param_{p.ordinal}"
|
|
1434
|
+
if param_key in params:
|
|
1435
|
+
value = params[param_key]
|
|
1436
|
+
# Try arg_N format
|
|
1437
|
+
elif f"arg_{p.ordinal}" in params:
|
|
1438
|
+
value = params[f"arg_{p.ordinal}"]
|
|
1439
|
+
# For positional colon, also check if there's a numeric key
|
|
1440
|
+
# that matches the ordinal position
|
|
1441
|
+
elif str(p.ordinal + 1) in params:
|
|
1442
|
+
value = params[str(p.ordinal + 1)]
|
|
1443
|
+
|
|
1444
|
+
# Unwrap TypedParameter if needed
|
|
1445
|
+
if value is not None:
|
|
1446
|
+
if has_parameter_value(value):
|
|
1447
|
+
value = value.value
|
|
1448
|
+
result_dict[oracle_key] = value
|
|
1449
|
+
|
|
987
1450
|
return result_dict
|
|
988
1451
|
|
|
989
1452
|
return params
|
|
@@ -1001,33 +1464,79 @@ class SQL:
|
|
|
1001
1464
|
"""
|
|
1002
1465
|
result_list: list[Any] = []
|
|
1003
1466
|
if is_dict(params):
|
|
1467
|
+
param_values_by_ordinal: dict[int, Any] = {}
|
|
1468
|
+
|
|
1004
1469
|
for p in param_info:
|
|
1005
1470
|
if p.name and p.name in params:
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1471
|
+
param_values_by_ordinal[p.ordinal] = params[p.name]
|
|
1472
|
+
|
|
1473
|
+
for p in param_info:
|
|
1474
|
+
if p.name is None and p.ordinal not in param_values_by_ordinal:
|
|
1475
|
+
arg_key = f"arg_{p.ordinal}"
|
|
1476
|
+
param_key = f"param_{p.ordinal}"
|
|
1477
|
+
if arg_key in params:
|
|
1478
|
+
param_values_by_ordinal[p.ordinal] = params[arg_key]
|
|
1479
|
+
elif param_key in params:
|
|
1480
|
+
param_values_by_ordinal[p.ordinal] = params[param_key]
|
|
1481
|
+
|
|
1482
|
+
remaining_params = {
|
|
1483
|
+
k: v
|
|
1484
|
+
for k, v in params.items()
|
|
1485
|
+
if k not in {p.name for p in param_info if p.name} and not k.startswith(("arg_", "param_"))
|
|
1486
|
+
}
|
|
1487
|
+
|
|
1488
|
+
unmatched_ordinals = [p.ordinal for p in param_info if p.ordinal not in param_values_by_ordinal]
|
|
1489
|
+
|
|
1490
|
+
for ordinal, (_, value) in zip(unmatched_ordinals, remaining_params.items()):
|
|
1491
|
+
param_values_by_ordinal[ordinal] = value
|
|
1492
|
+
|
|
1493
|
+
for p in param_info:
|
|
1494
|
+
val = param_values_by_ordinal.get(p.ordinal)
|
|
1495
|
+
if val is not None:
|
|
1496
|
+
if has_parameter_value(val):
|
|
1009
1497
|
result_list.append(val.value)
|
|
1010
1498
|
else:
|
|
1011
1499
|
result_list.append(val)
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1500
|
+
else:
|
|
1501
|
+
result_list.append(None)
|
|
1502
|
+
|
|
1503
|
+
return result_list
|
|
1504
|
+
if isinstance(params, (list, tuple)):
|
|
1505
|
+
# Special case: if params is empty, preserve it (don't create None values)
|
|
1506
|
+
# This is important for execute_many with empty parameter lists
|
|
1507
|
+
if not params:
|
|
1508
|
+
return params
|
|
1509
|
+
|
|
1510
|
+
# Handle mixed parameter styles correctly
|
|
1511
|
+
# For mixed styles, assign parameters in order of appearance, not by numeric reference
|
|
1512
|
+
if param_info and any(p.style == ParameterStyle.NUMERIC for p in param_info):
|
|
1513
|
+
# Create mapping from ordinal to parameter value
|
|
1514
|
+
param_mapping: dict[int, Any] = {}
|
|
1515
|
+
|
|
1516
|
+
# Sort parameter info by position to get order of appearance
|
|
1517
|
+
sorted_params = sorted(param_info, key=lambda p: p.position)
|
|
1518
|
+
|
|
1519
|
+
# Assign parameters sequentially in order of appearance
|
|
1520
|
+
for i, param_info_item in enumerate(sorted_params):
|
|
1521
|
+
if i < len(params):
|
|
1522
|
+
param_mapping[param_info_item.ordinal] = params[i]
|
|
1523
|
+
|
|
1524
|
+
# Build result list ordered by original ordinal values
|
|
1525
|
+
for i in range(len(param_info)):
|
|
1526
|
+
val = param_mapping.get(i)
|
|
1527
|
+
if val is not None:
|
|
1528
|
+
if has_parameter_value(val):
|
|
1019
1529
|
result_list.append(val.value)
|
|
1020
1530
|
else:
|
|
1021
1531
|
result_list.append(val)
|
|
1022
1532
|
else:
|
|
1023
1533
|
result_list.append(None)
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
if isinstance(params, (list, tuple)):
|
|
1534
|
+
|
|
1535
|
+
return result_list
|
|
1536
|
+
|
|
1537
|
+
# Standard conversion for non-mixed styles
|
|
1029
1538
|
for param in params:
|
|
1030
|
-
if
|
|
1539
|
+
if has_parameter_value(param):
|
|
1031
1540
|
result_list.append(param.value)
|
|
1032
1541
|
else:
|
|
1033
1542
|
result_list.append(param)
|
|
@@ -1047,28 +1556,26 @@ class SQL:
|
|
|
1047
1556
|
"""
|
|
1048
1557
|
result_dict: dict[str, Any] = {}
|
|
1049
1558
|
if is_dict(params):
|
|
1050
|
-
# For dict params with matching parameter names, return as-is
|
|
1051
|
-
# Otherwise, remap to match the expected names
|
|
1052
1559
|
if all(p.name in params for p in param_info if p.name):
|
|
1053
1560
|
return params
|
|
1054
1561
|
for p in param_info:
|
|
1055
1562
|
if p.name and p.name in params:
|
|
1056
1563
|
result_dict[p.name] = params[p.name]
|
|
1057
1564
|
elif f"param_{p.ordinal}" in params:
|
|
1058
|
-
# Handle param_N style names
|
|
1059
|
-
# Oracle doesn't like underscores at the start of parameter names
|
|
1060
1565
|
result_dict[p.name or f"arg_{p.ordinal}"] = params[f"param_{p.ordinal}"]
|
|
1061
1566
|
return result_dict
|
|
1062
1567
|
if isinstance(params, (list, tuple)):
|
|
1063
|
-
# Convert list/tuple to dict with parameter names from param_info
|
|
1064
|
-
|
|
1065
1568
|
for i, value in enumerate(params):
|
|
1569
|
+
if has_parameter_value(value):
|
|
1570
|
+
value = value.value
|
|
1571
|
+
|
|
1066
1572
|
if i < len(param_info):
|
|
1067
1573
|
p = param_info[i]
|
|
1068
|
-
# Use the actual parameter name if available
|
|
1069
|
-
# Oracle doesn't like underscores at the start of parameter names
|
|
1070
1574
|
param_name = p.name or f"arg_{i}"
|
|
1071
1575
|
result_dict[param_name] = value
|
|
1576
|
+
else:
|
|
1577
|
+
param_name = f"arg_{i}"
|
|
1578
|
+
result_dict[param_name] = value
|
|
1072
1579
|
return result_dict
|
|
1073
1580
|
return params
|
|
1074
1581
|
|
|
@@ -1084,7 +1591,6 @@ class SQL:
|
|
|
1084
1591
|
Dict of parameters with names
|
|
1085
1592
|
"""
|
|
1086
1593
|
if isinstance(params, (list, tuple)):
|
|
1087
|
-
# Convert list to dict with generated names
|
|
1088
1594
|
result_dict: dict[str, Any] = {}
|
|
1089
1595
|
for i, p in enumerate(param_info):
|
|
1090
1596
|
if i < len(params):
|
|
@@ -1093,14 +1599,15 @@ class SQL:
|
|
|
1093
1599
|
return result_dict
|
|
1094
1600
|
return params
|
|
1095
1601
|
|
|
1096
|
-
# Validation properties for compatibility
|
|
1097
1602
|
@property
|
|
1098
1603
|
def validation_errors(self) -> list[Any]:
|
|
1099
1604
|
"""Get validation errors."""
|
|
1100
1605
|
if not self._config.enable_validation:
|
|
1101
1606
|
return []
|
|
1102
1607
|
self._ensure_processed()
|
|
1103
|
-
|
|
1608
|
+
if not self._processed_state:
|
|
1609
|
+
msg = "Failed to process SQL statement"
|
|
1610
|
+
raise RuntimeError(msg)
|
|
1104
1611
|
return self._processed_state.validation_errors
|
|
1105
1612
|
|
|
1106
1613
|
@property
|
|
@@ -1113,25 +1620,30 @@ class SQL:
|
|
|
1113
1620
|
"""Check if statement is safe."""
|
|
1114
1621
|
return not self.has_errors
|
|
1115
1622
|
|
|
1116
|
-
# Additional compatibility methods
|
|
1117
1623
|
def validate(self) -> list[Any]:
|
|
1118
1624
|
"""Validate the SQL statement and return validation errors."""
|
|
1119
1625
|
return self.validation_errors
|
|
1120
1626
|
|
|
1121
1627
|
@property
|
|
1122
1628
|
def parameter_info(self) -> list[Any]:
|
|
1123
|
-
"""Get parameter information from the SQL statement.
|
|
1629
|
+
"""Get parameter information from the SQL statement.
|
|
1630
|
+
|
|
1631
|
+
Returns the original parameter info before any normalization.
|
|
1632
|
+
"""
|
|
1124
1633
|
validator = self._config.parameter_validator
|
|
1125
|
-
if self.
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1634
|
+
if self._raw_sql:
|
|
1635
|
+
return validator.extract_parameters(self._raw_sql)
|
|
1636
|
+
|
|
1637
|
+
self._ensure_processed()
|
|
1638
|
+
|
|
1639
|
+
if self._processing_context:
|
|
1640
|
+
return self._processing_context.parameter_info
|
|
1641
|
+
|
|
1642
|
+
return []
|
|
1130
1643
|
|
|
1131
1644
|
@property
|
|
1132
1645
|
def _raw_parameters(self) -> Any:
|
|
1133
1646
|
"""Get raw parameters for compatibility."""
|
|
1134
|
-
# Return the original parameters as passed in
|
|
1135
1647
|
return self._original_parameters
|
|
1136
1648
|
|
|
1137
1649
|
@property
|
|
@@ -1140,7 +1652,7 @@ class SQL:
|
|
|
1140
1652
|
return self.sql
|
|
1141
1653
|
|
|
1142
1654
|
@property
|
|
1143
|
-
def _expression(self) -> Optional[exp.Expression]:
|
|
1655
|
+
def _expression(self) -> "Optional[exp.Expression]":
|
|
1144
1656
|
"""Get expression for compatibility."""
|
|
1145
1657
|
return self.expression
|
|
1146
1658
|
|
|
@@ -1152,18 +1664,15 @@ class SQL:
|
|
|
1152
1664
|
def limit(self, count: int, use_parameter: bool = False) -> "SQL":
|
|
1153
1665
|
"""Add LIMIT clause."""
|
|
1154
1666
|
if use_parameter:
|
|
1155
|
-
# Create a unique parameter name
|
|
1156
1667
|
param_name = self.get_unique_parameter_name("limit")
|
|
1157
|
-
# Add parameter to the SQL object
|
|
1158
1668
|
result = self
|
|
1159
1669
|
result = result.add_named_parameter(param_name, count)
|
|
1160
|
-
|
|
1161
|
-
if hasattr(result._statement, "limit"):
|
|
1670
|
+
if supports_limit(result._statement):
|
|
1162
1671
|
new_statement = result._statement.limit(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1163
1672
|
else:
|
|
1164
1673
|
new_statement = exp.Select().from_(result._statement).limit(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1165
1674
|
return result.copy(statement=new_statement)
|
|
1166
|
-
if
|
|
1675
|
+
if supports_limit(self._statement):
|
|
1167
1676
|
new_statement = self._statement.limit(count) # pyright: ignore
|
|
1168
1677
|
else:
|
|
1169
1678
|
new_statement = exp.Select().from_(self._statement).limit(count) # pyright: ignore
|
|
@@ -1172,18 +1681,15 @@ class SQL:
|
|
|
1172
1681
|
def offset(self, count: int, use_parameter: bool = False) -> "SQL":
|
|
1173
1682
|
"""Add OFFSET clause."""
|
|
1174
1683
|
if use_parameter:
|
|
1175
|
-
# Create a unique parameter name
|
|
1176
1684
|
param_name = self.get_unique_parameter_name("offset")
|
|
1177
|
-
# Add parameter to the SQL object
|
|
1178
1685
|
result = self
|
|
1179
1686
|
result = result.add_named_parameter(param_name, count)
|
|
1180
|
-
|
|
1181
|
-
if hasattr(result._statement, "offset"):
|
|
1687
|
+
if supports_offset(result._statement):
|
|
1182
1688
|
new_statement = result._statement.offset(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1183
1689
|
else:
|
|
1184
1690
|
new_statement = exp.Select().from_(result._statement).offset(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1185
1691
|
return result.copy(statement=new_statement)
|
|
1186
|
-
if
|
|
1692
|
+
if supports_offset(self._statement):
|
|
1187
1693
|
new_statement = self._statement.offset(count) # pyright: ignore
|
|
1188
1694
|
else:
|
|
1189
1695
|
new_statement = exp.Select().from_(self._statement).offset(count) # pyright: ignore
|
|
@@ -1191,7 +1697,7 @@ class SQL:
|
|
|
1191
1697
|
|
|
1192
1698
|
def order_by(self, expression: exp.Expression) -> "SQL":
|
|
1193
1699
|
"""Add ORDER BY clause."""
|
|
1194
|
-
if
|
|
1700
|
+
if supports_order_by(self._statement):
|
|
1195
1701
|
new_statement = self._statement.order_by(expression) # pyright: ignore
|
|
1196
1702
|
else:
|
|
1197
1703
|
new_statement = exp.Select().from_(self._statement).order_by(expression) # pyright: ignore
|