sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/sql.py
CHANGED
|
@@ -1,28 +1,63 @@
|
|
|
1
1
|
"""SQL statement handling with centralized parameter management."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from
|
|
3
|
+
import operator
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
5
6
|
|
|
6
7
|
import sqlglot
|
|
7
8
|
import sqlglot.expressions as exp
|
|
8
|
-
from sqlglot.dialects.dialect import DialectType
|
|
9
9
|
from sqlglot.errors import ParseError
|
|
10
|
+
from typing_extensions import TypeAlias
|
|
10
11
|
|
|
11
|
-
from sqlspec.exceptions import RiskLevel, SQLValidationError
|
|
12
|
+
from sqlspec.exceptions import RiskLevel, SQLParsingError, SQLValidationError
|
|
12
13
|
from sqlspec.statement.filters import StatementFilter
|
|
13
|
-
from sqlspec.statement.parameters import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
from sqlspec.statement.parameters import (
|
|
15
|
+
SQLGLOT_INCOMPATIBLE_STYLES,
|
|
16
|
+
ParameterConverter,
|
|
17
|
+
ParameterStyle,
|
|
18
|
+
ParameterValidator,
|
|
19
|
+
)
|
|
20
|
+
from sqlspec.statement.pipelines import SQLProcessingContext, StatementPipeline
|
|
21
|
+
from sqlspec.statement.pipelines.transformers import CommentAndHintRemover, ParameterizeLiterals
|
|
17
22
|
from sqlspec.statement.pipelines.validators import DMLSafetyValidator, ParameterStyleValidator
|
|
18
|
-
from sqlspec.typing import is_dict
|
|
19
23
|
from sqlspec.utils.logging import get_logger
|
|
24
|
+
from sqlspec.utils.type_guards import (
|
|
25
|
+
can_append_to_statement,
|
|
26
|
+
can_extract_parameters,
|
|
27
|
+
has_parameter_value,
|
|
28
|
+
has_risk_level,
|
|
29
|
+
is_dict,
|
|
30
|
+
is_expression,
|
|
31
|
+
is_statement_filter,
|
|
32
|
+
supports_limit,
|
|
33
|
+
supports_offset,
|
|
34
|
+
supports_order_by,
|
|
35
|
+
supports_where,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from sqlglot.dialects.dialect import DialectType
|
|
40
|
+
|
|
41
|
+
from sqlspec.statement.parameters import ParameterNormalizationState
|
|
20
42
|
|
|
21
43
|
__all__ = ("SQL", "SQLConfig", "Statement")
|
|
22
44
|
|
|
23
45
|
logger = get_logger("sqlspec.statement")
|
|
24
46
|
|
|
25
|
-
Statement = Union[str, exp.Expression, "SQL"]
|
|
47
|
+
Statement: TypeAlias = Union[str, exp.Expression, "SQL"]
|
|
48
|
+
|
|
49
|
+
# Parameter naming constants
|
|
50
|
+
PARAM_PREFIX = "param_"
|
|
51
|
+
POS_PARAM_PREFIX = "pos_param_"
|
|
52
|
+
KW_POS_PARAM_PREFIX = "kw_pos_param_"
|
|
53
|
+
ARG_PREFIX = "arg_"
|
|
54
|
+
|
|
55
|
+
# Cache and limit constants
|
|
56
|
+
DEFAULT_CACHE_SIZE = 1000
|
|
57
|
+
|
|
58
|
+
# Oracle/Colon style parameter constants
|
|
59
|
+
COLON_PARAM_ONE = "1"
|
|
60
|
+
COLON_PARAM_MIN_INDEX = 1
|
|
26
61
|
|
|
27
62
|
|
|
28
63
|
@dataclass
|
|
@@ -39,9 +74,30 @@ class _ProcessedState:
|
|
|
39
74
|
|
|
40
75
|
@dataclass
|
|
41
76
|
class SQLConfig:
|
|
42
|
-
"""Configuration for SQL statement behavior.
|
|
77
|
+
"""Configuration for SQL statement behavior.
|
|
78
|
+
|
|
79
|
+
Uses conservative defaults that prioritize compatibility and robustness
|
|
80
|
+
over strict enforcement, making it easier to work with diverse SQL dialects
|
|
81
|
+
and complex queries.
|
|
82
|
+
|
|
83
|
+
Component Lists:
|
|
84
|
+
transformers: Optional list of SQL transformers for explicit staging
|
|
85
|
+
validators: Optional list of SQL validators for explicit staging
|
|
86
|
+
analyzers: Optional list of SQL analyzers for explicit staging
|
|
87
|
+
|
|
88
|
+
Configuration Options:
|
|
89
|
+
parameter_converter: Handles parameter style conversions
|
|
90
|
+
parameter_validator: Validates parameter usage and styles
|
|
91
|
+
analysis_cache_size: Cache size for analysis results
|
|
92
|
+
input_sql_had_placeholders: Populated by SQL.__init__ to track original SQL state
|
|
93
|
+
dialect: SQL dialect to use for parsing and generation
|
|
94
|
+
|
|
95
|
+
Parameter Style Configuration:
|
|
96
|
+
allowed_parameter_styles: Allowed parameter styles (e.g., ('qmark', 'named_colon'))
|
|
97
|
+
target_parameter_style: Target parameter style for SQL generation
|
|
98
|
+
allow_mixed_parameter_styles: Whether to allow mixing parameter styles in same query
|
|
99
|
+
"""
|
|
43
100
|
|
|
44
|
-
# Behavior flags
|
|
45
101
|
enable_parsing: bool = True
|
|
46
102
|
enable_validation: bool = True
|
|
47
103
|
enable_transformations: bool = True
|
|
@@ -49,29 +105,23 @@ class SQLConfig:
|
|
|
49
105
|
enable_normalization: bool = True
|
|
50
106
|
strict_mode: bool = False
|
|
51
107
|
cache_parsed_expression: bool = True
|
|
108
|
+
parse_errors_as_warnings: bool = True
|
|
52
109
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
analyzers: Optional[list[Any]] = None
|
|
110
|
+
transformers: "Optional[list[Any]]" = None
|
|
111
|
+
validators: "Optional[list[Any]]" = None
|
|
112
|
+
analyzers: "Optional[list[Any]]" = None
|
|
57
113
|
|
|
58
|
-
# Other configs
|
|
59
114
|
parameter_converter: ParameterConverter = field(default_factory=ParameterConverter)
|
|
60
115
|
parameter_validator: ParameterValidator = field(default_factory=ParameterValidator)
|
|
61
116
|
analysis_cache_size: int = 1000
|
|
62
|
-
input_sql_had_placeholders: bool = False
|
|
63
|
-
|
|
64
|
-
# Parameter style configuration
|
|
65
|
-
allowed_parameter_styles: Optional[tuple[str, ...]] = None
|
|
66
|
-
"""Allowed parameter styles for this SQL configuration (e.g., ('qmark', 'named_colon'))."""
|
|
67
|
-
|
|
68
|
-
target_parameter_style: Optional[str] = None
|
|
69
|
-
"""Target parameter style for SQL generation."""
|
|
117
|
+
input_sql_had_placeholders: bool = False
|
|
118
|
+
dialect: "Optional[DialectType]" = None
|
|
70
119
|
|
|
120
|
+
allowed_parameter_styles: "Optional[tuple[str, ...]]" = None
|
|
121
|
+
target_parameter_style: "Optional[str]" = None
|
|
71
122
|
allow_mixed_parameter_styles: bool = False
|
|
72
|
-
"""Whether to allow mixing named and positional parameters in same query."""
|
|
73
123
|
|
|
74
|
-
def validate_parameter_style(self, style: Union[ParameterStyle, str]) -> bool:
|
|
124
|
+
def validate_parameter_style(self, style: "Union[ParameterStyle, str]") -> bool:
|
|
75
125
|
"""Check if a parameter style is allowed.
|
|
76
126
|
|
|
77
127
|
Args:
|
|
@@ -81,7 +131,7 @@ class SQLConfig:
|
|
|
81
131
|
True if the style is allowed, False otherwise
|
|
82
132
|
"""
|
|
83
133
|
if self.allowed_parameter_styles is None:
|
|
84
|
-
return True
|
|
134
|
+
return True
|
|
85
135
|
style_str = str(style)
|
|
86
136
|
return style_str in self.allowed_parameter_styles
|
|
87
137
|
|
|
@@ -91,36 +141,23 @@ class SQLConfig:
|
|
|
91
141
|
Returns:
|
|
92
142
|
StatementPipeline configured with transformers, validators, and analyzers
|
|
93
143
|
"""
|
|
94
|
-
# Import here to avoid circular dependencies
|
|
95
|
-
|
|
96
|
-
# Create transformers based on config
|
|
97
144
|
transformers = []
|
|
98
145
|
if self.transformers is not None:
|
|
99
|
-
# Use explicit transformers if provided
|
|
100
146
|
transformers = list(self.transformers)
|
|
101
|
-
# Use default transformers
|
|
102
147
|
elif self.enable_transformations:
|
|
103
|
-
# Use target_parameter_style if available, otherwise default to "?"
|
|
104
148
|
placeholder_style = self.target_parameter_style or "?"
|
|
105
|
-
transformers = [
|
|
149
|
+
transformers = [CommentAndHintRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
|
|
106
150
|
|
|
107
|
-
# Create validators based on config
|
|
108
151
|
validators = []
|
|
109
152
|
if self.validators is not None:
|
|
110
|
-
# Use explicit validators if provided
|
|
111
153
|
validators = list(self.validators)
|
|
112
|
-
# Use default validators
|
|
113
154
|
elif self.enable_validation:
|
|
114
155
|
validators = [ParameterStyleValidator(fail_on_violation=self.strict_mode), DMLSafetyValidator()]
|
|
115
156
|
|
|
116
|
-
# Create analyzers based on config
|
|
117
157
|
analyzers = []
|
|
118
158
|
if self.analyzers is not None:
|
|
119
|
-
# Use explicit analyzers if provided
|
|
120
159
|
analyzers = list(self.analyzers)
|
|
121
|
-
# Use default analyzers
|
|
122
160
|
elif self.enable_analysis:
|
|
123
|
-
# Currently no default analyzers
|
|
124
161
|
analyzers = []
|
|
125
162
|
|
|
126
163
|
return StatementPipeline(transformers=transformers, validators=validators, analyzers=analyzers)
|
|
@@ -139,36 +176,39 @@ class SQL:
|
|
|
139
176
|
"""
|
|
140
177
|
|
|
141
178
|
__slots__ = (
|
|
142
|
-
"_builder_result_type",
|
|
143
|
-
"_config",
|
|
144
|
-
"_dialect",
|
|
145
|
-
"_filters",
|
|
146
|
-
"_is_many",
|
|
147
|
-
"_is_script",
|
|
148
|
-
"_named_params",
|
|
149
|
-
"_original_parameters",
|
|
150
|
-
"_original_sql",
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"
|
|
154
|
-
"
|
|
155
|
-
"
|
|
156
|
-
"
|
|
179
|
+
"_builder_result_type",
|
|
180
|
+
"_config",
|
|
181
|
+
"_dialect",
|
|
182
|
+
"_filters",
|
|
183
|
+
"_is_many",
|
|
184
|
+
"_is_script",
|
|
185
|
+
"_named_params",
|
|
186
|
+
"_original_parameters",
|
|
187
|
+
"_original_sql",
|
|
188
|
+
"_parameter_normalization_state",
|
|
189
|
+
"_placeholder_mapping",
|
|
190
|
+
"_positional_params",
|
|
191
|
+
"_processed_state",
|
|
192
|
+
"_processing_context",
|
|
193
|
+
"_raw_sql",
|
|
194
|
+
"_statement",
|
|
157
195
|
)
|
|
158
196
|
|
|
159
197
|
def __init__(
|
|
160
198
|
self,
|
|
161
|
-
statement: Union[str, exp.Expression,
|
|
162
|
-
*parameters: Union[Any, StatementFilter, list[Union[Any, StatementFilter]]],
|
|
163
|
-
_dialect: DialectType = None,
|
|
164
|
-
_config: Optional[SQLConfig] = None,
|
|
165
|
-
_builder_result_type: Optional[type] = None,
|
|
166
|
-
_existing_state: Optional[dict[str, Any]] = None,
|
|
199
|
+
statement: "Union[str, exp.Expression, 'SQL']",
|
|
200
|
+
*parameters: "Union[Any, StatementFilter, list[Union[Any, StatementFilter]]]",
|
|
201
|
+
_dialect: "DialectType" = None,
|
|
202
|
+
_config: "Optional[SQLConfig]" = None,
|
|
203
|
+
_builder_result_type: "Optional[type]" = None,
|
|
204
|
+
_existing_state: "Optional[dict[str, Any]]" = None,
|
|
167
205
|
**kwargs: Any,
|
|
168
206
|
) -> None:
|
|
169
207
|
"""Initialize SQL with centralized parameter management."""
|
|
208
|
+
if "config" in kwargs and _config is None:
|
|
209
|
+
_config = kwargs.pop("config")
|
|
170
210
|
self._config = _config or SQLConfig()
|
|
171
|
-
self._dialect = _dialect
|
|
211
|
+
self._dialect = _dialect or (self._config.dialect if self._config else None)
|
|
172
212
|
self._builder_result_type = _builder_result_type
|
|
173
213
|
self._processed_state: Optional[_ProcessedState] = None
|
|
174
214
|
self._processing_context: Optional[SQLProcessingContext] = None
|
|
@@ -180,6 +220,7 @@ class SQL:
|
|
|
180
220
|
self._original_parameters: Any = None
|
|
181
221
|
self._original_sql: str = ""
|
|
182
222
|
self._placeholder_mapping: dict[str, Union[str, int]] = {}
|
|
223
|
+
self._parameter_normalization_state: Optional[ParameterNormalizationState] = None
|
|
183
224
|
self._is_many: bool = False
|
|
184
225
|
self._is_script: bool = False
|
|
185
226
|
|
|
@@ -197,7 +238,11 @@ class SQL:
|
|
|
197
238
|
self._process_parameters(*parameters, **kwargs)
|
|
198
239
|
|
|
199
240
|
def _init_from_sql_object(
|
|
200
|
-
self,
|
|
241
|
+
self,
|
|
242
|
+
statement: "SQL",
|
|
243
|
+
dialect: "DialectType",
|
|
244
|
+
config: "Optional[SQLConfig]",
|
|
245
|
+
builder_result_type: "Optional[type]",
|
|
201
246
|
) -> None:
|
|
202
247
|
"""Initialize attributes from an existing SQL object."""
|
|
203
248
|
self._statement = statement._statement
|
|
@@ -210,24 +255,21 @@ class SQL:
|
|
|
210
255
|
self._original_parameters = statement._original_parameters
|
|
211
256
|
self._original_sql = statement._original_sql
|
|
212
257
|
self._placeholder_mapping = statement._placeholder_mapping.copy()
|
|
258
|
+
self._parameter_normalization_state = statement._parameter_normalization_state
|
|
213
259
|
self._positional_params.extend(statement._positional_params)
|
|
214
260
|
self._named_params.update(statement._named_params)
|
|
215
261
|
self._filters.extend(statement._filters)
|
|
216
262
|
|
|
217
|
-
def _init_from_str_or_expression(self, statement: Union[str, exp.Expression]) -> None:
|
|
263
|
+
def _init_from_str_or_expression(self, statement: "Union[str, exp.Expression]") -> None:
|
|
218
264
|
"""Initialize attributes from a SQL string or expression."""
|
|
219
265
|
if isinstance(statement, str):
|
|
220
266
|
self._raw_sql = statement
|
|
221
|
-
if self._raw_sql and not self._config.input_sql_had_placeholders:
|
|
222
|
-
param_info = self._config.parameter_validator.extract_parameters(self._raw_sql)
|
|
223
|
-
if param_info:
|
|
224
|
-
self._config = replace(self._config, input_sql_had_placeholders=True)
|
|
225
267
|
self._statement = self._to_expression(statement)
|
|
226
268
|
else:
|
|
227
269
|
self._raw_sql = statement.sql(dialect=self._dialect) # pyright: ignore
|
|
228
270
|
self._statement = statement
|
|
229
271
|
|
|
230
|
-
def _load_from_existing_state(self, existing_state: dict[str, Any]) -> None:
|
|
272
|
+
def _load_from_existing_state(self, existing_state: "dict[str, Any]") -> None:
|
|
231
273
|
"""Load state from a dictionary (used by copy)."""
|
|
232
274
|
self._positional_params = list(existing_state.get("positional_params", self._positional_params))
|
|
233
275
|
self._named_params = dict(existing_state.get("named_params", self._named_params))
|
|
@@ -239,12 +281,12 @@ class SQL:
|
|
|
239
281
|
|
|
240
282
|
def _set_original_parameters(self, *parameters: Any) -> None:
|
|
241
283
|
"""Store the original parameters for compatibility."""
|
|
242
|
-
if len(parameters) == 1 and
|
|
284
|
+
if len(parameters) == 0 or (len(parameters) == 1 and is_statement_filter(parameters[0])):
|
|
285
|
+
self._original_parameters = None
|
|
286
|
+
elif len(parameters) == 1 and isinstance(parameters[0], (list, tuple)):
|
|
243
287
|
self._original_parameters = parameters[0]
|
|
244
|
-
elif len(parameters) > 1:
|
|
245
|
-
self._original_parameters = parameters
|
|
246
288
|
else:
|
|
247
|
-
self._original_parameters =
|
|
289
|
+
self._original_parameters = parameters
|
|
248
290
|
|
|
249
291
|
def _process_parameters(self, *parameters: Any, **kwargs: Any) -> None:
|
|
250
292
|
"""Process positional and keyword arguments for parameters and filters."""
|
|
@@ -255,7 +297,7 @@ class SQL:
|
|
|
255
297
|
param_value = kwargs.pop("parameters")
|
|
256
298
|
if isinstance(param_value, (list, tuple)):
|
|
257
299
|
self._positional_params.extend(param_value)
|
|
258
|
-
elif
|
|
300
|
+
elif is_dict(param_value):
|
|
259
301
|
self._named_params.update(param_value)
|
|
260
302
|
else:
|
|
261
303
|
self._positional_params.append(param_value)
|
|
@@ -266,7 +308,7 @@ class SQL:
|
|
|
266
308
|
|
|
267
309
|
def _process_parameter_item(self, item: Any) -> None:
|
|
268
310
|
"""Process a single item from the parameters list."""
|
|
269
|
-
if
|
|
311
|
+
if is_statement_filter(item):
|
|
270
312
|
self._filters.append(item)
|
|
271
313
|
pos_params, named_params = self._extract_filter_parameters(item)
|
|
272
314
|
self._positional_params.extend(pos_params)
|
|
@@ -274,7 +316,7 @@ class SQL:
|
|
|
274
316
|
elif isinstance(item, list):
|
|
275
317
|
for sub_item in item:
|
|
276
318
|
self._process_parameter_item(sub_item)
|
|
277
|
-
elif
|
|
319
|
+
elif is_dict(item):
|
|
278
320
|
self._named_params.update(item)
|
|
279
321
|
elif isinstance(item, tuple):
|
|
280
322
|
self._positional_params.extend(item)
|
|
@@ -290,120 +332,255 @@ class SQL:
|
|
|
290
332
|
if self._processed_state is not None:
|
|
291
333
|
return
|
|
292
334
|
|
|
293
|
-
# Get the final expression and parameters after filters
|
|
294
335
|
final_expr, final_params = self._build_final_state()
|
|
336
|
+
has_placeholders = self._detect_placeholders()
|
|
337
|
+
initial_sql_for_context, final_params = self._prepare_context_sql(final_expr, final_params)
|
|
338
|
+
|
|
339
|
+
context = self._create_processing_context(initial_sql_for_context, final_expr, final_params, has_placeholders)
|
|
340
|
+
result = self._run_pipeline(context)
|
|
341
|
+
|
|
342
|
+
processed_sql, merged_params = self._process_pipeline_result(result, final_params, context)
|
|
295
343
|
|
|
296
|
-
|
|
344
|
+
self._finalize_processed_state(result, processed_sql, merged_params)
|
|
345
|
+
|
|
346
|
+
def _detect_placeholders(self) -> bool:
|
|
347
|
+
"""Detect if the raw SQL has placeholders."""
|
|
297
348
|
if self._raw_sql:
|
|
298
349
|
validator = self._config.parameter_validator
|
|
299
350
|
raw_param_info = validator.extract_parameters(self._raw_sql)
|
|
300
351
|
has_placeholders = bool(raw_param_info)
|
|
301
|
-
|
|
302
|
-
|
|
352
|
+
if has_placeholders:
|
|
353
|
+
self._config.input_sql_had_placeholders = True
|
|
354
|
+
return has_placeholders
|
|
355
|
+
return self._config.input_sql_had_placeholders
|
|
356
|
+
|
|
357
|
+
def _prepare_context_sql(self, final_expr: exp.Expression, final_params: Any) -> tuple[str, Any]:
|
|
358
|
+
"""Prepare SQL string and parameters for context."""
|
|
359
|
+
initial_sql_for_context = self._raw_sql or final_expr.sql(dialect=self._dialect or self._config.dialect)
|
|
360
|
+
|
|
361
|
+
if is_expression(final_expr) and self._placeholder_mapping:
|
|
362
|
+
initial_sql_for_context = final_expr.sql(dialect=self._dialect or self._config.dialect)
|
|
363
|
+
if self._placeholder_mapping:
|
|
364
|
+
final_params = self._normalize_parameters(final_params)
|
|
365
|
+
|
|
366
|
+
return initial_sql_for_context, final_params
|
|
367
|
+
|
|
368
|
+
def _normalize_parameters(self, final_params: Any) -> Any:
|
|
369
|
+
"""Normalize parameters based on placeholder mapping."""
|
|
370
|
+
if is_dict(final_params):
|
|
371
|
+
normalized_params = {}
|
|
372
|
+
for placeholder_key, original_name in self._placeholder_mapping.items():
|
|
373
|
+
if str(original_name) in final_params:
|
|
374
|
+
normalized_params[placeholder_key] = final_params[str(original_name)]
|
|
375
|
+
non_oracle_params = {
|
|
376
|
+
key: value
|
|
377
|
+
for key, value in final_params.items()
|
|
378
|
+
if key not in {str(name) for name in self._placeholder_mapping.values()}
|
|
379
|
+
}
|
|
380
|
+
normalized_params.update(non_oracle_params)
|
|
381
|
+
return normalized_params
|
|
382
|
+
if isinstance(final_params, (list, tuple)):
|
|
383
|
+
validator = self._config.parameter_validator
|
|
384
|
+
param_info = validator.extract_parameters(self._raw_sql)
|
|
385
|
+
|
|
386
|
+
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
303
387
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
self._config = replace(self._config, input_sql_had_placeholders=True)
|
|
388
|
+
if all_numeric:
|
|
389
|
+
normalized_params = {}
|
|
307
390
|
|
|
308
|
-
|
|
391
|
+
min_param_num = min(int(p.name) for p in param_info if p.name)
|
|
392
|
+
|
|
393
|
+
for i, param in enumerate(final_params):
|
|
394
|
+
param_num = str(i + min_param_num)
|
|
395
|
+
normalized_params[param_num] = param
|
|
396
|
+
|
|
397
|
+
return normalized_params
|
|
398
|
+
normalized_params = {}
|
|
399
|
+
for i, param in enumerate(final_params):
|
|
400
|
+
if i < len(param_info):
|
|
401
|
+
placeholder_key = f"{PARAM_PREFIX}{param_info[i].ordinal}"
|
|
402
|
+
normalized_params[placeholder_key] = param
|
|
403
|
+
return normalized_params
|
|
404
|
+
return final_params
|
|
405
|
+
|
|
406
|
+
def _create_processing_context(
|
|
407
|
+
self, initial_sql_for_context: str, final_expr: exp.Expression, final_params: Any, has_placeholders: bool
|
|
408
|
+
) -> SQLProcessingContext:
|
|
409
|
+
"""Create SQL processing context."""
|
|
309
410
|
context = SQLProcessingContext(
|
|
310
|
-
initial_sql_string=
|
|
311
|
-
dialect=self._dialect,
|
|
411
|
+
initial_sql_string=initial_sql_for_context,
|
|
412
|
+
dialect=self._dialect or self._config.dialect,
|
|
312
413
|
config=self._config,
|
|
313
|
-
current_expression=final_expr,
|
|
314
414
|
initial_expression=final_expr,
|
|
415
|
+
current_expression=final_expr,
|
|
315
416
|
merged_parameters=final_params,
|
|
316
|
-
input_sql_had_placeholders=has_placeholders,
|
|
417
|
+
input_sql_had_placeholders=has_placeholders or self._config.input_sql_had_placeholders,
|
|
317
418
|
)
|
|
318
419
|
|
|
319
|
-
|
|
420
|
+
if self._placeholder_mapping:
|
|
421
|
+
context.extra_info["placeholder_map"] = self._placeholder_mapping
|
|
422
|
+
|
|
423
|
+
# Set normalization state if available
|
|
424
|
+
if self._parameter_normalization_state:
|
|
425
|
+
context.parameter_normalization = self._parameter_normalization_state
|
|
426
|
+
|
|
320
427
|
validator = self._config.parameter_validator
|
|
321
428
|
context.parameter_info = validator.extract_parameters(context.initial_sql_string)
|
|
322
429
|
|
|
323
|
-
|
|
430
|
+
return context
|
|
431
|
+
|
|
432
|
+
def _run_pipeline(self, context: SQLProcessingContext) -> Any:
|
|
433
|
+
"""Run the SQL processing pipeline."""
|
|
324
434
|
pipeline = self._config.get_statement_pipeline()
|
|
325
435
|
result = pipeline.execute_pipeline(context)
|
|
326
|
-
|
|
327
|
-
# Store the processing context for later use
|
|
328
436
|
self._processing_context = result.context
|
|
437
|
+
return result
|
|
329
438
|
|
|
330
|
-
|
|
439
|
+
def _process_pipeline_result(
|
|
440
|
+
self, result: Any, final_params: Any, context: SQLProcessingContext
|
|
441
|
+
) -> tuple[str, Any]:
|
|
442
|
+
"""Process the result from the pipeline."""
|
|
331
443
|
processed_expr = result.expression
|
|
444
|
+
|
|
332
445
|
if isinstance(processed_expr, exp.Anonymous):
|
|
333
446
|
processed_sql = self._raw_sql or context.initial_sql_string
|
|
334
447
|
else:
|
|
335
|
-
processed_sql = processed_expr.sql(dialect=self._dialect, comments=False)
|
|
448
|
+
processed_sql = processed_expr.sql(dialect=self._dialect or self._config.dialect, comments=False)
|
|
336
449
|
logger.debug("Processed expression SQL: '%s'", processed_sql)
|
|
337
450
|
|
|
338
|
-
# Check if we need to denormalize pyformat placeholders
|
|
339
451
|
if self._placeholder_mapping and self._original_sql:
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
452
|
+
processed_sql, result = self._denormalize_sql(processed_sql, result)
|
|
453
|
+
|
|
454
|
+
merged_params = self._merge_pipeline_parameters(result, final_params)
|
|
455
|
+
|
|
456
|
+
return processed_sql, merged_params
|
|
457
|
+
|
|
458
|
+
def _denormalize_sql(self, processed_sql: str, result: Any) -> tuple[str, Any]:
|
|
459
|
+
"""Denormalize SQL back to original parameter style."""
|
|
460
|
+
|
|
461
|
+
original_sql = self._original_sql
|
|
462
|
+
param_info = self._config.parameter_validator.extract_parameters(original_sql)
|
|
463
|
+
target_styles = {p.style for p in param_info}
|
|
464
|
+
|
|
465
|
+
logger.debug(
|
|
466
|
+
"Denormalizing SQL: before='%s', original='%s', styles=%s", processed_sql, original_sql, target_styles
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
|
|
470
|
+
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
471
|
+
processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
|
|
472
|
+
)
|
|
473
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
474
|
+
elif ParameterStyle.NAMED_PYFORMAT in target_styles:
|
|
475
|
+
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
476
|
+
processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
|
|
477
|
+
)
|
|
478
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
479
|
+
# Also denormalize the parameters back to their original names
|
|
480
|
+
if (
|
|
481
|
+
self._placeholder_mapping
|
|
482
|
+
and result.context.merged_parameters
|
|
483
|
+
and is_dict(result.context.merged_parameters)
|
|
484
|
+
):
|
|
485
|
+
result.context.merged_parameters = self._denormalize_pyformat_params(result.context.merged_parameters)
|
|
486
|
+
elif ParameterStyle.POSITIONAL_COLON in target_styles:
|
|
487
|
+
processed_param_info = self._config.parameter_validator.extract_parameters(processed_sql)
|
|
488
|
+
has_param_placeholders = any(p.name and p.name.startswith(PARAM_PREFIX) for p in processed_param_info)
|
|
489
|
+
|
|
490
|
+
if has_param_placeholders:
|
|
491
|
+
logger.debug("Skipping denormalization for param_N placeholders")
|
|
367
492
|
else:
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
bool(self._placeholder_mapping),
|
|
371
|
-
bool(self._original_sql),
|
|
493
|
+
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
494
|
+
processed_sql, param_info, ParameterStyle.POSITIONAL_COLON
|
|
372
495
|
)
|
|
496
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
497
|
+
if (
|
|
498
|
+
self._placeholder_mapping
|
|
499
|
+
and result.context.merged_parameters
|
|
500
|
+
and is_dict(result.context.merged_parameters)
|
|
501
|
+
):
|
|
502
|
+
result.context.merged_parameters = self._denormalize_colon_params(result.context.merged_parameters)
|
|
503
|
+
else:
|
|
504
|
+
logger.debug(
|
|
505
|
+
"No denormalization needed: mapping=%s, original=%s",
|
|
506
|
+
bool(self._placeholder_mapping),
|
|
507
|
+
bool(self._original_sql),
|
|
508
|
+
)
|
|
373
509
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
510
|
+
return processed_sql, result
|
|
511
|
+
|
|
512
|
+
def _denormalize_colon_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
|
|
513
|
+
"""Denormalize colon-style parameters back to numeric format."""
|
|
514
|
+
# For positional colon style, all params should have numeric keys
|
|
515
|
+
# Just return the params as-is if they already have the right format
|
|
516
|
+
if all(key.isdigit() for key in params):
|
|
517
|
+
return params
|
|
518
|
+
|
|
519
|
+
# For positional colon, we need ALL parameters in the final result
|
|
520
|
+
# This includes both user parameters and extracted literals
|
|
521
|
+
# We should NOT filter out extracted parameters (param_0, param_1, etc)
|
|
522
|
+
# because they need to be included in the final parameter conversion
|
|
523
|
+
return params
|
|
524
|
+
|
|
525
|
+
def _denormalize_pyformat_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
|
|
526
|
+
"""Denormalize pyformat parameters back to their original names."""
|
|
527
|
+
denormalized_params = {}
|
|
528
|
+
for placeholder_key, original_name in self._placeholder_mapping.items():
|
|
529
|
+
if placeholder_key in params:
|
|
530
|
+
# For pyformat, the original_name is the actual parameter name (e.g., 'max_value')
|
|
531
|
+
denormalized_params[str(original_name)] = params[placeholder_key]
|
|
532
|
+
# Include any parameters that weren't normalized
|
|
533
|
+
non_normalized_params = {key: value for key, value in params.items() if not key.startswith(PARAM_PREFIX)}
|
|
534
|
+
denormalized_params.update(non_normalized_params)
|
|
535
|
+
return denormalized_params
|
|
536
|
+
|
|
537
|
+
def _merge_pipeline_parameters(self, result: Any, final_params: Any) -> Any:
|
|
538
|
+
"""Merge parameters from the pipeline processing."""
|
|
539
|
+
merged_params = result.context.merged_parameters
|
|
540
|
+
|
|
541
|
+
# If we have extracted parameters from the pipeline, only merge them if:
|
|
542
|
+
# 1. We don't already have parameters in merged_params, OR
|
|
543
|
+
# 2. The original params were None and we need to use the extracted ones
|
|
544
|
+
if result.context.extracted_parameters_from_pipeline:
|
|
545
|
+
if merged_params is None:
|
|
546
|
+
# No existing parameters - use the extracted ones
|
|
386
547
|
merged_params = result.context.extracted_parameters_from_pipeline
|
|
387
|
-
|
|
388
|
-
#
|
|
389
|
-
merged_params =
|
|
548
|
+
elif merged_params == final_params and final_params is None:
|
|
549
|
+
# Both are None, use extracted parameters
|
|
550
|
+
merged_params = result.context.extracted_parameters_from_pipeline
|
|
551
|
+
elif merged_params != result.context.extracted_parameters_from_pipeline:
|
|
552
|
+
# Only merge if the extracted parameters are different from what we already have
|
|
553
|
+
# This prevents the duplication issue where the same parameters get added twice
|
|
554
|
+
if is_dict(merged_params):
|
|
555
|
+
for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
|
|
556
|
+
param_name = f"{PARAM_PREFIX}{i}"
|
|
557
|
+
merged_params[param_name] = param
|
|
558
|
+
elif isinstance(merged_params, (list, tuple)):
|
|
559
|
+
# Only extend if we don't already have these parameters
|
|
560
|
+
# Convert to list and extend with extracted parameters
|
|
561
|
+
if isinstance(merged_params, tuple):
|
|
562
|
+
merged_params = list(merged_params)
|
|
563
|
+
merged_params.extend(result.context.extracted_parameters_from_pipeline)
|
|
564
|
+
else:
|
|
565
|
+
# Single parameter case - convert to list with original + extracted
|
|
566
|
+
merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
|
|
567
|
+
|
|
568
|
+
return merged_params
|
|
390
569
|
|
|
391
|
-
|
|
570
|
+
def _finalize_processed_state(self, result: Any, processed_sql: str, merged_params: Any) -> None:
|
|
571
|
+
"""Finalize the processed state."""
|
|
392
572
|
self._processed_state = _ProcessedState(
|
|
393
|
-
processed_expression=
|
|
573
|
+
processed_expression=result.expression,
|
|
394
574
|
processed_sql=processed_sql,
|
|
395
575
|
merged_parameters=merged_params,
|
|
396
576
|
validation_errors=list(result.context.validation_errors),
|
|
397
|
-
analysis_results={},
|
|
398
|
-
transformation_results={},
|
|
577
|
+
analysis_results={},
|
|
578
|
+
transformation_results={},
|
|
399
579
|
)
|
|
400
580
|
|
|
401
|
-
# Check strict mode
|
|
402
581
|
if self._config.strict_mode and self._processed_state.validation_errors:
|
|
403
|
-
# Find the highest risk error
|
|
404
582
|
highest_risk_error = max(
|
|
405
|
-
self._processed_state.validation_errors,
|
|
406
|
-
key=lambda e: e.risk_level.value if hasattr(e, "risk_level") else 0,
|
|
583
|
+
self._processed_state.validation_errors, key=lambda e: e.risk_level.value if has_risk_level(e) else 0
|
|
407
584
|
)
|
|
408
585
|
raise SQLValidationError(
|
|
409
586
|
message=highest_risk_error.message,
|
|
@@ -411,81 +588,85 @@ class SQL:
|
|
|
411
588
|
risk_level=getattr(highest_risk_error, "risk_level", RiskLevel.HIGH),
|
|
412
589
|
)
|
|
413
590
|
|
|
414
|
-
def _to_expression(self, statement: Union[str, exp.Expression]) -> exp.Expression:
|
|
591
|
+
def _to_expression(self, statement: "Union[str, exp.Expression]") -> exp.Expression:
|
|
415
592
|
"""Convert string to sqlglot expression."""
|
|
416
|
-
if
|
|
593
|
+
if is_expression(statement):
|
|
417
594
|
return statement
|
|
418
595
|
|
|
419
|
-
|
|
420
|
-
if not statement or not statement.strip():
|
|
421
|
-
# Return an empty select instead of Anonymous for empty strings
|
|
596
|
+
if not statement or (isinstance(statement, str) and not statement.strip()):
|
|
422
597
|
return exp.Select()
|
|
423
598
|
|
|
424
|
-
# Check if parsing is disabled
|
|
425
599
|
if not self._config.enable_parsing:
|
|
426
|
-
# Return an anonymous expression that preserves the raw SQL
|
|
427
600
|
return exp.Anonymous(this=statement)
|
|
428
601
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
602
|
+
if not isinstance(statement, str):
|
|
603
|
+
return exp.Anonymous(this="")
|
|
432
604
|
validator = self._config.parameter_validator
|
|
433
605
|
param_info = validator.extract_parameters(statement)
|
|
434
606
|
|
|
435
|
-
# Check if
|
|
436
|
-
|
|
437
|
-
p.style in {ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT} for p in param_info
|
|
438
|
-
)
|
|
607
|
+
# Check if normalization is needed
|
|
608
|
+
needs_normalization = any(p.style in SQLGLOT_INCOMPATIBLE_STYLES for p in param_info)
|
|
439
609
|
|
|
440
610
|
normalized_sql = statement
|
|
441
611
|
placeholder_mapping: dict[str, Any] = {}
|
|
442
612
|
|
|
443
|
-
if
|
|
444
|
-
# Normalize pyformat placeholders to named placeholders for SQLGlot
|
|
613
|
+
if needs_normalization:
|
|
445
614
|
converter = self._config.parameter_converter
|
|
446
615
|
normalized_sql, placeholder_mapping = converter._transform_sql_for_parsing(statement, param_info)
|
|
447
|
-
# Store the original SQL before normalization
|
|
448
616
|
self._original_sql = statement
|
|
449
617
|
self._placeholder_mapping = placeholder_mapping
|
|
450
618
|
|
|
619
|
+
# Create normalization state
|
|
620
|
+
from sqlspec.statement.parameters import ParameterNormalizationState
|
|
621
|
+
|
|
622
|
+
self._parameter_normalization_state = ParameterNormalizationState(
|
|
623
|
+
was_normalized=True,
|
|
624
|
+
original_styles=list({p.style for p in param_info}),
|
|
625
|
+
normalized_style=ParameterStyle.NAMED_COLON,
|
|
626
|
+
placeholder_map=placeholder_mapping,
|
|
627
|
+
original_param_info=param_info,
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
self._parameter_normalization_state = None
|
|
631
|
+
|
|
451
632
|
try:
|
|
452
|
-
# Parse with sqlglot
|
|
453
633
|
expressions = sqlglot.parse(normalized_sql, dialect=self._dialect) # pyright: ignore
|
|
454
634
|
if not expressions:
|
|
455
|
-
# Empty statement
|
|
456
635
|
return exp.Anonymous(this=statement)
|
|
457
636
|
first_expr = expressions[0]
|
|
458
637
|
if first_expr is None:
|
|
459
|
-
# Could not parse
|
|
460
638
|
return exp.Anonymous(this=statement)
|
|
461
639
|
|
|
462
640
|
except ParseError as e:
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
641
|
+
if getattr(self._config, "parse_errors_as_warnings", False):
|
|
642
|
+
logger.warning(
|
|
643
|
+
"Failed to parse SQL, returning Anonymous expression.", extra={"sql": statement, "error": str(e)}
|
|
644
|
+
)
|
|
645
|
+
return exp.Anonymous(this=statement)
|
|
646
|
+
|
|
647
|
+
msg = f"Failed to parse SQL: {statement}"
|
|
648
|
+
raise SQLParsingError(msg) from e
|
|
466
649
|
return first_expr
|
|
467
650
|
|
|
468
651
|
@staticmethod
|
|
469
652
|
def _extract_filter_parameters(filter_obj: StatementFilter) -> tuple[list[Any], dict[str, Any]]:
|
|
470
653
|
"""Extract parameters from a filter object."""
|
|
471
|
-
if
|
|
654
|
+
if can_extract_parameters(filter_obj):
|
|
472
655
|
return filter_obj.extract_parameters()
|
|
473
|
-
# Fallback for filters that don't implement the new method yet
|
|
474
656
|
return [], {}
|
|
475
657
|
|
|
476
658
|
def copy(
|
|
477
659
|
self,
|
|
478
|
-
statement: Optional[Union[str, exp.Expression]] = None,
|
|
479
|
-
parameters: Optional[Any] = None,
|
|
480
|
-
dialect: DialectType = None,
|
|
481
|
-
config: Optional[SQLConfig] = None,
|
|
660
|
+
statement: "Optional[Union[str, exp.Expression]]" = None,
|
|
661
|
+
parameters: "Optional[Any]" = None,
|
|
662
|
+
dialect: "DialectType" = None,
|
|
663
|
+
config: "Optional[SQLConfig]" = None,
|
|
482
664
|
**kwargs: Any,
|
|
483
665
|
) -> "SQL":
|
|
484
666
|
"""Create a copy with optional modifications.
|
|
485
667
|
|
|
486
668
|
This is the primary method for creating modified SQL objects.
|
|
487
669
|
"""
|
|
488
|
-
# Prepare existing state
|
|
489
670
|
existing_state = {
|
|
490
671
|
"positional_params": list(self._positional_params),
|
|
491
672
|
"named_params": dict(self._named_params),
|
|
@@ -494,27 +675,22 @@ class SQL:
|
|
|
494
675
|
"is_script": self._is_script,
|
|
495
676
|
"raw_sql": self._raw_sql,
|
|
496
677
|
}
|
|
497
|
-
# Always include original_parameters in existing_state
|
|
498
678
|
existing_state["original_parameters"] = self._original_parameters
|
|
499
679
|
|
|
500
|
-
# Create new instance
|
|
501
680
|
new_statement = statement if statement is not None else self._statement
|
|
502
681
|
new_dialect = dialect if dialect is not None else self._dialect
|
|
503
682
|
new_config = config if config is not None else self._config
|
|
504
683
|
|
|
505
|
-
# If parameters are explicitly provided, they replace existing ones
|
|
506
684
|
if parameters is not None:
|
|
507
|
-
# Clear existing state so only new parameters are used
|
|
508
685
|
existing_state["positional_params"] = []
|
|
509
686
|
existing_state["named_params"] = {}
|
|
510
|
-
# Pass parameters through normal processing
|
|
511
687
|
return SQL(
|
|
512
688
|
new_statement,
|
|
513
689
|
parameters,
|
|
514
690
|
_dialect=new_dialect,
|
|
515
691
|
_config=new_config,
|
|
516
692
|
_builder_result_type=self._builder_result_type,
|
|
517
|
-
_existing_state=None,
|
|
693
|
+
_existing_state=None,
|
|
518
694
|
**kwargs,
|
|
519
695
|
)
|
|
520
696
|
|
|
@@ -527,14 +703,14 @@ class SQL:
|
|
|
527
703
|
**kwargs,
|
|
528
704
|
)
|
|
529
705
|
|
|
530
|
-
def add_named_parameter(self, name: str, value: Any) -> "SQL":
|
|
706
|
+
def add_named_parameter(self, name: "str", value: Any) -> "SQL":
|
|
531
707
|
"""Add a named parameter and return a new SQL instance."""
|
|
532
708
|
new_obj = self.copy()
|
|
533
709
|
new_obj._named_params[name] = value
|
|
534
710
|
return new_obj
|
|
535
711
|
|
|
536
712
|
def get_unique_parameter_name(
|
|
537
|
-
self, base_name: str, namespace: Optional[str] = None, preserve_original: bool = False
|
|
713
|
+
self, base_name: "str", namespace: "Optional[str]" = None, preserve_original: bool = False
|
|
538
714
|
) -> str:
|
|
539
715
|
"""Generate a unique parameter name.
|
|
540
716
|
|
|
@@ -546,21 +722,16 @@ class SQL:
|
|
|
546
722
|
Returns:
|
|
547
723
|
A unique parameter name
|
|
548
724
|
"""
|
|
549
|
-
# Check both positional and named params
|
|
550
725
|
all_param_names = set(self._named_params.keys())
|
|
551
726
|
|
|
552
|
-
# Build the candidate name
|
|
553
727
|
candidate = f"{namespace}_{base_name}" if namespace else base_name
|
|
554
728
|
|
|
555
|
-
# If preserve_original and the name is unique, use it
|
|
556
729
|
if preserve_original and candidate not in all_param_names:
|
|
557
730
|
return candidate
|
|
558
731
|
|
|
559
|
-
# If not preserving or name exists, generate unique name
|
|
560
732
|
if candidate not in all_param_names:
|
|
561
733
|
return candidate
|
|
562
734
|
|
|
563
|
-
# Generate unique name with counter
|
|
564
735
|
counter = 1
|
|
565
736
|
while True:
|
|
566
737
|
new_candidate = f"{candidate}_{counter}"
|
|
@@ -570,24 +741,19 @@ class SQL:
|
|
|
570
741
|
|
|
571
742
|
def where(self, condition: "Union[str, exp.Expression, exp.Condition]") -> "SQL":
|
|
572
743
|
"""Apply WHERE clause and return new SQL instance."""
|
|
573
|
-
# Convert condition to expression
|
|
574
744
|
condition_expr = self._to_expression(condition) if isinstance(condition, str) else condition
|
|
575
745
|
|
|
576
|
-
|
|
577
|
-
if hasattr(self._statement, "where"):
|
|
746
|
+
if supports_where(self._statement):
|
|
578
747
|
new_statement = self._statement.where(condition_expr) # pyright: ignore
|
|
579
748
|
else:
|
|
580
|
-
# Wrap in SELECT if needed
|
|
581
749
|
new_statement = exp.Select().from_(self._statement).where(condition_expr) # pyright: ignore
|
|
582
750
|
|
|
583
751
|
return self.copy(statement=new_statement)
|
|
584
752
|
|
|
585
753
|
def filter(self, filter_obj: StatementFilter) -> "SQL":
|
|
586
754
|
"""Apply a filter and return a new SQL instance."""
|
|
587
|
-
# Create a new SQL object with the filter added
|
|
588
755
|
new_obj = self.copy()
|
|
589
756
|
new_obj._filters.append(filter_obj)
|
|
590
|
-
# Extract filter parameters
|
|
591
757
|
pos_params, named_params = self._extract_filter_parameters(filter_obj)
|
|
592
758
|
new_obj._positional_params.extend(pos_params)
|
|
593
759
|
new_obj._named_params.update(named_params)
|
|
@@ -611,81 +777,82 @@ class SQL:
|
|
|
611
777
|
|
|
612
778
|
def _build_final_state(self) -> tuple[exp.Expression, Any]:
|
|
613
779
|
"""Build final expression and parameters after applying filters."""
|
|
614
|
-
# Start with current statement
|
|
615
780
|
final_expr = self._statement
|
|
616
781
|
|
|
617
|
-
# Apply all filters to the expression
|
|
618
782
|
for filter_obj in self._filters:
|
|
619
|
-
if
|
|
783
|
+
if can_append_to_statement(filter_obj):
|
|
620
784
|
temp_sql = SQL(final_expr, config=self._config, dialect=self._dialect)
|
|
621
785
|
temp_sql._positional_params = list(self._positional_params)
|
|
622
786
|
temp_sql._named_params = dict(self._named_params)
|
|
623
787
|
result = filter_obj.append_to_statement(temp_sql)
|
|
624
788
|
final_expr = result._statement if isinstance(result, SQL) else result
|
|
625
789
|
|
|
626
|
-
# Determine final parameters format
|
|
627
790
|
final_params: Any
|
|
628
791
|
if self._named_params and not self._positional_params:
|
|
629
|
-
# Only named params
|
|
630
792
|
final_params = dict(self._named_params)
|
|
631
793
|
elif self._positional_params and not self._named_params:
|
|
632
|
-
# Always return a list for positional params to maintain sequence type
|
|
633
794
|
final_params = list(self._positional_params)
|
|
634
795
|
elif self._positional_params and self._named_params:
|
|
635
|
-
# Mixed - merge into dict
|
|
636
796
|
final_params = dict(self._named_params)
|
|
637
|
-
# Add positional params with generated names
|
|
638
797
|
for i, param in enumerate(self._positional_params):
|
|
639
798
|
param_name = f"arg_{i}"
|
|
640
799
|
while param_name in final_params:
|
|
641
800
|
param_name = f"arg_{i}_{id(param)}"
|
|
642
801
|
final_params[param_name] = param
|
|
643
802
|
else:
|
|
644
|
-
# No parameters
|
|
645
803
|
final_params = None
|
|
646
804
|
|
|
647
805
|
return final_expr, final_params
|
|
648
806
|
|
|
649
|
-
# Properties for compatibility
|
|
650
807
|
@property
|
|
651
808
|
def sql(self) -> str:
|
|
652
809
|
"""Get SQL string."""
|
|
653
|
-
# Handle empty string case
|
|
654
810
|
if not self._raw_sql or (self._raw_sql and not self._raw_sql.strip()):
|
|
655
811
|
return ""
|
|
656
812
|
|
|
657
|
-
# For scripts, always return the raw SQL to preserve multi-statement scripts
|
|
658
813
|
if self._is_script and self._raw_sql:
|
|
659
814
|
return self._raw_sql
|
|
660
|
-
# If parsing is disabled, return the raw SQL
|
|
661
815
|
if not self._config.enable_parsing and self._raw_sql:
|
|
662
816
|
return self._raw_sql
|
|
663
817
|
|
|
664
|
-
# Ensure processed
|
|
665
818
|
self._ensure_processed()
|
|
666
|
-
|
|
819
|
+
if self._processed_state is None:
|
|
820
|
+
msg = "Failed to process SQL statement"
|
|
821
|
+
raise RuntimeError(msg)
|
|
667
822
|
return self._processed_state.processed_sql
|
|
668
823
|
|
|
669
824
|
@property
|
|
670
|
-
def expression(self) -> Optional[exp.Expression]:
|
|
825
|
+
def expression(self) -> "Optional[exp.Expression]":
|
|
671
826
|
"""Get the final expression."""
|
|
672
|
-
# Return None if parsing is disabled
|
|
673
827
|
if not self._config.enable_parsing:
|
|
674
828
|
return None
|
|
675
829
|
self._ensure_processed()
|
|
676
|
-
|
|
830
|
+
if self._processed_state is None:
|
|
831
|
+
msg = "Failed to process SQL statement"
|
|
832
|
+
raise RuntimeError(msg)
|
|
677
833
|
return self._processed_state.processed_expression
|
|
678
834
|
|
|
679
835
|
@property
|
|
680
836
|
def parameters(self) -> Any:
|
|
681
837
|
"""Get merged parameters."""
|
|
682
|
-
# For executemany operations, return the original parameters list
|
|
683
838
|
if self._is_many and self._original_parameters is not None:
|
|
684
839
|
return self._original_parameters
|
|
685
840
|
|
|
841
|
+
if (
|
|
842
|
+
self._original_parameters is not None
|
|
843
|
+
and isinstance(self._original_parameters, tuple)
|
|
844
|
+
and not self._named_params
|
|
845
|
+
):
|
|
846
|
+
return self._original_parameters
|
|
847
|
+
|
|
686
848
|
self._ensure_processed()
|
|
687
|
-
|
|
688
|
-
|
|
849
|
+
if self._processed_state is None:
|
|
850
|
+
msg = "Failed to process SQL statement"
|
|
851
|
+
raise RuntimeError(msg)
|
|
852
|
+
params = self._processed_state.merged_parameters
|
|
853
|
+
if params is None:
|
|
854
|
+
return {}
|
|
855
|
+
return params
|
|
689
856
|
|
|
690
857
|
@property
|
|
691
858
|
def is_many(self) -> bool:
|
|
@@ -697,66 +864,173 @@ class SQL:
|
|
|
697
864
|
"""Check if this is a script."""
|
|
698
865
|
return self._is_script
|
|
699
866
|
|
|
700
|
-
|
|
867
|
+
@property
|
|
868
|
+
def dialect(self) -> "Optional[DialectType]":
|
|
869
|
+
"""Get the SQL dialect."""
|
|
870
|
+
return self._dialect
|
|
871
|
+
|
|
872
|
+
def to_sql(self, placeholder_style: "Optional[str]" = None) -> "str":
|
|
701
873
|
"""Convert to SQL string with given placeholder style."""
|
|
702
874
|
if self._is_script:
|
|
703
875
|
return self.sql
|
|
704
876
|
sql, _ = self.compile(placeholder_style=placeholder_style)
|
|
705
877
|
return sql
|
|
706
878
|
|
|
707
|
-
def get_parameters(self, style: Optional[str] = None) -> Any:
|
|
879
|
+
def get_parameters(self, style: "Optional[str]" = None) -> Any:
|
|
708
880
|
"""Get parameters in the requested style."""
|
|
709
|
-
# Get compiled parameters with style
|
|
710
881
|
_, params = self.compile(placeholder_style=style)
|
|
711
882
|
return params
|
|
712
883
|
|
|
713
|
-
def
|
|
884
|
+
def _compile_execute_many(self, placeholder_style: "Optional[str]") -> "tuple[str, Any]":
|
|
885
|
+
"""Handle compilation for execute_many operations."""
|
|
886
|
+
sql = self.sql
|
|
887
|
+
|
|
888
|
+
self._ensure_processed()
|
|
889
|
+
|
|
890
|
+
params = self._original_parameters
|
|
891
|
+
|
|
892
|
+
extracted_params = self._get_extracted_parameters()
|
|
893
|
+
|
|
894
|
+
if extracted_params:
|
|
895
|
+
params = self._merge_extracted_params_with_sets(params, extracted_params)
|
|
896
|
+
|
|
897
|
+
if placeholder_style:
|
|
898
|
+
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
899
|
+
|
|
900
|
+
return sql, params
|
|
901
|
+
|
|
902
|
+
def _get_extracted_parameters(self) -> "list[Any]":
|
|
903
|
+
"""Get extracted parameters from pipeline processing."""
|
|
904
|
+
extracted_params = []
|
|
905
|
+
if self._processed_state and self._processed_state.merged_parameters:
|
|
906
|
+
merged = self._processed_state.merged_parameters
|
|
907
|
+
if isinstance(merged, list):
|
|
908
|
+
if merged and not isinstance(merged[0], (tuple, list)):
|
|
909
|
+
extracted_params = merged
|
|
910
|
+
elif self._processing_context and self._processing_context.extracted_parameters_from_pipeline:
|
|
911
|
+
extracted_params = self._processing_context.extracted_parameters_from_pipeline
|
|
912
|
+
return extracted_params
|
|
913
|
+
|
|
914
|
+
def _merge_extracted_params_with_sets(self, params: Any, extracted_params: "list[Any]") -> "list[tuple[Any, ...]]":
|
|
915
|
+
"""Merge extracted parameters with each parameter set."""
|
|
916
|
+
enhanced_params = []
|
|
917
|
+
for param_set in params:
|
|
918
|
+
if isinstance(param_set, (list, tuple)):
|
|
919
|
+
extracted_values = []
|
|
920
|
+
for extracted in extracted_params:
|
|
921
|
+
if has_parameter_value(extracted):
|
|
922
|
+
extracted_values.append(extracted.value)
|
|
923
|
+
else:
|
|
924
|
+
extracted_values.append(extracted)
|
|
925
|
+
enhanced_set = list(param_set) + extracted_values
|
|
926
|
+
enhanced_params.append(tuple(enhanced_set))
|
|
927
|
+
else:
|
|
928
|
+
extracted_values = []
|
|
929
|
+
for extracted in extracted_params:
|
|
930
|
+
if has_parameter_value(extracted):
|
|
931
|
+
extracted_values.append(extracted.value)
|
|
932
|
+
else:
|
|
933
|
+
extracted_values.append(extracted)
|
|
934
|
+
enhanced_params.append((param_set, *extracted_values))
|
|
935
|
+
return enhanced_params
|
|
936
|
+
|
|
937
|
+
def compile(self, placeholder_style: "Optional[str]" = None) -> "tuple[str, Any]":
|
|
714
938
|
"""Compile to SQL and parameters."""
|
|
715
|
-
# For scripts, return raw SQL directly without processing
|
|
716
939
|
if self._is_script:
|
|
717
940
|
return self.sql, None
|
|
718
941
|
|
|
719
|
-
# For executemany operations with original parameters, handle specially
|
|
720
942
|
if self._is_many and self._original_parameters is not None:
|
|
721
|
-
|
|
722
|
-
sql = self.sql # This will ensure processing if needed
|
|
723
|
-
params = self._original_parameters
|
|
724
|
-
|
|
725
|
-
# Convert placeholder style if requested
|
|
726
|
-
if placeholder_style:
|
|
727
|
-
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
728
|
-
|
|
729
|
-
return sql, params
|
|
943
|
+
return self._compile_execute_many(placeholder_style)
|
|
730
944
|
|
|
731
|
-
# If parsing is disabled, return raw SQL without transformation
|
|
732
945
|
if not self._config.enable_parsing and self._raw_sql:
|
|
733
946
|
return self._raw_sql, self._raw_parameters
|
|
734
947
|
|
|
735
|
-
# Ensure processed
|
|
736
948
|
self._ensure_processed()
|
|
737
949
|
|
|
738
|
-
|
|
739
|
-
|
|
950
|
+
if self._processed_state is None:
|
|
951
|
+
msg = "Failed to process SQL statement"
|
|
952
|
+
raise RuntimeError(msg)
|
|
740
953
|
sql = self._processed_state.processed_sql
|
|
741
954
|
params = self._processed_state.merged_parameters
|
|
742
955
|
|
|
743
|
-
|
|
744
|
-
if params is not None and hasattr(self, "_processing_context") and self._processing_context:
|
|
956
|
+
if params is not None and self._processing_context:
|
|
745
957
|
parameter_mapping = self._processing_context.metadata.get("parameter_position_mapping")
|
|
746
958
|
if parameter_mapping:
|
|
747
|
-
# Apply parameter reordering based on the mapping
|
|
748
959
|
params = self._reorder_parameters(params, parameter_mapping)
|
|
749
960
|
|
|
750
|
-
#
|
|
961
|
+
# Handle denormalization if needed
|
|
962
|
+
if self._processing_context and self._processing_context.parameter_normalization:
|
|
963
|
+
norm_state = self._processing_context.parameter_normalization
|
|
964
|
+
|
|
965
|
+
# If original SQL had incompatible styles, denormalize back to the original style
|
|
966
|
+
# when no specific style requested OR when the requested style matches the original
|
|
967
|
+
if norm_state.was_normalized and norm_state.original_styles:
|
|
968
|
+
original_style = norm_state.original_styles[0]
|
|
969
|
+
should_denormalize = placeholder_style is None or (
|
|
970
|
+
placeholder_style and ParameterStyle(placeholder_style) == original_style
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
if should_denormalize and original_style in SQLGLOT_INCOMPATIBLE_STYLES:
|
|
974
|
+
# Denormalize SQL back to original style
|
|
975
|
+
sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
976
|
+
sql, norm_state.original_param_info, original_style
|
|
977
|
+
)
|
|
978
|
+
# Also denormalize parameters if needed
|
|
979
|
+
if original_style == ParameterStyle.POSITIONAL_COLON and is_dict(params):
|
|
980
|
+
params = self._denormalize_colon_params(params)
|
|
981
|
+
|
|
982
|
+
params = self._unwrap_typed_parameters(params)
|
|
983
|
+
|
|
751
984
|
if placeholder_style is None:
|
|
752
985
|
return sql, params
|
|
753
986
|
|
|
754
|
-
# Convert to requested placeholder style
|
|
755
987
|
if placeholder_style:
|
|
756
|
-
sql, params = self.
|
|
988
|
+
sql, params = self._apply_placeholder_style(sql, params, placeholder_style)
|
|
989
|
+
|
|
990
|
+
return sql, params
|
|
757
991
|
|
|
992
|
+
def _apply_placeholder_style(self, sql: "str", params: Any, placeholder_style: "str") -> "tuple[str, Any]":
|
|
993
|
+
"""Apply placeholder style conversion to SQL and parameters."""
|
|
994
|
+
# Just use the params passed in - they've already been processed
|
|
995
|
+
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
758
996
|
return sql, params
|
|
759
997
|
|
|
998
|
+
@staticmethod
|
|
999
|
+
def _unwrap_typed_parameters(params: Any) -> Any:
|
|
1000
|
+
"""Unwrap TypedParameter objects to their actual values.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
params: Parameters that may contain TypedParameter objects
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
Parameters with TypedParameter objects unwrapped to their values
|
|
1007
|
+
"""
|
|
1008
|
+
if params is None:
|
|
1009
|
+
return None
|
|
1010
|
+
|
|
1011
|
+
if is_dict(params):
|
|
1012
|
+
unwrapped_dict = {}
|
|
1013
|
+
for key, value in params.items():
|
|
1014
|
+
if has_parameter_value(value):
|
|
1015
|
+
unwrapped_dict[key] = value.value
|
|
1016
|
+
else:
|
|
1017
|
+
unwrapped_dict[key] = value
|
|
1018
|
+
return unwrapped_dict
|
|
1019
|
+
|
|
1020
|
+
if isinstance(params, (list, tuple)):
|
|
1021
|
+
unwrapped_list = []
|
|
1022
|
+
for value in params:
|
|
1023
|
+
if has_parameter_value(value):
|
|
1024
|
+
unwrapped_list.append(value.value)
|
|
1025
|
+
else:
|
|
1026
|
+
unwrapped_list.append(value)
|
|
1027
|
+
return type(params)(unwrapped_list)
|
|
1028
|
+
|
|
1029
|
+
if has_parameter_value(params):
|
|
1030
|
+
return params.value
|
|
1031
|
+
|
|
1032
|
+
return params
|
|
1033
|
+
|
|
760
1034
|
@staticmethod
|
|
761
1035
|
def _reorder_parameters(params: Any, mapping: dict[int, int]) -> Any:
|
|
762
1036
|
"""Reorder parameters based on the position mapping.
|
|
@@ -769,43 +1043,34 @@ class SQL:
|
|
|
769
1043
|
Reordered parameters in the same format as input
|
|
770
1044
|
"""
|
|
771
1045
|
if isinstance(params, (list, tuple)):
|
|
772
|
-
# Create a new list with reordered parameters
|
|
773
1046
|
reordered_list = [None] * len(params) # pyright: ignore
|
|
774
1047
|
for new_pos, old_pos in mapping.items():
|
|
775
1048
|
if old_pos < len(params):
|
|
776
1049
|
reordered_list[new_pos] = params[old_pos] # pyright: ignore
|
|
777
1050
|
|
|
778
|
-
# Handle any unmapped positions
|
|
779
1051
|
for i, val in enumerate(reordered_list):
|
|
780
1052
|
if val is None and i < len(params) and i not in mapping:
|
|
781
|
-
# If position wasn't mapped, try to use original
|
|
782
1053
|
reordered_list[i] = params[i] # pyright: ignore
|
|
783
1054
|
|
|
784
|
-
# Return in same format as input
|
|
785
1055
|
return tuple(reordered_list) if isinstance(params, tuple) else reordered_list
|
|
786
1056
|
|
|
787
|
-
if
|
|
788
|
-
|
|
789
|
-
# If keys are like param_0, param_1, we can reorder them
|
|
790
|
-
if all(key.startswith("param_") and key[6:].isdigit() for key in params):
|
|
1057
|
+
if is_dict(params):
|
|
1058
|
+
if all(key.startswith(PARAM_PREFIX) and key[len(PARAM_PREFIX) :].isdigit() for key in params):
|
|
791
1059
|
reordered_dict: dict[str, Any] = {}
|
|
792
1060
|
for new_pos, old_pos in mapping.items():
|
|
793
|
-
old_key = f"
|
|
794
|
-
new_key = f"
|
|
1061
|
+
old_key = f"{PARAM_PREFIX}{old_pos}"
|
|
1062
|
+
new_key = f"{PARAM_PREFIX}{new_pos}"
|
|
795
1063
|
if old_key in params:
|
|
796
1064
|
reordered_dict[new_key] = params[old_key]
|
|
797
1065
|
|
|
798
|
-
# Add any unmapped parameters
|
|
799
1066
|
for key, value in params.items():
|
|
800
|
-
if key not in reordered_dict and key.startswith(
|
|
1067
|
+
if key not in reordered_dict and key.startswith(PARAM_PREFIX):
|
|
801
1068
|
idx = int(key[6:])
|
|
802
1069
|
if idx not in mapping:
|
|
803
1070
|
reordered_dict[key] = value
|
|
804
1071
|
|
|
805
1072
|
return reordered_dict
|
|
806
|
-
# Can't reorder named parameters, return as-is
|
|
807
1073
|
return params
|
|
808
|
-
# Single value or unknown format, return as-is
|
|
809
1074
|
return params
|
|
810
1075
|
|
|
811
1076
|
def _convert_placeholder_style(self, sql: str, params: Any, placeholder_style: str) -> tuple[str, Any]:
|
|
@@ -819,45 +1084,119 @@ class SQL:
|
|
|
819
1084
|
Returns:
|
|
820
1085
|
Tuple of (converted_sql, converted_params)
|
|
821
1086
|
"""
|
|
822
|
-
# Handle execute_many case where params is a list of parameter sets
|
|
823
1087
|
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
824
|
-
# For execute_many, we only need to convert the SQL once
|
|
825
|
-
# The parameters remain as a list of tuples
|
|
826
1088
|
converter = self._config.parameter_converter
|
|
827
1089
|
param_info = converter.validator.extract_parameters(sql)
|
|
828
1090
|
|
|
829
1091
|
if param_info:
|
|
830
|
-
from sqlspec.statement.parameters import ParameterStyle
|
|
831
|
-
|
|
832
1092
|
target_style = (
|
|
833
1093
|
ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
834
1094
|
)
|
|
835
1095
|
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
836
1096
|
|
|
837
|
-
# Parameters remain as list of tuples for execute_many
|
|
838
1097
|
return sql, params
|
|
839
1098
|
|
|
840
|
-
# Extract parameter info from current SQL
|
|
841
1099
|
converter = self._config.parameter_converter
|
|
842
|
-
|
|
1100
|
+
|
|
1101
|
+
# For POSITIONAL_COLON style, use original parameter info if available to preserve numeric identifiers
|
|
1102
|
+
target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
1103
|
+
if (
|
|
1104
|
+
target_style == ParameterStyle.POSITIONAL_COLON
|
|
1105
|
+
and self._processing_context
|
|
1106
|
+
and self._processing_context.parameter_normalization
|
|
1107
|
+
and self._processing_context.parameter_normalization.original_param_info
|
|
1108
|
+
):
|
|
1109
|
+
param_info = self._processing_context.parameter_normalization.original_param_info
|
|
1110
|
+
else:
|
|
1111
|
+
param_info = converter.validator.extract_parameters(sql)
|
|
1112
|
+
|
|
1113
|
+
# CRITICAL FIX: For POSITIONAL_COLON, we need to ensure param_info reflects
|
|
1114
|
+
# all placeholders in the current SQL, not just the original ones.
|
|
1115
|
+
# This handles cases where transformers (like ParameterizeLiterals) add new placeholders.
|
|
1116
|
+
if target_style == ParameterStyle.POSITIONAL_COLON and param_info:
|
|
1117
|
+
# Re-extract from current SQL to get all placeholders
|
|
1118
|
+
current_param_info = converter.validator.extract_parameters(sql)
|
|
1119
|
+
if len(current_param_info) > len(param_info):
|
|
1120
|
+
# More placeholders in current SQL means transformers added some
|
|
1121
|
+
# Use the current info to ensure all placeholders get parameters
|
|
1122
|
+
param_info = current_param_info
|
|
843
1123
|
|
|
844
1124
|
if not param_info:
|
|
845
1125
|
return sql, params
|
|
846
1126
|
|
|
847
|
-
|
|
848
|
-
|
|
1127
|
+
if target_style == ParameterStyle.STATIC:
|
|
1128
|
+
return self._embed_static_parameters(sql, params, param_info)
|
|
849
1129
|
|
|
850
|
-
|
|
1130
|
+
if param_info and all(p.style == target_style for p in param_info):
|
|
1131
|
+
converted_params = self._convert_parameters_format(params, param_info, target_style)
|
|
1132
|
+
return sql, converted_params
|
|
851
1133
|
|
|
852
|
-
# Replace placeholders in SQL
|
|
853
1134
|
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
854
1135
|
|
|
855
|
-
# Convert parameters to appropriate format
|
|
856
1136
|
params = self._convert_parameters_format(params, param_info, target_style)
|
|
857
1137
|
|
|
858
1138
|
return sql, params
|
|
859
1139
|
|
|
860
|
-
def
|
|
1140
|
+
def _embed_static_parameters(self, sql: str, params: Any, param_info: list[Any]) -> tuple[str, Any]:
|
|
1141
|
+
"""Embed parameter values directly into SQL for STATIC style.
|
|
1142
|
+
|
|
1143
|
+
This is used for scripts and other cases where parameters need to be
|
|
1144
|
+
embedded directly in the SQL string rather than passed separately.
|
|
1145
|
+
|
|
1146
|
+
Args:
|
|
1147
|
+
sql: The SQL string with placeholders
|
|
1148
|
+
params: The parameter values
|
|
1149
|
+
param_info: List of parameter information from extraction
|
|
1150
|
+
|
|
1151
|
+
Returns:
|
|
1152
|
+
Tuple of (sql_with_embedded_values, None)
|
|
1153
|
+
"""
|
|
1154
|
+
param_list: list[Any] = []
|
|
1155
|
+
if is_dict(params):
|
|
1156
|
+
for p in param_info:
|
|
1157
|
+
if p.name and p.name in params:
|
|
1158
|
+
param_list.append(params[p.name])
|
|
1159
|
+
elif f"{PARAM_PREFIX}{p.ordinal}" in params:
|
|
1160
|
+
param_list.append(params[f"{PARAM_PREFIX}{p.ordinal}"])
|
|
1161
|
+
elif f"arg_{p.ordinal}" in params:
|
|
1162
|
+
param_list.append(params[f"arg_{p.ordinal}"])
|
|
1163
|
+
else:
|
|
1164
|
+
param_list.append(params.get(str(p.ordinal), None))
|
|
1165
|
+
elif isinstance(params, (list, tuple)):
|
|
1166
|
+
param_list = list(params)
|
|
1167
|
+
elif params is not None:
|
|
1168
|
+
param_list = [params]
|
|
1169
|
+
|
|
1170
|
+
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
1171
|
+
|
|
1172
|
+
for p in sorted_params:
|
|
1173
|
+
if p.ordinal < len(param_list):
|
|
1174
|
+
value = param_list[p.ordinal]
|
|
1175
|
+
|
|
1176
|
+
if has_parameter_value(value):
|
|
1177
|
+
value = value.value
|
|
1178
|
+
|
|
1179
|
+
if value is None:
|
|
1180
|
+
literal_str = "NULL"
|
|
1181
|
+
elif isinstance(value, bool):
|
|
1182
|
+
literal_str = "TRUE" if value else "FALSE"
|
|
1183
|
+
elif isinstance(value, str):
|
|
1184
|
+
literal_expr = sqlglot.exp.Literal.string(value)
|
|
1185
|
+
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1186
|
+
elif isinstance(value, (int, float)):
|
|
1187
|
+
literal_expr = sqlglot.exp.Literal.number(value)
|
|
1188
|
+
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1189
|
+
else:
|
|
1190
|
+
literal_expr = sqlglot.exp.Literal.string(str(value))
|
|
1191
|
+
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1192
|
+
|
|
1193
|
+
start = p.position
|
|
1194
|
+
end = start + len(p.placeholder_text)
|
|
1195
|
+
sql = sql[:start] + literal_str + sql[end:]
|
|
1196
|
+
|
|
1197
|
+
return sql, None
|
|
1198
|
+
|
|
1199
|
+
def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: ParameterStyle) -> str:
|
|
861
1200
|
"""Replace placeholders in SQL string with target style placeholders.
|
|
862
1201
|
|
|
863
1202
|
Args:
|
|
@@ -868,12 +1207,10 @@ class SQL:
|
|
|
868
1207
|
Returns:
|
|
869
1208
|
SQL string with replaced placeholders
|
|
870
1209
|
"""
|
|
871
|
-
# Sort by position in reverse to avoid position shifts
|
|
872
1210
|
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
873
1211
|
|
|
874
1212
|
for p in sorted_params:
|
|
875
1213
|
new_placeholder = self._generate_placeholder(p, target_style)
|
|
876
|
-
# Replace the placeholder in SQL
|
|
877
1214
|
start = p.position
|
|
878
1215
|
end = start + len(p.placeholder_text)
|
|
879
1216
|
sql = sql[:start] + new_placeholder + sql[end:]
|
|
@@ -881,7 +1218,7 @@ class SQL:
|
|
|
881
1218
|
return sql
|
|
882
1219
|
|
|
883
1220
|
@staticmethod
|
|
884
|
-
def _generate_placeholder(param: Any, target_style:
|
|
1221
|
+
def _generate_placeholder(param: Any, target_style: ParameterStyle) -> str:
|
|
885
1222
|
"""Generate a placeholder string for the given parameter style.
|
|
886
1223
|
|
|
887
1224
|
Args:
|
|
@@ -891,36 +1228,34 @@ class SQL:
|
|
|
891
1228
|
Returns:
|
|
892
1229
|
Placeholder string
|
|
893
1230
|
"""
|
|
894
|
-
if target_style
|
|
1231
|
+
if target_style in {ParameterStyle.STATIC, ParameterStyle.QMARK}:
|
|
895
1232
|
return "?"
|
|
896
1233
|
if target_style == ParameterStyle.NUMERIC:
|
|
897
|
-
# Use 1-based numbering for numeric style
|
|
898
1234
|
return f"${param.ordinal + 1}"
|
|
899
1235
|
if target_style == ParameterStyle.NAMED_COLON:
|
|
900
|
-
# Use original name if available, otherwise generate one
|
|
901
|
-
# Oracle doesn't like underscores at the start of parameter names
|
|
902
1236
|
if param.name and not param.name.isdigit():
|
|
903
|
-
# Use the name if it's not just a number
|
|
904
1237
|
return f":{param.name}"
|
|
905
|
-
# Generate a new name for numeric placeholders or missing names
|
|
906
1238
|
return f":arg_{param.ordinal}"
|
|
907
1239
|
if target_style == ParameterStyle.NAMED_AT:
|
|
908
|
-
# Use @ prefix for BigQuery style
|
|
909
|
-
# BigQuery requires parameter names to start with a letter, not underscore
|
|
910
1240
|
return f"@{param.name or f'param_{param.ordinal}'}"
|
|
911
1241
|
if target_style == ParameterStyle.POSITIONAL_COLON:
|
|
912
|
-
#
|
|
1242
|
+
# For Oracle positional colon, preserve the original numeric identifier if it was already :N style
|
|
1243
|
+
if (
|
|
1244
|
+
hasattr(param, "style")
|
|
1245
|
+
and param.style == ParameterStyle.POSITIONAL_COLON
|
|
1246
|
+
and hasattr(param, "name")
|
|
1247
|
+
and param.name
|
|
1248
|
+
and param.name.isdigit()
|
|
1249
|
+
):
|
|
1250
|
+
return f":{param.name}"
|
|
913
1251
|
return f":{param.ordinal + 1}"
|
|
914
1252
|
if target_style == ParameterStyle.POSITIONAL_PYFORMAT:
|
|
915
|
-
# Use %s for positional pyformat
|
|
916
1253
|
return "%s"
|
|
917
1254
|
if target_style == ParameterStyle.NAMED_PYFORMAT:
|
|
918
|
-
|
|
919
|
-
return f"%({param.name or f'_arg_{param.ordinal}'})s"
|
|
920
|
-
# Keep original for unknown styles
|
|
1255
|
+
return f"%({param.name or f'arg_{param.ordinal}'})s"
|
|
921
1256
|
return str(param.placeholder_text)
|
|
922
1257
|
|
|
923
|
-
def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style:
|
|
1258
|
+
def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: ParameterStyle) -> Any:
|
|
924
1259
|
"""Convert parameters to the appropriate format for the target style.
|
|
925
1260
|
|
|
926
1261
|
Args:
|
|
@@ -941,10 +1276,96 @@ class SQL:
|
|
|
941
1276
|
return self._convert_to_named_pyformat_format(params, param_info)
|
|
942
1277
|
return params
|
|
943
1278
|
|
|
1279
|
+
def _convert_list_to_colon_dict(
|
|
1280
|
+
self, params: "Union[list[Any], tuple[Any, ...]]", param_info: "list[Any]"
|
|
1281
|
+
) -> "dict[str, Any]":
|
|
1282
|
+
"""Convert list/tuple parameters to colon-style dict format."""
|
|
1283
|
+
result_dict: dict[str, Any] = {}
|
|
1284
|
+
|
|
1285
|
+
if param_info:
|
|
1286
|
+
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
1287
|
+
if all_numeric:
|
|
1288
|
+
for i, value in enumerate(params):
|
|
1289
|
+
result_dict[str(i + 1)] = value
|
|
1290
|
+
else:
|
|
1291
|
+
for i, value in enumerate(params):
|
|
1292
|
+
if i < len(param_info):
|
|
1293
|
+
param_name = param_info[i].name or str(i + 1)
|
|
1294
|
+
result_dict[param_name] = value
|
|
1295
|
+
else:
|
|
1296
|
+
result_dict[str(i + 1)] = value
|
|
1297
|
+
else:
|
|
1298
|
+
for i, value in enumerate(params):
|
|
1299
|
+
result_dict[str(i + 1)] = value
|
|
1300
|
+
|
|
1301
|
+
return result_dict
|
|
1302
|
+
|
|
1303
|
+
def _convert_single_value_to_colon_dict(self, params: Any, param_info: "list[Any]") -> "dict[str, Any]":
|
|
1304
|
+
"""Convert single value parameter to colon-style dict format."""
|
|
1305
|
+
result_dict: dict[str, Any] = {}
|
|
1306
|
+
if param_info and param_info[0].name and param_info[0].name.isdigit():
|
|
1307
|
+
result_dict[param_info[0].name] = params
|
|
1308
|
+
else:
|
|
1309
|
+
result_dict["1"] = params
|
|
1310
|
+
return result_dict
|
|
1311
|
+
|
|
1312
|
+
def _process_mixed_colon_params(self, params: "dict[str, Any]", param_info: "list[Any]") -> "dict[str, Any]":
|
|
1313
|
+
"""Process mixed colon-style numeric and normalized parameters."""
|
|
1314
|
+
result_dict: dict[str, Any] = {}
|
|
1315
|
+
|
|
1316
|
+
# When we have mixed parameters (extracted literals + user oracle params),
|
|
1317
|
+
# we need to be careful about the ordering. The extracted literals should
|
|
1318
|
+
# fill positions based on where they appear in the SQL, not based on
|
|
1319
|
+
# matching parameter names.
|
|
1320
|
+
|
|
1321
|
+
# Separate extracted parameters and user oracle parameters
|
|
1322
|
+
extracted_params = []
|
|
1323
|
+
user_oracle_params = {}
|
|
1324
|
+
extracted_keys_sorted = []
|
|
1325
|
+
|
|
1326
|
+
for key, value in params.items():
|
|
1327
|
+
if has_parameter_value(value):
|
|
1328
|
+
extracted_params.append((key, value))
|
|
1329
|
+
elif key.isdigit():
|
|
1330
|
+
user_oracle_params[key] = value
|
|
1331
|
+
elif key.startswith("param_") and key[6:].isdigit():
|
|
1332
|
+
param_idx = int(key[6:])
|
|
1333
|
+
oracle_key = str(param_idx + 1)
|
|
1334
|
+
if oracle_key not in user_oracle_params:
|
|
1335
|
+
extracted_keys_sorted.append((param_idx, key, value))
|
|
1336
|
+
else:
|
|
1337
|
+
extracted_params.append((key, value))
|
|
1338
|
+
|
|
1339
|
+
extracted_keys_sorted.sort(key=operator.itemgetter(0))
|
|
1340
|
+
for _, key, value in extracted_keys_sorted:
|
|
1341
|
+
extracted_params.append((key, value))
|
|
1342
|
+
|
|
1343
|
+
# Build lists of parameter values in order
|
|
1344
|
+
extracted_values = []
|
|
1345
|
+
for _, value in extracted_params:
|
|
1346
|
+
if has_parameter_value(value):
|
|
1347
|
+
extracted_values.append(value.value)
|
|
1348
|
+
else:
|
|
1349
|
+
extracted_values.append(value)
|
|
1350
|
+
|
|
1351
|
+
user_values = [user_oracle_params[key] for key in sorted(user_oracle_params.keys(), key=int)]
|
|
1352
|
+
|
|
1353
|
+
# Now assign parameters based on position
|
|
1354
|
+
# Extracted parameters go first (they were literals in original positions)
|
|
1355
|
+
# User parameters follow
|
|
1356
|
+
all_values = extracted_values + user_values
|
|
1357
|
+
|
|
1358
|
+
for i, p in enumerate(sorted(param_info, key=lambda x: x.ordinal)):
|
|
1359
|
+
oracle_key = str(p.ordinal + 1)
|
|
1360
|
+
if i < len(all_values):
|
|
1361
|
+
result_dict[oracle_key] = all_values[i]
|
|
1362
|
+
|
|
1363
|
+
return result_dict
|
|
1364
|
+
|
|
944
1365
|
def _convert_to_positional_colon_format(self, params: Any, param_info: list[Any]) -> Any:
|
|
945
|
-
"""Convert to dict format for
|
|
1366
|
+
"""Convert to dict format for positional colon style.
|
|
946
1367
|
|
|
947
|
-
|
|
1368
|
+
Positional colon style uses :1, :2, etc. placeholders and expects
|
|
948
1369
|
parameters as a dict with string keys "1", "2", etc.
|
|
949
1370
|
|
|
950
1371
|
For execute_many operations, returns a list of parameter sets.
|
|
@@ -956,68 +1377,76 @@ class SQL:
|
|
|
956
1377
|
Returns:
|
|
957
1378
|
Dict of parameters with string keys "1", "2", etc., or list for execute_many
|
|
958
1379
|
"""
|
|
959
|
-
# Special handling for execute_many
|
|
960
1380
|
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
961
|
-
# This is execute_many - keep as list but process each item
|
|
962
1381
|
return params
|
|
963
1382
|
|
|
964
|
-
result_dict: dict[str, Any] = {}
|
|
965
|
-
|
|
966
1383
|
if isinstance(params, (list, tuple)):
|
|
967
|
-
|
|
968
|
-
if param_info:
|
|
969
|
-
# Check if all param names are numeric (positional colon style)
|
|
970
|
-
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
971
|
-
if all_numeric:
|
|
972
|
-
# Sort param_info by numeric name to match list order
|
|
973
|
-
sorted_params = sorted(param_info, key=lambda p: int(p.name))
|
|
974
|
-
for i, value in enumerate(params):
|
|
975
|
-
if i < len(sorted_params):
|
|
976
|
-
# Map based on numeric order, not SQL appearance order
|
|
977
|
-
param_name = sorted_params[i].name
|
|
978
|
-
result_dict[param_name] = value
|
|
979
|
-
else:
|
|
980
|
-
# Extra parameters
|
|
981
|
-
result_dict[str(i + 1)] = value
|
|
982
|
-
else:
|
|
983
|
-
# Non-numeric names, map by ordinal
|
|
984
|
-
for i, value in enumerate(params):
|
|
985
|
-
if i < len(param_info):
|
|
986
|
-
param_name = param_info[i].name or str(i + 1)
|
|
987
|
-
result_dict[param_name] = value
|
|
988
|
-
else:
|
|
989
|
-
result_dict[str(i + 1)] = value
|
|
990
|
-
else:
|
|
991
|
-
# No param_info, default to 1-based indexing
|
|
992
|
-
for i, value in enumerate(params):
|
|
993
|
-
result_dict[str(i + 1)] = value
|
|
994
|
-
return result_dict
|
|
1384
|
+
return self._convert_list_to_colon_dict(params, param_info)
|
|
995
1385
|
|
|
996
1386
|
if not is_dict(params) and param_info:
|
|
997
|
-
|
|
998
|
-
if param_info and param_info[0].name and param_info[0].name.isdigit():
|
|
999
|
-
# Use the actual parameter name from SQL (e.g., "0")
|
|
1000
|
-
result_dict[param_info[0].name] = params
|
|
1001
|
-
else:
|
|
1002
|
-
# Default to "1"
|
|
1003
|
-
result_dict["1"] = params
|
|
1004
|
-
return result_dict
|
|
1387
|
+
return self._convert_single_value_to_colon_dict(params, param_info)
|
|
1005
1388
|
|
|
1006
1389
|
if is_dict(params):
|
|
1007
|
-
# Check if already in correct format (keys are "1", "2", etc.)
|
|
1008
1390
|
if all(key.isdigit() for key in params):
|
|
1009
1391
|
return params
|
|
1010
1392
|
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1393
|
+
if all(key.startswith("param_") for key in params):
|
|
1394
|
+
param_result_dict: dict[str, Any] = {}
|
|
1395
|
+
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
1396
|
+
# Use the parameter's ordinal to find the normalized key
|
|
1397
|
+
normalized_key = f"param_{p.ordinal}"
|
|
1398
|
+
if normalized_key in params:
|
|
1399
|
+
if p.name and p.name.isdigit():
|
|
1400
|
+
# For Oracle numeric parameters, preserve the original number
|
|
1401
|
+
param_result_dict[p.name] = params[normalized_key]
|
|
1402
|
+
else:
|
|
1403
|
+
# For other cases, use sequential numbering
|
|
1404
|
+
param_result_dict[str(p.ordinal + 1)] = params[normalized_key]
|
|
1405
|
+
return param_result_dict
|
|
1406
|
+
|
|
1407
|
+
has_oracle_numeric = any(key.isdigit() for key in params)
|
|
1408
|
+
has_param_normalized = any(key.startswith("param_") for key in params)
|
|
1409
|
+
has_typed_params = any(has_parameter_value(v) for v in params.values())
|
|
1410
|
+
|
|
1411
|
+
if (has_oracle_numeric and has_param_normalized) or has_typed_params:
|
|
1412
|
+
return self._process_mixed_colon_params(params, param_info)
|
|
1413
|
+
|
|
1414
|
+
result_dict: dict[str, Any] = {}
|
|
1415
|
+
|
|
1416
|
+
if param_info:
|
|
1417
|
+
# Process all parameters in order of their ordinals
|
|
1418
|
+
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
1419
|
+
oracle_key = str(p.ordinal + 1)
|
|
1420
|
+
value = None
|
|
1421
|
+
|
|
1422
|
+
# Try different ways to find the parameter value
|
|
1423
|
+
if p.name and (
|
|
1424
|
+
p.name in params
|
|
1425
|
+
or (p.name.isdigit() and p.name in params)
|
|
1426
|
+
or (p.name.startswith("param_") and p.name in params)
|
|
1427
|
+
):
|
|
1428
|
+
value = params[p.name]
|
|
1429
|
+
|
|
1430
|
+
# If not found by name, try by ordinal-based keys
|
|
1431
|
+
if value is None:
|
|
1432
|
+
# Try param_N format (common for pipeline parameters)
|
|
1433
|
+
param_key = f"param_{p.ordinal}"
|
|
1434
|
+
if param_key in params:
|
|
1435
|
+
value = params[param_key]
|
|
1436
|
+
# Try arg_N format
|
|
1437
|
+
elif f"arg_{p.ordinal}" in params:
|
|
1438
|
+
value = params[f"arg_{p.ordinal}"]
|
|
1439
|
+
# For positional colon, also check if there's a numeric key
|
|
1440
|
+
# that matches the ordinal position
|
|
1441
|
+
elif str(p.ordinal + 1) in params:
|
|
1442
|
+
value = params[str(p.ordinal + 1)]
|
|
1443
|
+
|
|
1444
|
+
# Unwrap TypedParameter if needed
|
|
1445
|
+
if value is not None:
|
|
1446
|
+
if has_parameter_value(value):
|
|
1447
|
+
value = value.value
|
|
1448
|
+
result_dict[oracle_key] = value
|
|
1449
|
+
|
|
1021
1450
|
return result_dict
|
|
1022
1451
|
|
|
1023
1452
|
return params
|
|
@@ -1035,33 +1464,79 @@ class SQL:
|
|
|
1035
1464
|
"""
|
|
1036
1465
|
result_list: list[Any] = []
|
|
1037
1466
|
if is_dict(params):
|
|
1467
|
+
param_values_by_ordinal: dict[int, Any] = {}
|
|
1468
|
+
|
|
1038
1469
|
for p in param_info:
|
|
1039
1470
|
if p.name and p.name in params:
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1471
|
+
param_values_by_ordinal[p.ordinal] = params[p.name]
|
|
1472
|
+
|
|
1473
|
+
for p in param_info:
|
|
1474
|
+
if p.name is None and p.ordinal not in param_values_by_ordinal:
|
|
1475
|
+
arg_key = f"arg_{p.ordinal}"
|
|
1476
|
+
param_key = f"param_{p.ordinal}"
|
|
1477
|
+
if arg_key in params:
|
|
1478
|
+
param_values_by_ordinal[p.ordinal] = params[arg_key]
|
|
1479
|
+
elif param_key in params:
|
|
1480
|
+
param_values_by_ordinal[p.ordinal] = params[param_key]
|
|
1481
|
+
|
|
1482
|
+
remaining_params = {
|
|
1483
|
+
k: v
|
|
1484
|
+
for k, v in params.items()
|
|
1485
|
+
if k not in {p.name for p in param_info if p.name} and not k.startswith(("arg_", "param_"))
|
|
1486
|
+
}
|
|
1487
|
+
|
|
1488
|
+
unmatched_ordinals = [p.ordinal for p in param_info if p.ordinal not in param_values_by_ordinal]
|
|
1489
|
+
|
|
1490
|
+
for ordinal, (_, value) in zip(unmatched_ordinals, remaining_params.items()):
|
|
1491
|
+
param_values_by_ordinal[ordinal] = value
|
|
1492
|
+
|
|
1493
|
+
for p in param_info:
|
|
1494
|
+
val = param_values_by_ordinal.get(p.ordinal)
|
|
1495
|
+
if val is not None:
|
|
1496
|
+
if has_parameter_value(val):
|
|
1043
1497
|
result_list.append(val.value)
|
|
1044
1498
|
else:
|
|
1045
1499
|
result_list.append(val)
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1500
|
+
else:
|
|
1501
|
+
result_list.append(None)
|
|
1502
|
+
|
|
1503
|
+
return result_list
|
|
1504
|
+
if isinstance(params, (list, tuple)):
|
|
1505
|
+
# Special case: if params is empty, preserve it (don't create None values)
|
|
1506
|
+
# This is important for execute_many with empty parameter lists
|
|
1507
|
+
if not params:
|
|
1508
|
+
return params
|
|
1509
|
+
|
|
1510
|
+
# Handle mixed parameter styles correctly
|
|
1511
|
+
# For mixed styles, assign parameters in order of appearance, not by numeric reference
|
|
1512
|
+
if param_info and any(p.style == ParameterStyle.NUMERIC for p in param_info):
|
|
1513
|
+
# Create mapping from ordinal to parameter value
|
|
1514
|
+
param_mapping: dict[int, Any] = {}
|
|
1515
|
+
|
|
1516
|
+
# Sort parameter info by position to get order of appearance
|
|
1517
|
+
sorted_params = sorted(param_info, key=lambda p: p.position)
|
|
1518
|
+
|
|
1519
|
+
# Assign parameters sequentially in order of appearance
|
|
1520
|
+
for i, param_info_item in enumerate(sorted_params):
|
|
1521
|
+
if i < len(params):
|
|
1522
|
+
param_mapping[param_info_item.ordinal] = params[i]
|
|
1523
|
+
|
|
1524
|
+
# Build result list ordered by original ordinal values
|
|
1525
|
+
for i in range(len(param_info)):
|
|
1526
|
+
val = param_mapping.get(i)
|
|
1527
|
+
if val is not None:
|
|
1528
|
+
if has_parameter_value(val):
|
|
1053
1529
|
result_list.append(val.value)
|
|
1054
1530
|
else:
|
|
1055
1531
|
result_list.append(val)
|
|
1056
1532
|
else:
|
|
1057
1533
|
result_list.append(None)
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
if isinstance(params, (list, tuple)):
|
|
1534
|
+
|
|
1535
|
+
return result_list
|
|
1536
|
+
|
|
1537
|
+
# Standard conversion for non-mixed styles
|
|
1063
1538
|
for param in params:
|
|
1064
|
-
if
|
|
1539
|
+
if has_parameter_value(param):
|
|
1065
1540
|
result_list.append(param.value)
|
|
1066
1541
|
else:
|
|
1067
1542
|
result_list.append(param)
|
|
@@ -1081,28 +1556,26 @@ class SQL:
|
|
|
1081
1556
|
"""
|
|
1082
1557
|
result_dict: dict[str, Any] = {}
|
|
1083
1558
|
if is_dict(params):
|
|
1084
|
-
# For dict params with matching parameter names, return as-is
|
|
1085
|
-
# Otherwise, remap to match the expected names
|
|
1086
1559
|
if all(p.name in params for p in param_info if p.name):
|
|
1087
1560
|
return params
|
|
1088
1561
|
for p in param_info:
|
|
1089
1562
|
if p.name and p.name in params:
|
|
1090
1563
|
result_dict[p.name] = params[p.name]
|
|
1091
1564
|
elif f"param_{p.ordinal}" in params:
|
|
1092
|
-
# Handle param_N style names
|
|
1093
|
-
# Oracle doesn't like underscores at the start of parameter names
|
|
1094
1565
|
result_dict[p.name or f"arg_{p.ordinal}"] = params[f"param_{p.ordinal}"]
|
|
1095
1566
|
return result_dict
|
|
1096
1567
|
if isinstance(params, (list, tuple)):
|
|
1097
|
-
# Convert list/tuple to dict with parameter names from param_info
|
|
1098
|
-
|
|
1099
1568
|
for i, value in enumerate(params):
|
|
1569
|
+
if has_parameter_value(value):
|
|
1570
|
+
value = value.value
|
|
1571
|
+
|
|
1100
1572
|
if i < len(param_info):
|
|
1101
1573
|
p = param_info[i]
|
|
1102
|
-
# Use the actual parameter name if available
|
|
1103
|
-
# Oracle doesn't like underscores at the start of parameter names
|
|
1104
1574
|
param_name = p.name or f"arg_{i}"
|
|
1105
1575
|
result_dict[param_name] = value
|
|
1576
|
+
else:
|
|
1577
|
+
param_name = f"arg_{i}"
|
|
1578
|
+
result_dict[param_name] = value
|
|
1106
1579
|
return result_dict
|
|
1107
1580
|
return params
|
|
1108
1581
|
|
|
@@ -1118,7 +1591,6 @@ class SQL:
|
|
|
1118
1591
|
Dict of parameters with names
|
|
1119
1592
|
"""
|
|
1120
1593
|
if isinstance(params, (list, tuple)):
|
|
1121
|
-
# Convert list to dict with generated names
|
|
1122
1594
|
result_dict: dict[str, Any] = {}
|
|
1123
1595
|
for i, p in enumerate(param_info):
|
|
1124
1596
|
if i < len(params):
|
|
@@ -1127,14 +1599,15 @@ class SQL:
|
|
|
1127
1599
|
return result_dict
|
|
1128
1600
|
return params
|
|
1129
1601
|
|
|
1130
|
-
# Validation properties for compatibility
|
|
1131
1602
|
@property
|
|
1132
1603
|
def validation_errors(self) -> list[Any]:
|
|
1133
1604
|
"""Get validation errors."""
|
|
1134
1605
|
if not self._config.enable_validation:
|
|
1135
1606
|
return []
|
|
1136
1607
|
self._ensure_processed()
|
|
1137
|
-
|
|
1608
|
+
if not self._processed_state:
|
|
1609
|
+
msg = "Failed to process SQL statement"
|
|
1610
|
+
raise RuntimeError(msg)
|
|
1138
1611
|
return self._processed_state.validation_errors
|
|
1139
1612
|
|
|
1140
1613
|
@property
|
|
@@ -1147,25 +1620,30 @@ class SQL:
|
|
|
1147
1620
|
"""Check if statement is safe."""
|
|
1148
1621
|
return not self.has_errors
|
|
1149
1622
|
|
|
1150
|
-
# Additional compatibility methods
|
|
1151
1623
|
def validate(self) -> list[Any]:
|
|
1152
1624
|
"""Validate the SQL statement and return validation errors."""
|
|
1153
1625
|
return self.validation_errors
|
|
1154
1626
|
|
|
1155
1627
|
@property
|
|
1156
1628
|
def parameter_info(self) -> list[Any]:
|
|
1157
|
-
"""Get parameter information from the SQL statement.
|
|
1629
|
+
"""Get parameter information from the SQL statement.
|
|
1630
|
+
|
|
1631
|
+
Returns the original parameter info before any normalization.
|
|
1632
|
+
"""
|
|
1158
1633
|
validator = self._config.parameter_validator
|
|
1159
|
-
if self.
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1634
|
+
if self._raw_sql:
|
|
1635
|
+
return validator.extract_parameters(self._raw_sql)
|
|
1636
|
+
|
|
1637
|
+
self._ensure_processed()
|
|
1638
|
+
|
|
1639
|
+
if self._processing_context:
|
|
1640
|
+
return self._processing_context.parameter_info
|
|
1641
|
+
|
|
1642
|
+
return []
|
|
1164
1643
|
|
|
1165
1644
|
@property
|
|
1166
1645
|
def _raw_parameters(self) -> Any:
|
|
1167
1646
|
"""Get raw parameters for compatibility."""
|
|
1168
|
-
# Return the original parameters as passed in
|
|
1169
1647
|
return self._original_parameters
|
|
1170
1648
|
|
|
1171
1649
|
@property
|
|
@@ -1174,7 +1652,7 @@ class SQL:
|
|
|
1174
1652
|
return self.sql
|
|
1175
1653
|
|
|
1176
1654
|
@property
|
|
1177
|
-
def _expression(self) -> Optional[exp.Expression]:
|
|
1655
|
+
def _expression(self) -> "Optional[exp.Expression]":
|
|
1178
1656
|
"""Get expression for compatibility."""
|
|
1179
1657
|
return self.expression
|
|
1180
1658
|
|
|
@@ -1186,18 +1664,15 @@ class SQL:
|
|
|
1186
1664
|
def limit(self, count: int, use_parameter: bool = False) -> "SQL":
|
|
1187
1665
|
"""Add LIMIT clause."""
|
|
1188
1666
|
if use_parameter:
|
|
1189
|
-
# Create a unique parameter name
|
|
1190
1667
|
param_name = self.get_unique_parameter_name("limit")
|
|
1191
|
-
# Add parameter to the SQL object
|
|
1192
1668
|
result = self
|
|
1193
1669
|
result = result.add_named_parameter(param_name, count)
|
|
1194
|
-
|
|
1195
|
-
if hasattr(result._statement, "limit"):
|
|
1670
|
+
if supports_limit(result._statement):
|
|
1196
1671
|
new_statement = result._statement.limit(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1197
1672
|
else:
|
|
1198
1673
|
new_statement = exp.Select().from_(result._statement).limit(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1199
1674
|
return result.copy(statement=new_statement)
|
|
1200
|
-
if
|
|
1675
|
+
if supports_limit(self._statement):
|
|
1201
1676
|
new_statement = self._statement.limit(count) # pyright: ignore
|
|
1202
1677
|
else:
|
|
1203
1678
|
new_statement = exp.Select().from_(self._statement).limit(count) # pyright: ignore
|
|
@@ -1206,18 +1681,15 @@ class SQL:
|
|
|
1206
1681
|
def offset(self, count: int, use_parameter: bool = False) -> "SQL":
|
|
1207
1682
|
"""Add OFFSET clause."""
|
|
1208
1683
|
if use_parameter:
|
|
1209
|
-
# Create a unique parameter name
|
|
1210
1684
|
param_name = self.get_unique_parameter_name("offset")
|
|
1211
|
-
# Add parameter to the SQL object
|
|
1212
1685
|
result = self
|
|
1213
1686
|
result = result.add_named_parameter(param_name, count)
|
|
1214
|
-
|
|
1215
|
-
if hasattr(result._statement, "offset"):
|
|
1687
|
+
if supports_offset(result._statement):
|
|
1216
1688
|
new_statement = result._statement.offset(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1217
1689
|
else:
|
|
1218
1690
|
new_statement = exp.Select().from_(result._statement).offset(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1219
1691
|
return result.copy(statement=new_statement)
|
|
1220
|
-
if
|
|
1692
|
+
if supports_offset(self._statement):
|
|
1221
1693
|
new_statement = self._statement.offset(count) # pyright: ignore
|
|
1222
1694
|
else:
|
|
1223
1695
|
new_statement = exp.Select().from_(self._statement).offset(count) # pyright: ignore
|
|
@@ -1225,7 +1697,7 @@ class SQL:
|
|
|
1225
1697
|
|
|
1226
1698
|
def order_by(self, expression: exp.Expression) -> "SQL":
|
|
1227
1699
|
"""Add ORDER BY clause."""
|
|
1228
|
-
if
|
|
1700
|
+
if supports_order_by(self._statement):
|
|
1229
1701
|
new_statement = self._statement.order_by(expression) # pyright: ignore
|
|
1230
1702
|
else:
|
|
1231
1703
|
new_statement = exp.Select().from_(self._statement).order_by(expression) # pyright: ignore
|