sqlspec 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +50 -25
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +1 -3
- sqlspec/_serialization.py +1 -2
- sqlspec/_sql.py +256 -120
- sqlspec/_typing.py +278 -142
- sqlspec/adapters/adbc/__init__.py +4 -3
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +115 -248
- sqlspec/adapters/adbc/driver.py +462 -353
- sqlspec/adapters/aiosqlite/__init__.py +18 -3
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +199 -129
- sqlspec/adapters/aiosqlite/driver.py +230 -269
- sqlspec/adapters/asyncmy/__init__.py +18 -3
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +80 -168
- sqlspec/adapters/asyncmy/driver.py +260 -225
- sqlspec/adapters/asyncpg/__init__.py +19 -4
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +82 -181
- sqlspec/adapters/asyncpg/driver.py +285 -383
- sqlspec/adapters/bigquery/__init__.py +17 -3
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +191 -258
- sqlspec/adapters/bigquery/driver.py +474 -646
- sqlspec/adapters/duckdb/__init__.py +14 -3
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +415 -351
- sqlspec/adapters/duckdb/driver.py +343 -413
- sqlspec/adapters/oracledb/__init__.py +19 -5
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +123 -379
- sqlspec/adapters/oracledb/driver.py +507 -560
- sqlspec/adapters/psqlpy/__init__.py +13 -3
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +93 -254
- sqlspec/adapters/psqlpy/driver.py +505 -234
- sqlspec/adapters/psycopg/__init__.py +19 -5
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +143 -403
- sqlspec/adapters/psycopg/driver.py +706 -872
- sqlspec/adapters/sqlite/__init__.py +14 -3
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +202 -118
- sqlspec/adapters/sqlite/driver.py +264 -303
- sqlspec/base.py +105 -9
- sqlspec/{statement/builder → builder}/__init__.py +12 -14
- sqlspec/{statement/builder → builder}/_base.py +120 -55
- sqlspec/{statement/builder → builder}/_column.py +17 -6
- sqlspec/{statement/builder → builder}/_ddl.py +46 -79
- sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
- sqlspec/{statement/builder → builder}/_delete.py +6 -25
- sqlspec/{statement/builder → builder}/_insert.py +6 -64
- sqlspec/builder/_merge.py +56 -0
- sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
- sqlspec/{statement/builder → builder}/_select.py +11 -56
- sqlspec/{statement/builder → builder}/_update.py +12 -18
- sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
- sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
- sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
- sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
- sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
- sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
- sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
- sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
- sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
- sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
- sqlspec/cli.py +4 -5
- sqlspec/config.py +180 -133
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.py +828 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.py +664 -0
- sqlspec/{statement → core}/splitter.py +321 -191
- sqlspec/core/statement.py +651 -0
- sqlspec/driver/__init__.py +7 -10
- sqlspec/driver/_async.py +387 -176
- sqlspec/driver/_common.py +527 -289
- sqlspec/driver/_sync.py +390 -172
- sqlspec/driver/mixins/__init__.py +2 -19
- sqlspec/driver/mixins/_result_tools.py +168 -0
- sqlspec/driver/mixins/_sql_translator.py +6 -3
- sqlspec/exceptions.py +5 -252
- sqlspec/extensions/aiosql/adapter.py +93 -96
- sqlspec/extensions/litestar/config.py +0 -1
- sqlspec/extensions/litestar/handlers.py +15 -26
- sqlspec/extensions/litestar/plugin.py +16 -14
- sqlspec/extensions/litestar/providers.py +17 -52
- sqlspec/loader.py +424 -105
- sqlspec/migrations/__init__.py +12 -0
- sqlspec/migrations/base.py +92 -68
- sqlspec/migrations/commands.py +24 -106
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +49 -51
- sqlspec/migrations/tracker.py +31 -44
- sqlspec/migrations/utils.py +64 -24
- sqlspec/protocols.py +7 -183
- sqlspec/storage/__init__.py +1 -1
- sqlspec/storage/backends/base.py +37 -40
- sqlspec/storage/backends/fsspec.py +136 -112
- sqlspec/storage/backends/obstore.py +138 -160
- sqlspec/storage/capabilities.py +5 -4
- sqlspec/storage/registry.py +57 -106
- sqlspec/typing.py +136 -115
- sqlspec/utils/__init__.py +2 -3
- sqlspec/utils/correlation.py +0 -3
- sqlspec/utils/deprecation.py +6 -6
- sqlspec/utils/fixtures.py +6 -6
- sqlspec/utils/logging.py +0 -2
- sqlspec/utils/module_loader.py +7 -12
- sqlspec/utils/singleton.py +0 -1
- sqlspec/utils/sync_tools.py +16 -37
- sqlspec/utils/text.py +12 -51
- sqlspec/utils/type_guards.py +443 -232
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
- sqlspec-0.15.0.dist-info/RECORD +134 -0
- sqlspec-0.15.0.dist-info/entry_points.txt +2 -0
- sqlspec/driver/connection.py +0 -207
- sqlspec/driver/mixins/_cache.py +0 -114
- sqlspec/driver/mixins/_csv_writer.py +0 -91
- sqlspec/driver/mixins/_pipeline.py +0 -508
- sqlspec/driver/mixins/_query_tools.py +0 -796
- sqlspec/driver/mixins/_result_utils.py +0 -138
- sqlspec/driver/mixins/_storage.py +0 -912
- sqlspec/driver/mixins/_type_coercion.py +0 -128
- sqlspec/driver/parameters.py +0 -138
- sqlspec/statement/__init__.py +0 -21
- sqlspec/statement/builder/_merge.py +0 -95
- sqlspec/statement/cache.py +0 -50
- sqlspec/statement/filters.py +0 -625
- sqlspec/statement/parameters.py +0 -996
- sqlspec/statement/pipelines/__init__.py +0 -210
- sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
- sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
- sqlspec/statement/pipelines/context.py +0 -115
- sqlspec/statement/pipelines/transformers/__init__.py +0 -7
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
- sqlspec/statement/pipelines/validators/__init__.py +0 -23
- sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
- sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
- sqlspec/statement/pipelines/validators/_performance.py +0 -714
- sqlspec/statement/pipelines/validators/_security.py +0 -967
- sqlspec/statement/result.py +0 -435
- sqlspec/statement/sql.py +0 -1774
- sqlspec/utils/cached_property.py +0 -25
- sqlspec/utils/statement_hashing.py +0 -203
- sqlspec-0.14.0.dist-info/RECORD +0 -143
- sqlspec-0.14.0.dist-info/entry_points.txt +0 -2
- /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.14.0.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/sql.py
DELETED
|
@@ -1,1774 +0,0 @@
|
|
|
1
|
-
"""SQL statement handling with centralized parameter management."""
|
|
2
|
-
|
|
3
|
-
import operator
|
|
4
|
-
from dataclasses import dataclass, field
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
6
|
-
|
|
7
|
-
import sqlglot
|
|
8
|
-
import sqlglot.expressions as exp
|
|
9
|
-
from sqlglot.errors import ParseError
|
|
10
|
-
from typing_extensions import TypeAlias
|
|
11
|
-
|
|
12
|
-
from sqlspec.exceptions import RiskLevel, SQLParsingError, SQLValidationError
|
|
13
|
-
from sqlspec.statement.cache import sql_cache
|
|
14
|
-
from sqlspec.statement.filters import StatementFilter
|
|
15
|
-
from sqlspec.statement.parameters import (
|
|
16
|
-
SQLGLOT_INCOMPATIBLE_STYLES,
|
|
17
|
-
ParameterConverter,
|
|
18
|
-
ParameterStyle,
|
|
19
|
-
ParameterValidator,
|
|
20
|
-
)
|
|
21
|
-
from sqlspec.statement.pipelines import SQLProcessingContext, StatementPipeline
|
|
22
|
-
from sqlspec.statement.pipelines.transformers import CommentAndHintRemover, ParameterizeLiterals
|
|
23
|
-
from sqlspec.statement.pipelines.validators import DMLSafetyValidator, ParameterStyleValidator, SecurityValidator
|
|
24
|
-
from sqlspec.utils.logging import get_logger
|
|
25
|
-
from sqlspec.utils.statement_hashing import hash_sql_statement
|
|
26
|
-
from sqlspec.utils.type_guards import (
|
|
27
|
-
can_append_to_statement,
|
|
28
|
-
can_extract_parameters,
|
|
29
|
-
has_parameter_value,
|
|
30
|
-
has_risk_level,
|
|
31
|
-
is_dict,
|
|
32
|
-
is_expression,
|
|
33
|
-
is_statement_filter,
|
|
34
|
-
supports_where,
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
if TYPE_CHECKING:
|
|
38
|
-
from sqlglot.dialects.dialect import DialectType
|
|
39
|
-
|
|
40
|
-
from sqlspec.statement.parameters import ParameterStyleTransformationState
|
|
41
|
-
|
|
42
|
-
__all__ = ("SQL", "SQLConfig", "Statement")
|
|
43
|
-
|
|
44
|
-
logger = get_logger("sqlspec.statement")
|
|
45
|
-
|
|
46
|
-
Statement: TypeAlias = Union[str, exp.Expression, "SQL"]
|
|
47
|
-
|
|
48
|
-
# Parameter naming constants
|
|
49
|
-
PARAM_PREFIX = "param_"
|
|
50
|
-
POS_PARAM_PREFIX = "pos_param_"
|
|
51
|
-
KW_POS_PARAM_PREFIX = "kw_pos_param_"
|
|
52
|
-
ARG_PREFIX = "arg_"
|
|
53
|
-
|
|
54
|
-
# Cache and limit constants
|
|
55
|
-
DEFAULT_CACHE_SIZE = 1000
|
|
56
|
-
|
|
57
|
-
# Oracle/Colon style parameter constants
|
|
58
|
-
COLON_PARAM_ONE = "1"
|
|
59
|
-
COLON_PARAM_MIN_INDEX = 1
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
@dataclass
|
|
63
|
-
class _ProcessedState:
|
|
64
|
-
"""Cached state from pipeline processing."""
|
|
65
|
-
|
|
66
|
-
processed_expression: exp.Expression
|
|
67
|
-
processed_sql: str
|
|
68
|
-
merged_parameters: Any
|
|
69
|
-
validation_errors: list[Any] = field(default_factory=list)
|
|
70
|
-
analysis_results: dict[str, Any] = field(default_factory=dict)
|
|
71
|
-
transformation_results: dict[str, Any] = field(default_factory=dict)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
@dataclass
|
|
75
|
-
class SQLConfig:
|
|
76
|
-
"""Configuration for SQL statement behavior.
|
|
77
|
-
|
|
78
|
-
Uses conservative defaults that prioritize compatibility and robustness,
|
|
79
|
-
making it easier to work with diverse SQL dialects and complex queries.
|
|
80
|
-
|
|
81
|
-
Pipeline Configuration:
|
|
82
|
-
enable_parsing: Parse SQL strings using sqlglot (default: True)
|
|
83
|
-
enable_validation: Run SQL validators to check for safety issues (default: True)
|
|
84
|
-
enable_transformations: Apply SQL transformers like literal parameterization (default: True)
|
|
85
|
-
enable_analysis: Run SQL analyzers for metadata extraction (default: False)
|
|
86
|
-
enable_expression_simplification: Apply expression simplification transformer (default: False)
|
|
87
|
-
enable_parameter_type_wrapping: Wrap parameters with type information (default: True)
|
|
88
|
-
parse_errors_as_warnings: Treat parse errors as warnings instead of failures (default: True)
|
|
89
|
-
enable_caching: Cache processed SQL statements (default: True)
|
|
90
|
-
|
|
91
|
-
Component Lists (Advanced):
|
|
92
|
-
transformers: Optional list of SQL transformers for explicit staging
|
|
93
|
-
validators: Optional list of SQL validators for explicit staging
|
|
94
|
-
analyzers: Optional list of SQL analyzers for explicit staging
|
|
95
|
-
|
|
96
|
-
Internal Configuration:
|
|
97
|
-
parameter_converter: Handles parameter style conversions
|
|
98
|
-
parameter_validator: Validates parameter usage and styles
|
|
99
|
-
input_sql_had_placeholders: Populated by SQL.__init__ to track original SQL state
|
|
100
|
-
dialect: SQL dialect to use for parsing and generation
|
|
101
|
-
|
|
102
|
-
Parameter Style Configuration:
|
|
103
|
-
allowed_parameter_styles: Allowed parameter styles (e.g., ('qmark', 'named_colon'))
|
|
104
|
-
default_parameter_style: Target parameter style for SQL generation
|
|
105
|
-
allow_mixed_parameter_styles: Whether to allow mixing parameter styles in same query
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
enable_parsing: bool = True
|
|
109
|
-
enable_validation: bool = True
|
|
110
|
-
enable_transformations: bool = True
|
|
111
|
-
enable_analysis: bool = False
|
|
112
|
-
enable_expression_simplification: bool = False
|
|
113
|
-
enable_parameter_type_wrapping: bool = True
|
|
114
|
-
parse_errors_as_warnings: bool = True
|
|
115
|
-
enable_caching: bool = True
|
|
116
|
-
|
|
117
|
-
transformers: "Optional[list[Any]]" = None
|
|
118
|
-
validators: "Optional[list[Any]]" = None
|
|
119
|
-
analyzers: "Optional[list[Any]]" = None
|
|
120
|
-
|
|
121
|
-
parameter_converter: ParameterConverter = field(default_factory=ParameterConverter)
|
|
122
|
-
parameter_validator: ParameterValidator = field(default_factory=ParameterValidator)
|
|
123
|
-
input_sql_had_placeholders: bool = False
|
|
124
|
-
dialect: "Optional[DialectType]" = None
|
|
125
|
-
|
|
126
|
-
allowed_parameter_styles: "Optional[tuple[str, ...]]" = None
|
|
127
|
-
default_parameter_style: "Optional[str]" = None
|
|
128
|
-
allow_mixed_parameter_styles: bool = False
|
|
129
|
-
analyzer_output_handler: "Optional[Callable[[Any], None]]" = None
|
|
130
|
-
|
|
131
|
-
def validate_parameter_style(self, style: "Union[ParameterStyle, str]") -> bool:
|
|
132
|
-
"""Check if a parameter style is allowed.
|
|
133
|
-
|
|
134
|
-
Args:
|
|
135
|
-
style: Parameter style to validate (can be ParameterStyle enum or string)
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
True if the style is allowed, False otherwise
|
|
139
|
-
"""
|
|
140
|
-
if self.allowed_parameter_styles is None:
|
|
141
|
-
return True
|
|
142
|
-
style_str = str(style)
|
|
143
|
-
return style_str in self.allowed_parameter_styles
|
|
144
|
-
|
|
145
|
-
def get_statement_pipeline(self) -> StatementPipeline:
|
|
146
|
-
"""Get the configured statement pipeline.
|
|
147
|
-
|
|
148
|
-
Returns:
|
|
149
|
-
StatementPipeline configured with transformers, validators, and analyzers
|
|
150
|
-
"""
|
|
151
|
-
transformers = []
|
|
152
|
-
if self.transformers is not None:
|
|
153
|
-
transformers = list(self.transformers)
|
|
154
|
-
elif self.enable_transformations:
|
|
155
|
-
placeholder_style = self.default_parameter_style or "?"
|
|
156
|
-
transformers = [CommentAndHintRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
|
|
157
|
-
if self.enable_expression_simplification:
|
|
158
|
-
from sqlspec.statement.pipelines.transformers import ExpressionSimplifier
|
|
159
|
-
|
|
160
|
-
transformers.append(ExpressionSimplifier())
|
|
161
|
-
|
|
162
|
-
validators = []
|
|
163
|
-
if self.validators is not None:
|
|
164
|
-
validators = list(self.validators)
|
|
165
|
-
elif self.enable_validation:
|
|
166
|
-
validators = [
|
|
167
|
-
ParameterStyleValidator(fail_on_violation=not self.parse_errors_as_warnings),
|
|
168
|
-
DMLSafetyValidator(),
|
|
169
|
-
SecurityValidator(),
|
|
170
|
-
]
|
|
171
|
-
|
|
172
|
-
analyzers = []
|
|
173
|
-
if self.analyzers is not None:
|
|
174
|
-
analyzers = list(self.analyzers)
|
|
175
|
-
elif self.enable_analysis:
|
|
176
|
-
from sqlspec.statement.pipelines.analyzers import StatementAnalyzer
|
|
177
|
-
|
|
178
|
-
analyzers = [StatementAnalyzer()]
|
|
179
|
-
|
|
180
|
-
return StatementPipeline(transformers=transformers, validators=validators, analyzers=analyzers) # pyright: ignore
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
def default_analysis_handler(analysis: Any) -> None:
|
|
184
|
-
"""Default handler that logs analysis to debug."""
|
|
185
|
-
logger.debug("SQL Analysis: %s", analysis)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
class SQL:
|
|
189
|
-
"""Immutable SQL statement with centralized parameter management.
|
|
190
|
-
|
|
191
|
-
The SQL class is the single source of truth for:
|
|
192
|
-
- SQL expression/statement
|
|
193
|
-
- Positional parameters
|
|
194
|
-
- Named parameters
|
|
195
|
-
- Applied filters
|
|
196
|
-
|
|
197
|
-
All methods that modify state return new SQL instances.
|
|
198
|
-
"""
|
|
199
|
-
|
|
200
|
-
__slots__ = (
|
|
201
|
-
"_builder_result_type",
|
|
202
|
-
"_config",
|
|
203
|
-
"_dialect",
|
|
204
|
-
"_filters",
|
|
205
|
-
"_is_many",
|
|
206
|
-
"_is_script",
|
|
207
|
-
"_named_params",
|
|
208
|
-
"_original_parameters",
|
|
209
|
-
"_original_sql",
|
|
210
|
-
"_parameter_conversion_state",
|
|
211
|
-
"_placeholder_mapping",
|
|
212
|
-
"_positional_params",
|
|
213
|
-
"_processed_state",
|
|
214
|
-
"_processing_context",
|
|
215
|
-
"_raw_sql",
|
|
216
|
-
"_statement",
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
def __init__(
|
|
220
|
-
self,
|
|
221
|
-
statement: "Union[str, exp.Expression, 'SQL']",
|
|
222
|
-
*parameters: "Union[Any, StatementFilter, list[Union[Any, StatementFilter]]]",
|
|
223
|
-
_dialect: "DialectType" = None,
|
|
224
|
-
_config: "Optional[SQLConfig]" = None,
|
|
225
|
-
_builder_result_type: "Optional[type]" = None,
|
|
226
|
-
_existing_state: "Optional[dict[str, Any]]" = None,
|
|
227
|
-
**kwargs: Any,
|
|
228
|
-
) -> None:
|
|
229
|
-
"""Initialize SQL with centralized parameter management."""
|
|
230
|
-
if "config" in kwargs and _config is None:
|
|
231
|
-
_config = kwargs.pop("config")
|
|
232
|
-
self._config = _config or SQLConfig()
|
|
233
|
-
self._dialect = _dialect or self._config.dialect
|
|
234
|
-
self._builder_result_type = _builder_result_type
|
|
235
|
-
self._processed_state: Optional[_ProcessedState] = None
|
|
236
|
-
self._processing_context: Optional[SQLProcessingContext] = None
|
|
237
|
-
self._positional_params: list[Any] = []
|
|
238
|
-
self._named_params: dict[str, Any] = {}
|
|
239
|
-
self._filters: list[StatementFilter] = []
|
|
240
|
-
self._statement: exp.Expression
|
|
241
|
-
self._raw_sql: str = ""
|
|
242
|
-
self._original_parameters: Any = None
|
|
243
|
-
self._original_sql: str = ""
|
|
244
|
-
self._placeholder_mapping: dict[str, Union[str, int]] = {}
|
|
245
|
-
self._parameter_conversion_state: Optional[ParameterStyleTransformationState] = None
|
|
246
|
-
self._is_many: bool = False
|
|
247
|
-
self._is_script: bool = False
|
|
248
|
-
|
|
249
|
-
if isinstance(statement, SQL):
|
|
250
|
-
self._init_from_sql_object(statement, _dialect, _config or SQLConfig(), _builder_result_type)
|
|
251
|
-
else:
|
|
252
|
-
self._init_from_str_or_expression(statement)
|
|
253
|
-
|
|
254
|
-
if _existing_state:
|
|
255
|
-
self._load_from_existing_state(_existing_state)
|
|
256
|
-
|
|
257
|
-
if not isinstance(statement, SQL) and not _existing_state:
|
|
258
|
-
self._set_original_parameters(*parameters)
|
|
259
|
-
|
|
260
|
-
self._process_parameters(*parameters, **kwargs)
|
|
261
|
-
|
|
262
|
-
def _init_from_sql_object(
|
|
263
|
-
self, statement: "SQL", dialect: "DialectType", config: "SQLConfig", builder_result_type: "Optional[type]"
|
|
264
|
-
) -> None:
|
|
265
|
-
"""Initialize from an existing SQL object."""
|
|
266
|
-
self._statement = statement._statement
|
|
267
|
-
self._dialect = dialect or statement._dialect
|
|
268
|
-
self._config = config or statement._config
|
|
269
|
-
self._builder_result_type = builder_result_type or statement._builder_result_type
|
|
270
|
-
self._is_many = statement._is_many
|
|
271
|
-
self._is_script = statement._is_script
|
|
272
|
-
self._raw_sql = statement._raw_sql
|
|
273
|
-
self._original_parameters = statement._original_parameters
|
|
274
|
-
self._original_sql = statement._original_sql
|
|
275
|
-
self._placeholder_mapping = statement._placeholder_mapping.copy()
|
|
276
|
-
self._parameter_conversion_state = statement._parameter_conversion_state
|
|
277
|
-
self._positional_params.extend(statement._positional_params)
|
|
278
|
-
self._named_params.update(statement._named_params)
|
|
279
|
-
self._filters.extend(statement._filters)
|
|
280
|
-
|
|
281
|
-
def _init_from_str_or_expression(self, statement: "Union[str, exp.Expression]") -> None:
|
|
282
|
-
"""Initialize from a string or expression."""
|
|
283
|
-
if isinstance(statement, str):
|
|
284
|
-
self._raw_sql = statement
|
|
285
|
-
self._statement = self._to_expression(statement)
|
|
286
|
-
else:
|
|
287
|
-
self._raw_sql = statement.sql(dialect=self._dialect)
|
|
288
|
-
self._statement = statement
|
|
289
|
-
|
|
290
|
-
def _load_from_existing_state(self, existing_state: "dict[str, Any]") -> None:
|
|
291
|
-
"""Load state from a dictionary (used by copy)."""
|
|
292
|
-
self._positional_params = list(existing_state.get("positional_params", self._positional_params))
|
|
293
|
-
self._named_params = dict(existing_state.get("named_params", self._named_params))
|
|
294
|
-
self._filters = list(existing_state.get("filters", self._filters))
|
|
295
|
-
self._is_many = existing_state.get("is_many", self._is_many)
|
|
296
|
-
self._is_script = existing_state.get("is_script", self._is_script)
|
|
297
|
-
self._raw_sql = existing_state.get("raw_sql", self._raw_sql)
|
|
298
|
-
self._original_parameters = existing_state.get("original_parameters", self._original_parameters)
|
|
299
|
-
|
|
300
|
-
def _set_original_parameters(self, *parameters: Any) -> None:
|
|
301
|
-
"""Set the original parameters."""
|
|
302
|
-
if not parameters or (len(parameters) == 1 and is_statement_filter(parameters[0])):
|
|
303
|
-
self._original_parameters = None
|
|
304
|
-
elif len(parameters) == 1 and isinstance(parameters[0], (list, tuple)):
|
|
305
|
-
self._original_parameters = parameters[0]
|
|
306
|
-
else:
|
|
307
|
-
self._original_parameters = parameters
|
|
308
|
-
|
|
309
|
-
def _process_parameters(self, *parameters: Any, **kwargs: Any) -> None:
|
|
310
|
-
"""Process and categorize parameters."""
|
|
311
|
-
for param in parameters:
|
|
312
|
-
self._process_parameter_item(param)
|
|
313
|
-
|
|
314
|
-
if "parameters" in kwargs:
|
|
315
|
-
param_value = kwargs.pop("parameters")
|
|
316
|
-
if isinstance(param_value, (list, tuple)):
|
|
317
|
-
self._positional_params.extend(param_value)
|
|
318
|
-
elif is_dict(param_value):
|
|
319
|
-
self._named_params.update(param_value)
|
|
320
|
-
else:
|
|
321
|
-
self._positional_params.append(param_value)
|
|
322
|
-
|
|
323
|
-
self._named_params.update({k: v for k, v in kwargs.items() if not k.startswith("_")})
|
|
324
|
-
|
|
325
|
-
def _cache_key(self) -> str:
|
|
326
|
-
"""Generate a cache key for the current SQL state."""
|
|
327
|
-
return hash_sql_statement(self)
|
|
328
|
-
|
|
329
|
-
def _process_parameter_item(self, item: Any) -> None:
|
|
330
|
-
"""Process a single item from the parameters list."""
|
|
331
|
-
if is_statement_filter(item):
|
|
332
|
-
self._filters.append(item)
|
|
333
|
-
pos_params, named_params = self._extract_filter_parameters(item)
|
|
334
|
-
self._positional_params.extend(pos_params)
|
|
335
|
-
self._named_params.update(named_params)
|
|
336
|
-
elif isinstance(item, list):
|
|
337
|
-
for sub_item in item:
|
|
338
|
-
self._process_parameter_item(sub_item)
|
|
339
|
-
elif is_dict(item):
|
|
340
|
-
self._named_params.update(item)
|
|
341
|
-
elif isinstance(item, tuple):
|
|
342
|
-
self._positional_params.extend(item)
|
|
343
|
-
else:
|
|
344
|
-
self._positional_params.append(item)
|
|
345
|
-
|
|
346
|
-
def _ensure_processed(self) -> None:
|
|
347
|
-
"""Ensure the SQL has been processed through the pipeline (lazy initialization).
|
|
348
|
-
|
|
349
|
-
This method implements the facade pattern with lazy processing.
|
|
350
|
-
It's called by public methods that need processed state.
|
|
351
|
-
"""
|
|
352
|
-
if self._processed_state is not None:
|
|
353
|
-
return
|
|
354
|
-
|
|
355
|
-
# Check cache first if caching is enabled
|
|
356
|
-
cache_key = None
|
|
357
|
-
if self._config.enable_caching:
|
|
358
|
-
cache_key = self._cache_key()
|
|
359
|
-
cached_state = sql_cache.get(cache_key)
|
|
360
|
-
|
|
361
|
-
if cached_state is not None:
|
|
362
|
-
self._processed_state = cached_state
|
|
363
|
-
return
|
|
364
|
-
|
|
365
|
-
final_expr, final_params = self._build_final_state()
|
|
366
|
-
has_placeholders = self._detect_placeholders()
|
|
367
|
-
initial_sql_for_context, final_params = self._prepare_context_sql(final_expr, final_params)
|
|
368
|
-
|
|
369
|
-
context = self._create_processing_context(initial_sql_for_context, final_expr, final_params, has_placeholders)
|
|
370
|
-
result = self._run_pipeline(context)
|
|
371
|
-
|
|
372
|
-
processed_sql, merged_params = self._process_pipeline_result(result, final_params, context)
|
|
373
|
-
|
|
374
|
-
self._finalize_processed_state(result, processed_sql, merged_params)
|
|
375
|
-
|
|
376
|
-
# Store in cache if caching is enabled
|
|
377
|
-
if self._config.enable_caching and cache_key is not None and self._processed_state is not None:
|
|
378
|
-
sql_cache.set(cache_key, self._processed_state)
|
|
379
|
-
|
|
380
|
-
def _detect_placeholders(self) -> bool:
|
|
381
|
-
"""Detect if the raw SQL has placeholders."""
|
|
382
|
-
if self._raw_sql:
|
|
383
|
-
validator = self._config.parameter_validator
|
|
384
|
-
raw_param_info = validator.extract_parameters(self._raw_sql)
|
|
385
|
-
has_placeholders = bool(raw_param_info)
|
|
386
|
-
if has_placeholders:
|
|
387
|
-
self._config.input_sql_had_placeholders = True
|
|
388
|
-
return has_placeholders
|
|
389
|
-
return self._config.input_sql_had_placeholders
|
|
390
|
-
|
|
391
|
-
def _prepare_context_sql(self, final_expr: exp.Expression, final_params: Any) -> tuple[str, Any]:
|
|
392
|
-
"""Prepare SQL string and parameters for context."""
|
|
393
|
-
initial_sql_for_context = self._raw_sql or final_expr.sql(dialect=self._dialect or self._config.dialect)
|
|
394
|
-
|
|
395
|
-
if is_expression(final_expr) and self._placeholder_mapping:
|
|
396
|
-
initial_sql_for_context = final_expr.sql(dialect=self._dialect or self._config.dialect)
|
|
397
|
-
if self._placeholder_mapping:
|
|
398
|
-
final_params = self._convert_parameters(final_params)
|
|
399
|
-
|
|
400
|
-
return initial_sql_for_context, final_params
|
|
401
|
-
|
|
402
|
-
def _convert_parameters(self, final_params: Any) -> Any:
|
|
403
|
-
"""Convert parameters based on placeholder mapping."""
|
|
404
|
-
if is_dict(final_params):
|
|
405
|
-
converted_params = {}
|
|
406
|
-
for placeholder_key, original_name in self._placeholder_mapping.items():
|
|
407
|
-
if str(original_name) in final_params:
|
|
408
|
-
converted_params[placeholder_key] = final_params[str(original_name)]
|
|
409
|
-
non_oracle_params = {
|
|
410
|
-
key: value
|
|
411
|
-
for key, value in final_params.items()
|
|
412
|
-
if key not in {str(name) for name in self._placeholder_mapping.values()}
|
|
413
|
-
}
|
|
414
|
-
converted_params.update(non_oracle_params)
|
|
415
|
-
return converted_params
|
|
416
|
-
if isinstance(final_params, (list, tuple)):
|
|
417
|
-
validator = self._config.parameter_validator
|
|
418
|
-
param_info = validator.extract_parameters(self._raw_sql)
|
|
419
|
-
|
|
420
|
-
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
421
|
-
|
|
422
|
-
if all_numeric:
|
|
423
|
-
converted_params = {}
|
|
424
|
-
|
|
425
|
-
min_param_num = min(int(p.name) for p in param_info if p.name)
|
|
426
|
-
|
|
427
|
-
for i, param in enumerate(final_params):
|
|
428
|
-
param_num = str(i + min_param_num)
|
|
429
|
-
converted_params[param_num] = param
|
|
430
|
-
|
|
431
|
-
return converted_params
|
|
432
|
-
converted_params = {}
|
|
433
|
-
for i, param in enumerate(final_params):
|
|
434
|
-
if i < len(param_info):
|
|
435
|
-
placeholder_key = f"{PARAM_PREFIX}{param_info[i].ordinal}"
|
|
436
|
-
converted_params[placeholder_key] = param
|
|
437
|
-
return converted_params
|
|
438
|
-
return final_params
|
|
439
|
-
|
|
440
|
-
def _create_processing_context(
|
|
441
|
-
self, initial_sql_for_context: str, final_expr: exp.Expression, final_params: Any, has_placeholders: bool
|
|
442
|
-
) -> SQLProcessingContext:
|
|
443
|
-
"""Create SQL processing context."""
|
|
444
|
-
context = SQLProcessingContext(
|
|
445
|
-
initial_sql_string=initial_sql_for_context,
|
|
446
|
-
dialect=self._dialect or self._config.dialect,
|
|
447
|
-
config=self._config,
|
|
448
|
-
initial_expression=final_expr,
|
|
449
|
-
current_expression=final_expr,
|
|
450
|
-
merged_parameters=final_params,
|
|
451
|
-
input_sql_had_placeholders=has_placeholders or self._config.input_sql_had_placeholders,
|
|
452
|
-
)
|
|
453
|
-
|
|
454
|
-
if self._placeholder_mapping:
|
|
455
|
-
context.extra_info["placeholder_map"] = self._placeholder_mapping
|
|
456
|
-
|
|
457
|
-
# Set conversion state if available
|
|
458
|
-
if self._parameter_conversion_state:
|
|
459
|
-
context.parameter_conversion = self._parameter_conversion_state
|
|
460
|
-
|
|
461
|
-
validator = self._config.parameter_validator
|
|
462
|
-
context.parameter_info = validator.extract_parameters(context.initial_sql_string)
|
|
463
|
-
|
|
464
|
-
return context
|
|
465
|
-
|
|
466
|
-
def _run_pipeline(self, context: SQLProcessingContext) -> Any:
|
|
467
|
-
"""Run the SQL processing pipeline."""
|
|
468
|
-
pipeline = self._config.get_statement_pipeline()
|
|
469
|
-
result = pipeline.execute_pipeline(context)
|
|
470
|
-
self._processing_context = result.context
|
|
471
|
-
return result
|
|
472
|
-
|
|
473
|
-
def _process_pipeline_result(
|
|
474
|
-
self, result: Any, final_params: Any, context: SQLProcessingContext
|
|
475
|
-
) -> tuple[str, Any]:
|
|
476
|
-
"""Process the result from the pipeline."""
|
|
477
|
-
processed_expr = result.expression
|
|
478
|
-
|
|
479
|
-
if isinstance(processed_expr, exp.Anonymous):
|
|
480
|
-
processed_sql = self._raw_sql or context.initial_sql_string
|
|
481
|
-
else:
|
|
482
|
-
# Use the initial expression that includes filters, not the processed one
|
|
483
|
-
# The processed expression may have lost LIMIT/OFFSET during pipeline processing
|
|
484
|
-
if hasattr(context, "initial_expression") and context.initial_expression != processed_expr:
|
|
485
|
-
# Check if LIMIT/OFFSET was stripped during processing
|
|
486
|
-
has_limit_in_initial = (
|
|
487
|
-
context.initial_expression is not None
|
|
488
|
-
and hasattr(context.initial_expression, "args")
|
|
489
|
-
and "limit" in context.initial_expression.args
|
|
490
|
-
)
|
|
491
|
-
has_limit_in_processed = hasattr(processed_expr, "args") and "limit" in processed_expr.args
|
|
492
|
-
|
|
493
|
-
if has_limit_in_initial and not has_limit_in_processed:
|
|
494
|
-
# Restore LIMIT/OFFSET from initial expression
|
|
495
|
-
processed_expr = context.initial_expression
|
|
496
|
-
|
|
497
|
-
processed_sql = (
|
|
498
|
-
processed_expr.sql(dialect=self._dialect or self._config.dialect, comments=False)
|
|
499
|
-
if processed_expr
|
|
500
|
-
else ""
|
|
501
|
-
)
|
|
502
|
-
if self._placeholder_mapping and self._original_sql:
|
|
503
|
-
processed_sql, result = self._denormalize_sql(processed_sql, result)
|
|
504
|
-
|
|
505
|
-
merged_params = self._merge_pipeline_parameters(result, final_params)
|
|
506
|
-
|
|
507
|
-
return processed_sql, merged_params
|
|
508
|
-
|
|
509
|
-
def _denormalize_sql(self, processed_sql: str, result: Any) -> tuple[str, Any]:
|
|
510
|
-
"""Denormalize SQL back to original parameter style."""
|
|
511
|
-
|
|
512
|
-
original_sql = self._original_sql
|
|
513
|
-
param_info = self._config.parameter_validator.extract_parameters(original_sql)
|
|
514
|
-
target_styles = {p.style for p in param_info}
|
|
515
|
-
if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
|
|
516
|
-
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
517
|
-
processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
|
|
518
|
-
)
|
|
519
|
-
elif ParameterStyle.NAMED_PYFORMAT in target_styles:
|
|
520
|
-
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
521
|
-
processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
|
|
522
|
-
)
|
|
523
|
-
if (
|
|
524
|
-
self._placeholder_mapping
|
|
525
|
-
and result.context.merged_parameters
|
|
526
|
-
and is_dict(result.context.merged_parameters)
|
|
527
|
-
):
|
|
528
|
-
result.context.merged_parameters = self._denormalize_pyformat_params(result.context.merged_parameters)
|
|
529
|
-
elif ParameterStyle.POSITIONAL_COLON in target_styles:
|
|
530
|
-
processed_param_info = self._config.parameter_validator.extract_parameters(processed_sql)
|
|
531
|
-
has_param_placeholders = any(p.name and p.name.startswith(PARAM_PREFIX) for p in processed_param_info)
|
|
532
|
-
|
|
533
|
-
if not has_param_placeholders:
|
|
534
|
-
processed_sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
535
|
-
processed_sql, param_info, ParameterStyle.POSITIONAL_COLON
|
|
536
|
-
)
|
|
537
|
-
if (
|
|
538
|
-
self._placeholder_mapping
|
|
539
|
-
and result.context.merged_parameters
|
|
540
|
-
and is_dict(result.context.merged_parameters)
|
|
541
|
-
):
|
|
542
|
-
result.context.merged_parameters = self._denormalize_colon_params(result.context.merged_parameters)
|
|
543
|
-
|
|
544
|
-
return processed_sql, result
|
|
545
|
-
|
|
546
|
-
def _denormalize_colon_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
|
|
547
|
-
"""Denormalize colon-style parameters back to numeric format."""
|
|
548
|
-
# For positional colon style, all params should have numeric keys
|
|
549
|
-
# Just return the params as-is if they already have the right format
|
|
550
|
-
if all(key.isdigit() for key in params):
|
|
551
|
-
return params
|
|
552
|
-
|
|
553
|
-
# For positional colon, we need ALL parameters in the final result
|
|
554
|
-
# This includes both user parameters and extracted literals
|
|
555
|
-
# We should NOT filter out extracted parameters (param_0, param_1, etc)
|
|
556
|
-
# because they need to be included in the final parameter conversion
|
|
557
|
-
return params
|
|
558
|
-
|
|
559
|
-
def _denormalize_pyformat_params(self, params: "dict[str, Any]") -> "dict[str, Any]":
|
|
560
|
-
"""Denormalize pyformat parameters back to their original names."""
|
|
561
|
-
deconverted_params = {}
|
|
562
|
-
for placeholder_key, original_name in self._placeholder_mapping.items():
|
|
563
|
-
if placeholder_key in params:
|
|
564
|
-
# For pyformat, the original_name is the actual parameter name (e.g., 'max_value')
|
|
565
|
-
deconverted_params[str(original_name)] = params[placeholder_key]
|
|
566
|
-
# Include any parameters that weren't converted
|
|
567
|
-
non_converted_params = {key: value for key, value in params.items() if not key.startswith(PARAM_PREFIX)}
|
|
568
|
-
deconverted_params.update(non_converted_params)
|
|
569
|
-
return deconverted_params
|
|
570
|
-
|
|
571
|
-
def _merge_pipeline_parameters(self, result: Any, final_params: Any) -> Any:
|
|
572
|
-
"""Merge parameters from the pipeline processing."""
|
|
573
|
-
merged_params = result.context.merged_parameters
|
|
574
|
-
|
|
575
|
-
# If we have extracted parameters from the pipeline, only merge them if:
|
|
576
|
-
# 1. We don't already have parameters in merged_params, OR
|
|
577
|
-
# 2. The original params were None and we need to use the extracted ones
|
|
578
|
-
if result.context.extracted_parameters_from_pipeline:
|
|
579
|
-
if merged_params is None:
|
|
580
|
-
# No existing parameters - use the extracted ones
|
|
581
|
-
merged_params = result.context.extracted_parameters_from_pipeline
|
|
582
|
-
elif merged_params == final_params and final_params is None:
|
|
583
|
-
# Both are None, use extracted parameters
|
|
584
|
-
merged_params = result.context.extracted_parameters_from_pipeline
|
|
585
|
-
elif merged_params != result.context.extracted_parameters_from_pipeline:
|
|
586
|
-
# Only merge if the extracted parameters are different from what we already have
|
|
587
|
-
# This prevents the duplication issue where the same parameters get added twice
|
|
588
|
-
if is_dict(merged_params):
|
|
589
|
-
for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
|
|
590
|
-
param_name = f"{PARAM_PREFIX}{i}"
|
|
591
|
-
merged_params[param_name] = param
|
|
592
|
-
elif isinstance(merged_params, (list, tuple)):
|
|
593
|
-
# Only extend if we don't already have these parameters
|
|
594
|
-
# Convert to list and extend with extracted parameters
|
|
595
|
-
if isinstance(merged_params, tuple):
|
|
596
|
-
merged_params = list(merged_params)
|
|
597
|
-
merged_params.extend(result.context.extracted_parameters_from_pipeline)
|
|
598
|
-
else:
|
|
599
|
-
# Single parameter case - convert to list with original + extracted
|
|
600
|
-
merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
|
|
601
|
-
|
|
602
|
-
return merged_params
|
|
603
|
-
|
|
604
|
-
def _finalize_processed_state(self, result: Any, processed_sql: str, merged_params: Any) -> None:
|
|
605
|
-
"""Finalize the processed state."""
|
|
606
|
-
# Wrap parameters with type information if enabled
|
|
607
|
-
if self._config.enable_parameter_type_wrapping and merged_params is not None:
|
|
608
|
-
# Get parameter info from the processed SQL
|
|
609
|
-
validator = self._config.parameter_validator
|
|
610
|
-
param_info = validator.extract_parameters(processed_sql)
|
|
611
|
-
|
|
612
|
-
# Wrap parameters with type information
|
|
613
|
-
converter = self._config.parameter_converter
|
|
614
|
-
merged_params = converter.wrap_parameters_with_types(merged_params, param_info)
|
|
615
|
-
|
|
616
|
-
# Extract analyzer results from context metadata
|
|
617
|
-
analysis_results = (
|
|
618
|
-
{key: value for key, value in result.context.metadata.items() if key.endswith("Analyzer")}
|
|
619
|
-
if result.context.metadata
|
|
620
|
-
else {}
|
|
621
|
-
)
|
|
622
|
-
|
|
623
|
-
# If analyzer output handler is configured, call it with the analysis
|
|
624
|
-
if self._config.analyzer_output_handler and analysis_results:
|
|
625
|
-
# Create a structured analysis object from the metadata
|
|
626
|
-
|
|
627
|
-
# Extract the main analyzer results
|
|
628
|
-
analyzer_metadata = analysis_results.get("StatementAnalyzer", {})
|
|
629
|
-
if analyzer_metadata:
|
|
630
|
-
# Create a simplified analysis object for the handler
|
|
631
|
-
analysis = {
|
|
632
|
-
"statement_type": analyzer_metadata.get("statement_type"),
|
|
633
|
-
"complexity_score": analyzer_metadata.get("complexity_score"),
|
|
634
|
-
"table_count": analyzer_metadata.get("table_count"),
|
|
635
|
-
"has_subqueries": analyzer_metadata.get("has_subqueries"),
|
|
636
|
-
"join_count": analyzer_metadata.get("join_count"),
|
|
637
|
-
"duration_ms": analyzer_metadata.get("duration_ms"),
|
|
638
|
-
}
|
|
639
|
-
self._config.analyzer_output_handler(analysis)
|
|
640
|
-
|
|
641
|
-
self._processed_state = _ProcessedState(
|
|
642
|
-
processed_expression=result.expression,
|
|
643
|
-
processed_sql=processed_sql,
|
|
644
|
-
merged_parameters=merged_params,
|
|
645
|
-
validation_errors=list(result.context.validation_errors),
|
|
646
|
-
analysis_results=analysis_results,
|
|
647
|
-
transformation_results={},
|
|
648
|
-
)
|
|
649
|
-
|
|
650
|
-
if not self._config.parse_errors_as_warnings and self._processed_state.validation_errors:
|
|
651
|
-
highest_risk_error = max(
|
|
652
|
-
self._processed_state.validation_errors, key=lambda e: e.risk_level.value if has_risk_level(e) else 0
|
|
653
|
-
)
|
|
654
|
-
raise SQLValidationError(
|
|
655
|
-
message=highest_risk_error.message,
|
|
656
|
-
sql=self._raw_sql or processed_sql,
|
|
657
|
-
risk_level=getattr(highest_risk_error, "risk_level", RiskLevel.HIGH),
|
|
658
|
-
)
|
|
659
|
-
|
|
660
|
-
def _to_expression(self, statement: "Union[str, exp.Expression]") -> exp.Expression:
|
|
661
|
-
"""Convert string to sqlglot expression."""
|
|
662
|
-
if is_expression(statement):
|
|
663
|
-
return statement
|
|
664
|
-
|
|
665
|
-
if not statement or (isinstance(statement, str) and not statement.strip()):
|
|
666
|
-
return exp.Select()
|
|
667
|
-
|
|
668
|
-
if not self._config.enable_parsing:
|
|
669
|
-
return exp.Anonymous(this=statement)
|
|
670
|
-
|
|
671
|
-
if not isinstance(statement, str):
|
|
672
|
-
return exp.Anonymous(this="")
|
|
673
|
-
validator = self._config.parameter_validator
|
|
674
|
-
param_info = validator.extract_parameters(statement)
|
|
675
|
-
|
|
676
|
-
# Check if conversion is needed
|
|
677
|
-
needs_conversion = any(p.style in SQLGLOT_INCOMPATIBLE_STYLES for p in param_info)
|
|
678
|
-
|
|
679
|
-
converted_sql = statement
|
|
680
|
-
placeholder_mapping: dict[str, Any] = {}
|
|
681
|
-
|
|
682
|
-
if needs_conversion:
|
|
683
|
-
converter = self._config.parameter_converter
|
|
684
|
-
converted_sql, placeholder_mapping = converter._transform_sql_for_parsing(statement, param_info)
|
|
685
|
-
self._original_sql = statement
|
|
686
|
-
self._placeholder_mapping = placeholder_mapping
|
|
687
|
-
|
|
688
|
-
# Create conversion state
|
|
689
|
-
from sqlspec.statement.parameters import ParameterStyleTransformationState
|
|
690
|
-
|
|
691
|
-
self._parameter_conversion_state = ParameterStyleTransformationState(
|
|
692
|
-
was_transformed=True,
|
|
693
|
-
original_styles=list({p.style for p in param_info}),
|
|
694
|
-
transformation_style=ParameterStyle.NAMED_COLON,
|
|
695
|
-
placeholder_map=placeholder_mapping,
|
|
696
|
-
original_param_info=param_info,
|
|
697
|
-
)
|
|
698
|
-
else:
|
|
699
|
-
self._parameter_conversion_state = None
|
|
700
|
-
|
|
701
|
-
try:
|
|
702
|
-
expressions = sqlglot.parse(converted_sql, dialect=self._dialect) # pyright: ignore
|
|
703
|
-
if not expressions:
|
|
704
|
-
return exp.Anonymous(this=statement)
|
|
705
|
-
first_expr = expressions[0]
|
|
706
|
-
if first_expr is None:
|
|
707
|
-
return exp.Anonymous(this=statement)
|
|
708
|
-
|
|
709
|
-
except ParseError as e:
|
|
710
|
-
if getattr(self._config, "parse_errors_as_warnings", False):
|
|
711
|
-
logger.warning(
|
|
712
|
-
"Failed to parse SQL, returning Anonymous expression.", extra={"sql": statement, "error": str(e)}
|
|
713
|
-
)
|
|
714
|
-
return exp.Anonymous(this=statement)
|
|
715
|
-
|
|
716
|
-
msg = f"Failed to parse SQL: {statement}"
|
|
717
|
-
raise SQLParsingError(msg) from e
|
|
718
|
-
return first_expr
|
|
719
|
-
|
|
720
|
-
@staticmethod
|
|
721
|
-
def _extract_filter_parameters(filter_obj: StatementFilter) -> tuple[list[Any], dict[str, Any]]:
|
|
722
|
-
"""Extract parameters from a filter object."""
|
|
723
|
-
if can_extract_parameters(filter_obj):
|
|
724
|
-
return filter_obj.extract_parameters()
|
|
725
|
-
return [], {}
|
|
726
|
-
|
|
727
|
-
def copy(
|
|
728
|
-
self,
|
|
729
|
-
statement: "Optional[Union[str, exp.Expression]]" = None,
|
|
730
|
-
parameters: "Optional[Any]" = None,
|
|
731
|
-
dialect: "DialectType" = None,
|
|
732
|
-
config: "Optional[SQLConfig]" = None,
|
|
733
|
-
**kwargs: Any,
|
|
734
|
-
) -> "SQL":
|
|
735
|
-
"""Create a copy with optional modifications.
|
|
736
|
-
|
|
737
|
-
This is the primary method for creating modified SQL objects.
|
|
738
|
-
"""
|
|
739
|
-
existing_state = {
|
|
740
|
-
"positional_params": list(self._positional_params),
|
|
741
|
-
"named_params": dict(self._named_params),
|
|
742
|
-
"filters": list(self._filters),
|
|
743
|
-
"is_many": self._is_many,
|
|
744
|
-
"is_script": self._is_script,
|
|
745
|
-
"raw_sql": self._raw_sql,
|
|
746
|
-
}
|
|
747
|
-
existing_state["original_parameters"] = self._original_parameters
|
|
748
|
-
|
|
749
|
-
new_statement = statement if statement is not None else self._statement
|
|
750
|
-
new_dialect = dialect if dialect is not None else self._dialect
|
|
751
|
-
new_config = config if config is not None else self._config
|
|
752
|
-
|
|
753
|
-
if parameters is not None:
|
|
754
|
-
existing_state["positional_params"] = []
|
|
755
|
-
existing_state["named_params"] = {}
|
|
756
|
-
return SQL(
|
|
757
|
-
new_statement,
|
|
758
|
-
parameters,
|
|
759
|
-
_dialect=new_dialect,
|
|
760
|
-
_config=new_config,
|
|
761
|
-
_builder_result_type=self._builder_result_type,
|
|
762
|
-
_existing_state=None,
|
|
763
|
-
**kwargs,
|
|
764
|
-
)
|
|
765
|
-
|
|
766
|
-
return SQL(
|
|
767
|
-
new_statement,
|
|
768
|
-
_dialect=new_dialect,
|
|
769
|
-
_config=new_config,
|
|
770
|
-
_builder_result_type=self._builder_result_type,
|
|
771
|
-
_existing_state=existing_state,
|
|
772
|
-
**kwargs,
|
|
773
|
-
)
|
|
774
|
-
|
|
775
|
-
def add_named_parameter(self, name: "str", value: Any) -> "SQL":
|
|
776
|
-
"""Add a named parameter and return a new SQL instance."""
|
|
777
|
-
new_obj = self.copy()
|
|
778
|
-
new_obj._named_params[name] = value
|
|
779
|
-
return new_obj
|
|
780
|
-
|
|
781
|
-
def get_unique_parameter_name(
|
|
782
|
-
self, base_name: "str", namespace: "Optional[str]" = None, preserve_original: bool = False
|
|
783
|
-
) -> str:
|
|
784
|
-
"""Generate a unique parameter name.
|
|
785
|
-
|
|
786
|
-
Args:
|
|
787
|
-
base_name: The base parameter name
|
|
788
|
-
namespace: Optional namespace prefix (e.g., 'cte', 'subquery')
|
|
789
|
-
preserve_original: If True, try to preserve the original name
|
|
790
|
-
|
|
791
|
-
Returns:
|
|
792
|
-
A unique parameter name
|
|
793
|
-
"""
|
|
794
|
-
all_param_names = set(self._named_params.keys())
|
|
795
|
-
|
|
796
|
-
candidate = f"{namespace}_{base_name}" if namespace else base_name
|
|
797
|
-
|
|
798
|
-
if preserve_original and candidate not in all_param_names:
|
|
799
|
-
return candidate
|
|
800
|
-
|
|
801
|
-
if candidate not in all_param_names:
|
|
802
|
-
return candidate
|
|
803
|
-
|
|
804
|
-
counter = 1
|
|
805
|
-
while True:
|
|
806
|
-
new_candidate = f"{candidate}_{counter}"
|
|
807
|
-
if new_candidate not in all_param_names:
|
|
808
|
-
return new_candidate
|
|
809
|
-
counter += 1
|
|
810
|
-
|
|
811
|
-
def where(self, condition: "Union[str, exp.Expression, exp.Condition]") -> "SQL":
|
|
812
|
-
"""Apply WHERE clause and return new SQL instance."""
|
|
813
|
-
condition_expr = self._to_expression(condition) if isinstance(condition, str) else condition
|
|
814
|
-
|
|
815
|
-
if supports_where(self._statement):
|
|
816
|
-
new_statement = self._statement.where(condition_expr) # pyright: ignore
|
|
817
|
-
else:
|
|
818
|
-
new_statement = exp.Select().from_(self._statement).where(condition_expr) # pyright: ignore
|
|
819
|
-
|
|
820
|
-
return self.copy(statement=new_statement)
|
|
821
|
-
|
|
822
|
-
def filter(self, filter_obj: StatementFilter) -> "SQL":
|
|
823
|
-
"""Apply a filter and return a new SQL instance."""
|
|
824
|
-
new_obj = self.copy()
|
|
825
|
-
new_obj._filters.append(filter_obj)
|
|
826
|
-
pos_params, named_params = self._extract_filter_parameters(filter_obj)
|
|
827
|
-
new_obj._positional_params.extend(pos_params)
|
|
828
|
-
new_obj._named_params.update(named_params)
|
|
829
|
-
return new_obj
|
|
830
|
-
|
|
831
|
-
def as_many(self, parameters: "Optional[list[Any]]" = None) -> "SQL":
|
|
832
|
-
"""Mark for executemany with optional parameters."""
|
|
833
|
-
new_obj = self.copy()
|
|
834
|
-
new_obj._is_many = True
|
|
835
|
-
if parameters is not None:
|
|
836
|
-
new_obj._positional_params = []
|
|
837
|
-
new_obj._named_params = {}
|
|
838
|
-
new_obj._original_parameters = parameters
|
|
839
|
-
return new_obj
|
|
840
|
-
|
|
841
|
-
def as_script(self) -> "SQL":
|
|
842
|
-
"""Mark as script for execution."""
|
|
843
|
-
new_obj = self.copy()
|
|
844
|
-
new_obj._is_script = True
|
|
845
|
-
return new_obj
|
|
846
|
-
|
|
847
|
-
def _build_final_state(self) -> tuple[exp.Expression, Any]:
|
|
848
|
-
"""Build final expression and parameters after applying filters."""
|
|
849
|
-
final_expr = self._statement
|
|
850
|
-
|
|
851
|
-
# Accumulate parameters from both the original SQL and filters
|
|
852
|
-
accumulated_positional = list(self._positional_params)
|
|
853
|
-
accumulated_named = dict(self._named_params)
|
|
854
|
-
|
|
855
|
-
for filter_obj in self._filters:
|
|
856
|
-
if can_append_to_statement(filter_obj):
|
|
857
|
-
temp_sql = SQL(final_expr, config=self._config, dialect=self._dialect)
|
|
858
|
-
temp_sql._positional_params = list(accumulated_positional)
|
|
859
|
-
temp_sql._named_params = dict(accumulated_named)
|
|
860
|
-
result = filter_obj.append_to_statement(temp_sql)
|
|
861
|
-
|
|
862
|
-
if isinstance(result, SQL):
|
|
863
|
-
# Extract the modified expression
|
|
864
|
-
final_expr = result._statement
|
|
865
|
-
# Also preserve any parameters added by the filter
|
|
866
|
-
accumulated_positional = list(result._positional_params)
|
|
867
|
-
accumulated_named = dict(result._named_params)
|
|
868
|
-
else:
|
|
869
|
-
final_expr = result
|
|
870
|
-
|
|
871
|
-
final_params: Any
|
|
872
|
-
if accumulated_named and not accumulated_positional:
|
|
873
|
-
final_params = dict(accumulated_named)
|
|
874
|
-
elif accumulated_positional and not accumulated_named:
|
|
875
|
-
final_params = list(accumulated_positional)
|
|
876
|
-
elif accumulated_positional and accumulated_named:
|
|
877
|
-
final_params = dict(accumulated_named)
|
|
878
|
-
for i, param in enumerate(accumulated_positional):
|
|
879
|
-
param_name = f"arg_{i}"
|
|
880
|
-
while param_name in final_params:
|
|
881
|
-
param_name = f"arg_{i}_{id(param)}"
|
|
882
|
-
final_params[param_name] = param
|
|
883
|
-
else:
|
|
884
|
-
final_params = None
|
|
885
|
-
|
|
886
|
-
return final_expr, final_params
|
|
887
|
-
|
|
888
|
-
@property
|
|
889
|
-
def sql(self) -> str:
|
|
890
|
-
"""Get SQL string."""
|
|
891
|
-
if not self._raw_sql or (self._raw_sql and not self._raw_sql.strip()):
|
|
892
|
-
return ""
|
|
893
|
-
|
|
894
|
-
if self._is_script and self._raw_sql:
|
|
895
|
-
return self._raw_sql
|
|
896
|
-
if not self._config.enable_parsing and self._raw_sql:
|
|
897
|
-
return self._raw_sql
|
|
898
|
-
|
|
899
|
-
self._ensure_processed()
|
|
900
|
-
if self._processed_state is None:
|
|
901
|
-
msg = "Failed to process SQL statement"
|
|
902
|
-
raise RuntimeError(msg)
|
|
903
|
-
return self._processed_state.processed_sql
|
|
904
|
-
|
|
905
|
-
@property
|
|
906
|
-
def config(self) -> "SQLConfig":
|
|
907
|
-
"""Get the SQL configuration."""
|
|
908
|
-
return self._config
|
|
909
|
-
|
|
910
|
-
@property
|
|
911
|
-
def expression(self) -> "Optional[exp.Expression]":
|
|
912
|
-
"""Get the final expression."""
|
|
913
|
-
if not self._config.enable_parsing:
|
|
914
|
-
return None
|
|
915
|
-
self._ensure_processed()
|
|
916
|
-
if self._processed_state is None:
|
|
917
|
-
msg = "Failed to process SQL statement"
|
|
918
|
-
raise RuntimeError(msg)
|
|
919
|
-
return self._processed_state.processed_expression
|
|
920
|
-
|
|
921
|
-
@property
|
|
922
|
-
def parameters(self) -> Any:
|
|
923
|
-
"""Get merged parameters."""
|
|
924
|
-
if self._is_many and self._original_parameters is not None:
|
|
925
|
-
return self._original_parameters
|
|
926
|
-
|
|
927
|
-
if (
|
|
928
|
-
self._original_parameters is not None
|
|
929
|
-
and isinstance(self._original_parameters, tuple)
|
|
930
|
-
and not self._named_params
|
|
931
|
-
):
|
|
932
|
-
return self._original_parameters
|
|
933
|
-
|
|
934
|
-
self._ensure_processed()
|
|
935
|
-
if self._processed_state is None:
|
|
936
|
-
msg = "Failed to process SQL statement"
|
|
937
|
-
raise RuntimeError(msg)
|
|
938
|
-
params = self._processed_state.merged_parameters
|
|
939
|
-
if params is None:
|
|
940
|
-
return {}
|
|
941
|
-
return params
|
|
942
|
-
|
|
943
|
-
@property
|
|
944
|
-
def is_many(self) -> bool:
|
|
945
|
-
"""Check if this is for executemany."""
|
|
946
|
-
return self._is_many
|
|
947
|
-
|
|
948
|
-
@property
|
|
949
|
-
def is_script(self) -> bool:
|
|
950
|
-
"""Check if this is a script."""
|
|
951
|
-
return self._is_script
|
|
952
|
-
|
|
953
|
-
@property
|
|
954
|
-
def dialect(self) -> "Optional[DialectType]":
|
|
955
|
-
"""Get the SQL dialect."""
|
|
956
|
-
return self._dialect
|
|
957
|
-
|
|
958
|
-
def to_sql(self, placeholder_style: "Optional[str]" = None) -> "str":
|
|
959
|
-
"""Convert to SQL string with given placeholder style."""
|
|
960
|
-
if self._is_script:
|
|
961
|
-
return self.sql
|
|
962
|
-
sql, _ = self.compile(placeholder_style=placeholder_style)
|
|
963
|
-
return sql
|
|
964
|
-
|
|
965
|
-
def get_parameters(self, style: "Optional[str]" = None) -> Any:
|
|
966
|
-
"""Get parameters in the requested style."""
|
|
967
|
-
_, params = self.compile(placeholder_style=style)
|
|
968
|
-
return params
|
|
969
|
-
|
|
970
|
-
def _compile_execute_many(self, placeholder_style: "Optional[str]") -> "tuple[str, Any]":
|
|
971
|
-
"""Compile for execute_many operations.
|
|
972
|
-
|
|
973
|
-
The pipeline processed the first parameter set to extract literals.
|
|
974
|
-
Now we need to apply those extracted literals to all parameter sets.
|
|
975
|
-
"""
|
|
976
|
-
sql = self.sql
|
|
977
|
-
self._ensure_processed()
|
|
978
|
-
|
|
979
|
-
# Get the original parameter sets
|
|
980
|
-
param_sets = self._original_parameters or []
|
|
981
|
-
|
|
982
|
-
# Get any literals extracted during pipeline processing
|
|
983
|
-
if self._processed_state and self._processing_context:
|
|
984
|
-
extracted_literals = self._processing_context.extracted_parameters_from_pipeline
|
|
985
|
-
|
|
986
|
-
if extracted_literals:
|
|
987
|
-
# Apply extracted literals to each parameter set
|
|
988
|
-
enhanced_params: list[Any] = []
|
|
989
|
-
for param_set in param_sets:
|
|
990
|
-
if isinstance(param_set, (list, tuple)):
|
|
991
|
-
# Add extracted literals to the parameter tuple
|
|
992
|
-
enhanced_set = list(param_set) + [
|
|
993
|
-
p.value if hasattr(p, "value") else p for p in extracted_literals
|
|
994
|
-
]
|
|
995
|
-
enhanced_params.append(tuple(enhanced_set))
|
|
996
|
-
elif isinstance(param_set, dict):
|
|
997
|
-
# For dict params, add extracted literals with generated names
|
|
998
|
-
enhanced_dict = dict(param_set)
|
|
999
|
-
for i, literal in enumerate(extracted_literals):
|
|
1000
|
-
param_name = f"_literal_{i}"
|
|
1001
|
-
enhanced_dict[param_name] = literal.value if hasattr(literal, "value") else literal
|
|
1002
|
-
enhanced_params.append(enhanced_dict)
|
|
1003
|
-
else:
|
|
1004
|
-
# Single parameter - convert to tuple with literals
|
|
1005
|
-
literals = [p.value if hasattr(p, "value") else p for p in extracted_literals]
|
|
1006
|
-
enhanced_params.append((param_set, *literals))
|
|
1007
|
-
param_sets = enhanced_params
|
|
1008
|
-
|
|
1009
|
-
if placeholder_style:
|
|
1010
|
-
sql, param_sets = self._convert_placeholder_style(sql, param_sets, placeholder_style)
|
|
1011
|
-
|
|
1012
|
-
return sql, param_sets
|
|
1013
|
-
|
|
1014
|
-
def _get_extracted_parameters(self) -> "list[Any]":
|
|
1015
|
-
"""Get extracted parameters from pipeline processing."""
|
|
1016
|
-
extracted_params = []
|
|
1017
|
-
if self._processed_state and self._processed_state.merged_parameters:
|
|
1018
|
-
merged = self._processed_state.merged_parameters
|
|
1019
|
-
if isinstance(merged, list):
|
|
1020
|
-
if merged and not isinstance(merged[0], (tuple, list)):
|
|
1021
|
-
extracted_params = merged
|
|
1022
|
-
elif self._processing_context and self._processing_context.extracted_parameters_from_pipeline:
|
|
1023
|
-
extracted_params = self._processing_context.extracted_parameters_from_pipeline
|
|
1024
|
-
return extracted_params
|
|
1025
|
-
|
|
1026
|
-
def _merge_extracted_params_with_sets(self, params: Any, extracted_params: "list[Any]") -> "list[tuple[Any, ...]]":
|
|
1027
|
-
"""Merge extracted parameters with each parameter set."""
|
|
1028
|
-
enhanced_params = []
|
|
1029
|
-
for param_set in params:
|
|
1030
|
-
if isinstance(param_set, (list, tuple)):
|
|
1031
|
-
extracted_values = []
|
|
1032
|
-
for extracted in extracted_params:
|
|
1033
|
-
if has_parameter_value(extracted):
|
|
1034
|
-
extracted_values.append(extracted.value)
|
|
1035
|
-
else:
|
|
1036
|
-
extracted_values.append(extracted)
|
|
1037
|
-
enhanced_set = list(param_set) + extracted_values
|
|
1038
|
-
enhanced_params.append(tuple(enhanced_set))
|
|
1039
|
-
else:
|
|
1040
|
-
extracted_values = []
|
|
1041
|
-
for extracted in extracted_params:
|
|
1042
|
-
if has_parameter_value(extracted):
|
|
1043
|
-
extracted_values.append(extracted.value)
|
|
1044
|
-
else:
|
|
1045
|
-
extracted_values.append(extracted)
|
|
1046
|
-
enhanced_params.append((param_set, *extracted_values))
|
|
1047
|
-
return enhanced_params
|
|
1048
|
-
|
|
1049
|
-
def compile(self, placeholder_style: "Optional[str]" = None) -> "tuple[str, Any]":
|
|
1050
|
-
"""Compile to SQL and parameters."""
|
|
1051
|
-
if self._is_script:
|
|
1052
|
-
return self.sql, None
|
|
1053
|
-
|
|
1054
|
-
if self._is_many and self._original_parameters is not None:
|
|
1055
|
-
return self._compile_execute_many(placeholder_style)
|
|
1056
|
-
|
|
1057
|
-
if not self._config.enable_parsing and self._raw_sql:
|
|
1058
|
-
return self._raw_sql, self._raw_parameters
|
|
1059
|
-
|
|
1060
|
-
self._ensure_processed()
|
|
1061
|
-
|
|
1062
|
-
if self._processed_state is None:
|
|
1063
|
-
msg = "Failed to process SQL statement"
|
|
1064
|
-
raise RuntimeError(msg)
|
|
1065
|
-
sql = self._processed_state.processed_sql
|
|
1066
|
-
params = self._processed_state.merged_parameters
|
|
1067
|
-
|
|
1068
|
-
if params is not None and self._processing_context:
|
|
1069
|
-
parameter_mapping = self._processing_context.metadata.get("parameter_position_mapping")
|
|
1070
|
-
if parameter_mapping:
|
|
1071
|
-
params = self._reorder_parameters(params, parameter_mapping)
|
|
1072
|
-
|
|
1073
|
-
# Handle deconversion if needed
|
|
1074
|
-
if self._processing_context and self._processing_context.parameter_conversion:
|
|
1075
|
-
norm_state = self._processing_context.parameter_conversion
|
|
1076
|
-
|
|
1077
|
-
# If original SQL had incompatible styles, denormalize back to the original style
|
|
1078
|
-
# when no specific style requested OR when the requested style matches the original
|
|
1079
|
-
if norm_state.was_transformed and norm_state.original_styles:
|
|
1080
|
-
original_style = norm_state.original_styles[0]
|
|
1081
|
-
should_denormalize = placeholder_style is None or (
|
|
1082
|
-
placeholder_style and ParameterStyle(placeholder_style) == original_style
|
|
1083
|
-
)
|
|
1084
|
-
|
|
1085
|
-
if should_denormalize and original_style in SQLGLOT_INCOMPATIBLE_STYLES:
|
|
1086
|
-
# Denormalize SQL back to original style
|
|
1087
|
-
sql = self._config.parameter_converter._convert_sql_placeholders(
|
|
1088
|
-
sql, norm_state.original_param_info, original_style
|
|
1089
|
-
)
|
|
1090
|
-
# Also deConvert parameters if needed
|
|
1091
|
-
if original_style == ParameterStyle.POSITIONAL_COLON and is_dict(params):
|
|
1092
|
-
params = self._denormalize_colon_params(params)
|
|
1093
|
-
|
|
1094
|
-
params = self._unwrap_typed_parameters(params)
|
|
1095
|
-
|
|
1096
|
-
if placeholder_style is None:
|
|
1097
|
-
return sql, params
|
|
1098
|
-
|
|
1099
|
-
if placeholder_style:
|
|
1100
|
-
sql, params = self._apply_placeholder_style(sql, params, placeholder_style)
|
|
1101
|
-
|
|
1102
|
-
return sql, params
|
|
1103
|
-
|
|
1104
|
-
def _apply_placeholder_style(self, sql: "str", params: Any, placeholder_style: "str") -> "tuple[str, Any]":
|
|
1105
|
-
"""Apply placeholder style conversion to SQL and parameters."""
|
|
1106
|
-
# Just use the params passed in - they've already been processed
|
|
1107
|
-
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
1108
|
-
return sql, params
|
|
1109
|
-
|
|
1110
|
-
@staticmethod
|
|
1111
|
-
def _unwrap_typed_parameters(params: Any) -> Any:
|
|
1112
|
-
"""Unwrap TypedParameter objects to their actual values.
|
|
1113
|
-
|
|
1114
|
-
Args:
|
|
1115
|
-
params: Parameters that may contain TypedParameter objects
|
|
1116
|
-
|
|
1117
|
-
Returns:
|
|
1118
|
-
Parameters with TypedParameter objects unwrapped to their values
|
|
1119
|
-
"""
|
|
1120
|
-
if params is None:
|
|
1121
|
-
return None
|
|
1122
|
-
|
|
1123
|
-
if is_dict(params):
|
|
1124
|
-
unwrapped_dict = {}
|
|
1125
|
-
for key, value in params.items():
|
|
1126
|
-
if has_parameter_value(value):
|
|
1127
|
-
unwrapped_dict[key] = value.value
|
|
1128
|
-
else:
|
|
1129
|
-
unwrapped_dict[key] = value
|
|
1130
|
-
return unwrapped_dict
|
|
1131
|
-
|
|
1132
|
-
if isinstance(params, (list, tuple)):
|
|
1133
|
-
unwrapped_list = []
|
|
1134
|
-
for value in params:
|
|
1135
|
-
if has_parameter_value(value):
|
|
1136
|
-
unwrapped_list.append(value.value)
|
|
1137
|
-
else:
|
|
1138
|
-
unwrapped_list.append(value)
|
|
1139
|
-
return type(params)(unwrapped_list)
|
|
1140
|
-
|
|
1141
|
-
if has_parameter_value(params):
|
|
1142
|
-
return params.value
|
|
1143
|
-
|
|
1144
|
-
return params
|
|
1145
|
-
|
|
1146
|
-
@staticmethod
|
|
1147
|
-
def _reorder_parameters(params: Any, mapping: dict[int, int]) -> Any:
|
|
1148
|
-
"""Reorder parameters based on the position mapping.
|
|
1149
|
-
|
|
1150
|
-
Args:
|
|
1151
|
-
params: Original parameters (list, tuple, or dict)
|
|
1152
|
-
mapping: Dict mapping new positions to original positions
|
|
1153
|
-
|
|
1154
|
-
Returns:
|
|
1155
|
-
Reordered parameters in the same format as input
|
|
1156
|
-
"""
|
|
1157
|
-
if isinstance(params, (list, tuple)):
|
|
1158
|
-
reordered_list = [None] * len(params) # pyright: ignore
|
|
1159
|
-
for new_pos, old_pos in mapping.items():
|
|
1160
|
-
if old_pos < len(params):
|
|
1161
|
-
reordered_list[new_pos] = params[old_pos] # pyright: ignore
|
|
1162
|
-
|
|
1163
|
-
for i, val in enumerate(reordered_list):
|
|
1164
|
-
if val is None and i < len(params) and i not in mapping:
|
|
1165
|
-
reordered_list[i] = params[i] # pyright: ignore
|
|
1166
|
-
|
|
1167
|
-
return tuple(reordered_list) if isinstance(params, tuple) else reordered_list
|
|
1168
|
-
|
|
1169
|
-
if is_dict(params):
|
|
1170
|
-
if all(key.startswith(PARAM_PREFIX) and key[len(PARAM_PREFIX) :].isdigit() for key in params):
|
|
1171
|
-
reordered_dict: dict[str, Any] = {}
|
|
1172
|
-
for new_pos, old_pos in mapping.items():
|
|
1173
|
-
old_key = f"{PARAM_PREFIX}{old_pos}"
|
|
1174
|
-
new_key = f"{PARAM_PREFIX}{new_pos}"
|
|
1175
|
-
if old_key in params:
|
|
1176
|
-
reordered_dict[new_key] = params[old_key]
|
|
1177
|
-
|
|
1178
|
-
for key, value in params.items():
|
|
1179
|
-
if key not in reordered_dict and key.startswith(PARAM_PREFIX):
|
|
1180
|
-
idx = int(key[6:])
|
|
1181
|
-
if idx not in mapping:
|
|
1182
|
-
reordered_dict[key] = value
|
|
1183
|
-
|
|
1184
|
-
return reordered_dict
|
|
1185
|
-
return params
|
|
1186
|
-
return params
|
|
1187
|
-
|
|
1188
|
-
def _convert_placeholder_style(self, sql: str, params: Any, placeholder_style: str) -> tuple[str, Any]:
|
|
1189
|
-
"""Convert SQL and parameters to the requested placeholder style.
|
|
1190
|
-
|
|
1191
|
-
Args:
|
|
1192
|
-
sql: The SQL string to convert
|
|
1193
|
-
params: The parameters to convert
|
|
1194
|
-
placeholder_style: Target placeholder style
|
|
1195
|
-
|
|
1196
|
-
Returns:
|
|
1197
|
-
Tuple of (converted_sql, converted_params)
|
|
1198
|
-
"""
|
|
1199
|
-
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
1200
|
-
converter = self._config.parameter_converter
|
|
1201
|
-
param_info = converter.validator.extract_parameters(sql)
|
|
1202
|
-
|
|
1203
|
-
if param_info:
|
|
1204
|
-
target_style = (
|
|
1205
|
-
ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
1206
|
-
)
|
|
1207
|
-
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
1208
|
-
|
|
1209
|
-
return sql, params
|
|
1210
|
-
|
|
1211
|
-
converter = self._config.parameter_converter
|
|
1212
|
-
|
|
1213
|
-
# For POSITIONAL_COLON style, use original parameter info if available to preserve numeric identifiers
|
|
1214
|
-
target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
1215
|
-
if (
|
|
1216
|
-
target_style == ParameterStyle.POSITIONAL_COLON
|
|
1217
|
-
and self._processing_context
|
|
1218
|
-
and self._processing_context.parameter_conversion
|
|
1219
|
-
and self._processing_context.parameter_conversion.original_param_info
|
|
1220
|
-
):
|
|
1221
|
-
param_info = self._processing_context.parameter_conversion.original_param_info
|
|
1222
|
-
else:
|
|
1223
|
-
param_info = converter.validator.extract_parameters(sql)
|
|
1224
|
-
|
|
1225
|
-
# CRITICAL FIX: For POSITIONAL_COLON, we need to ensure param_info reflects
|
|
1226
|
-
# all placeholders in the current SQL, not just the original ones.
|
|
1227
|
-
# This handles cases where transformers (like ParameterizeLiterals) add new placeholders.
|
|
1228
|
-
if target_style == ParameterStyle.POSITIONAL_COLON and param_info:
|
|
1229
|
-
# Re-extract from current SQL to get all placeholders
|
|
1230
|
-
current_param_info = converter.validator.extract_parameters(sql)
|
|
1231
|
-
if len(current_param_info) > len(param_info):
|
|
1232
|
-
# More placeholders in current SQL means transformers added some
|
|
1233
|
-
# Use the current info to ensure all placeholders get parameters
|
|
1234
|
-
param_info = current_param_info
|
|
1235
|
-
|
|
1236
|
-
if not param_info:
|
|
1237
|
-
return sql, params
|
|
1238
|
-
|
|
1239
|
-
if target_style == ParameterStyle.STATIC:
|
|
1240
|
-
return self._embed_static_parameters(sql, params, param_info)
|
|
1241
|
-
|
|
1242
|
-
if param_info and all(p.style == target_style for p in param_info):
|
|
1243
|
-
converted_params = self._convert_parameters_format(params, param_info, target_style)
|
|
1244
|
-
return sql, converted_params
|
|
1245
|
-
|
|
1246
|
-
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
1247
|
-
|
|
1248
|
-
params = self._convert_parameters_format(params, param_info, target_style)
|
|
1249
|
-
|
|
1250
|
-
return sql, params
|
|
1251
|
-
|
|
1252
|
-
def _embed_static_parameters(self, sql: str, params: Any, param_info: list[Any]) -> tuple[str, Any]:
|
|
1253
|
-
"""Embed parameter values directly into SQL for STATIC style.
|
|
1254
|
-
|
|
1255
|
-
This is used for scripts and other cases where parameters need to be
|
|
1256
|
-
embedded directly in the SQL string rather than passed separately.
|
|
1257
|
-
|
|
1258
|
-
Args:
|
|
1259
|
-
sql: The SQL string with placeholders
|
|
1260
|
-
params: The parameter values
|
|
1261
|
-
param_info: List of parameter information from extraction
|
|
1262
|
-
|
|
1263
|
-
Returns:
|
|
1264
|
-
Tuple of (sql_with_embedded_values, None)
|
|
1265
|
-
"""
|
|
1266
|
-
param_list: list[Any] = []
|
|
1267
|
-
if is_dict(params):
|
|
1268
|
-
for p in param_info:
|
|
1269
|
-
if p.name and p.name in params:
|
|
1270
|
-
param_list.append(params[p.name])
|
|
1271
|
-
elif f"{PARAM_PREFIX}{p.ordinal}" in params:
|
|
1272
|
-
param_list.append(params[f"{PARAM_PREFIX}{p.ordinal}"])
|
|
1273
|
-
elif f"arg_{p.ordinal}" in params:
|
|
1274
|
-
param_list.append(params[f"arg_{p.ordinal}"])
|
|
1275
|
-
else:
|
|
1276
|
-
param_list.append(params.get(str(p.ordinal), None))
|
|
1277
|
-
elif isinstance(params, (list, tuple)):
|
|
1278
|
-
param_list = list(params)
|
|
1279
|
-
elif params is not None:
|
|
1280
|
-
param_list = [params]
|
|
1281
|
-
|
|
1282
|
-
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
1283
|
-
|
|
1284
|
-
for p in sorted_params:
|
|
1285
|
-
if p.ordinal < len(param_list):
|
|
1286
|
-
value = param_list[p.ordinal]
|
|
1287
|
-
|
|
1288
|
-
if has_parameter_value(value):
|
|
1289
|
-
value = value.value
|
|
1290
|
-
|
|
1291
|
-
if value is None:
|
|
1292
|
-
literal_str = "NULL"
|
|
1293
|
-
elif isinstance(value, bool):
|
|
1294
|
-
literal_str = "TRUE" if value else "FALSE"
|
|
1295
|
-
elif isinstance(value, str):
|
|
1296
|
-
literal_expr = sqlglot.exp.Literal.string(value)
|
|
1297
|
-
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1298
|
-
elif isinstance(value, (int, float)):
|
|
1299
|
-
literal_expr = sqlglot.exp.Literal.number(value)
|
|
1300
|
-
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1301
|
-
else:
|
|
1302
|
-
literal_expr = sqlglot.exp.Literal.string(str(value))
|
|
1303
|
-
literal_str = literal_expr.sql(dialect=self._dialect)
|
|
1304
|
-
|
|
1305
|
-
start = p.position
|
|
1306
|
-
end = start + len(p.placeholder_text)
|
|
1307
|
-
sql = sql[:start] + literal_str + sql[end:]
|
|
1308
|
-
|
|
1309
|
-
return sql, None
|
|
1310
|
-
|
|
1311
|
-
def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: ParameterStyle) -> str:
|
|
1312
|
-
"""Replace placeholders in SQL string with target style placeholders.
|
|
1313
|
-
|
|
1314
|
-
Args:
|
|
1315
|
-
sql: The SQL string
|
|
1316
|
-
param_info: List of parameter information
|
|
1317
|
-
target_style: Target parameter style
|
|
1318
|
-
|
|
1319
|
-
Returns:
|
|
1320
|
-
SQL string with replaced placeholders
|
|
1321
|
-
"""
|
|
1322
|
-
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
1323
|
-
|
|
1324
|
-
for p in sorted_params:
|
|
1325
|
-
new_placeholder = self._generate_placeholder(p, target_style)
|
|
1326
|
-
start = p.position
|
|
1327
|
-
end = start + len(p.placeholder_text)
|
|
1328
|
-
sql = sql[:start] + new_placeholder + sql[end:]
|
|
1329
|
-
|
|
1330
|
-
return sql
|
|
1331
|
-
|
|
1332
|
-
@staticmethod
|
|
1333
|
-
def _generate_placeholder(param: Any, target_style: ParameterStyle) -> str:
|
|
1334
|
-
"""Generate a placeholder string for the given parameter style.
|
|
1335
|
-
|
|
1336
|
-
Args:
|
|
1337
|
-
param: Parameter information object
|
|
1338
|
-
target_style: Target parameter style
|
|
1339
|
-
|
|
1340
|
-
Returns:
|
|
1341
|
-
Placeholder string
|
|
1342
|
-
"""
|
|
1343
|
-
if target_style in {ParameterStyle.STATIC, ParameterStyle.QMARK}:
|
|
1344
|
-
return "?"
|
|
1345
|
-
if target_style == ParameterStyle.NUMERIC:
|
|
1346
|
-
return f"${param.ordinal + 1}"
|
|
1347
|
-
if target_style == ParameterStyle.NAMED_COLON:
|
|
1348
|
-
if param.name and not param.name.isdigit():
|
|
1349
|
-
return f":{param.name}"
|
|
1350
|
-
return f":arg_{param.ordinal}"
|
|
1351
|
-
if target_style == ParameterStyle.NAMED_AT:
|
|
1352
|
-
return f"@{param.name or f'param_{param.ordinal}'}"
|
|
1353
|
-
if target_style == ParameterStyle.POSITIONAL_COLON:
|
|
1354
|
-
# For Oracle positional colon, preserve the original numeric identifier if it was already :N style
|
|
1355
|
-
if (
|
|
1356
|
-
hasattr(param, "style")
|
|
1357
|
-
and param.style == ParameterStyle.POSITIONAL_COLON
|
|
1358
|
-
and hasattr(param, "name")
|
|
1359
|
-
and param.name
|
|
1360
|
-
and param.name.isdigit()
|
|
1361
|
-
):
|
|
1362
|
-
return f":{param.name}"
|
|
1363
|
-
return f":{param.ordinal + 1}"
|
|
1364
|
-
if target_style == ParameterStyle.POSITIONAL_PYFORMAT:
|
|
1365
|
-
return "%s"
|
|
1366
|
-
if target_style == ParameterStyle.NAMED_PYFORMAT:
|
|
1367
|
-
return f"%({param.name or f'arg_{param.ordinal}'})s"
|
|
1368
|
-
return str(param.placeholder_text)
|
|
1369
|
-
|
|
1370
|
-
def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: ParameterStyle) -> Any:
|
|
1371
|
-
"""Convert parameters to the appropriate format for the target style.
|
|
1372
|
-
|
|
1373
|
-
Args:
|
|
1374
|
-
params: Original parameters
|
|
1375
|
-
param_info: List of parameter information
|
|
1376
|
-
target_style: Target parameter style
|
|
1377
|
-
|
|
1378
|
-
Returns:
|
|
1379
|
-
Converted parameters
|
|
1380
|
-
"""
|
|
1381
|
-
if target_style == ParameterStyle.POSITIONAL_COLON:
|
|
1382
|
-
return self._convert_to_positional_colon_format(params, param_info)
|
|
1383
|
-
if target_style in {ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT}:
|
|
1384
|
-
return self._convert_to_positional_format(params, param_info)
|
|
1385
|
-
if target_style == ParameterStyle.NAMED_COLON:
|
|
1386
|
-
return self._convert_to_named_colon_format(params, param_info)
|
|
1387
|
-
if target_style == ParameterStyle.NAMED_PYFORMAT:
|
|
1388
|
-
return self._convert_to_named_pyformat_format(params, param_info)
|
|
1389
|
-
return params
|
|
1390
|
-
|
|
1391
|
-
def _convert_list_to_colon_dict(
|
|
1392
|
-
self, params: "Union[list[Any], tuple[Any, ...]]", param_info: "list[Any]"
|
|
1393
|
-
) -> "dict[str, Any]":
|
|
1394
|
-
"""Convert list/tuple parameters to colon-style dict format."""
|
|
1395
|
-
result_dict: dict[str, Any] = {}
|
|
1396
|
-
|
|
1397
|
-
if param_info:
|
|
1398
|
-
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
1399
|
-
if all_numeric:
|
|
1400
|
-
for i, value in enumerate(params):
|
|
1401
|
-
result_dict[str(i + 1)] = value
|
|
1402
|
-
else:
|
|
1403
|
-
for i, value in enumerate(params):
|
|
1404
|
-
if i < len(param_info):
|
|
1405
|
-
param_name = param_info[i].name or str(i + 1)
|
|
1406
|
-
result_dict[param_name] = value
|
|
1407
|
-
else:
|
|
1408
|
-
result_dict[str(i + 1)] = value
|
|
1409
|
-
else:
|
|
1410
|
-
for i, value in enumerate(params):
|
|
1411
|
-
result_dict[str(i + 1)] = value
|
|
1412
|
-
|
|
1413
|
-
return result_dict
|
|
1414
|
-
|
|
1415
|
-
def _convert_single_value_to_colon_dict(self, params: Any, param_info: "list[Any]") -> "dict[str, Any]":
|
|
1416
|
-
"""Convert single value parameter to colon-style dict format."""
|
|
1417
|
-
result_dict: dict[str, Any] = {}
|
|
1418
|
-
if param_info and param_info[0].name and param_info[0].name.isdigit():
|
|
1419
|
-
result_dict[param_info[0].name] = params
|
|
1420
|
-
else:
|
|
1421
|
-
result_dict["1"] = params
|
|
1422
|
-
return result_dict
|
|
1423
|
-
|
|
1424
|
-
def _process_mixed_colon_params(self, params: "dict[str, Any]", param_info: "list[Any]") -> "dict[str, Any]":
|
|
1425
|
-
"""Process mixed colon-style numeric and converted parameters."""
|
|
1426
|
-
result_dict: dict[str, Any] = {}
|
|
1427
|
-
|
|
1428
|
-
# When we have mixed parameters (extracted literals + user oracle params),
|
|
1429
|
-
# we need to be careful about the ordering. The extracted literals should
|
|
1430
|
-
# fill positions based on where they appear in the SQL, not based on
|
|
1431
|
-
# matching parameter names.
|
|
1432
|
-
|
|
1433
|
-
# Separate extracted parameters and user oracle parameters
|
|
1434
|
-
extracted_params = []
|
|
1435
|
-
user_oracle_params = {}
|
|
1436
|
-
extracted_keys_sorted = []
|
|
1437
|
-
|
|
1438
|
-
for key, value in params.items():
|
|
1439
|
-
if has_parameter_value(value):
|
|
1440
|
-
extracted_params.append((key, value))
|
|
1441
|
-
elif key.isdigit():
|
|
1442
|
-
user_oracle_params[key] = value
|
|
1443
|
-
elif key.startswith("param_") and key[6:].isdigit():
|
|
1444
|
-
param_idx = int(key[6:])
|
|
1445
|
-
oracle_key = str(param_idx + 1)
|
|
1446
|
-
if oracle_key not in user_oracle_params:
|
|
1447
|
-
extracted_keys_sorted.append((param_idx, key, value))
|
|
1448
|
-
else:
|
|
1449
|
-
extracted_params.append((key, value))
|
|
1450
|
-
|
|
1451
|
-
extracted_keys_sorted.sort(key=operator.itemgetter(0))
|
|
1452
|
-
for _, key, value in extracted_keys_sorted:
|
|
1453
|
-
extracted_params.append((key, value))
|
|
1454
|
-
|
|
1455
|
-
# Build lists of parameter values in order
|
|
1456
|
-
extracted_values = []
|
|
1457
|
-
for _, value in extracted_params:
|
|
1458
|
-
if has_parameter_value(value):
|
|
1459
|
-
extracted_values.append(value.value)
|
|
1460
|
-
else:
|
|
1461
|
-
extracted_values.append(value)
|
|
1462
|
-
|
|
1463
|
-
user_values = [user_oracle_params[key] for key in sorted(user_oracle_params.keys(), key=int)]
|
|
1464
|
-
|
|
1465
|
-
# Now assign parameters based on position
|
|
1466
|
-
# Extracted parameters go first (they were literals in original positions)
|
|
1467
|
-
# User parameters follow
|
|
1468
|
-
all_values = extracted_values + user_values
|
|
1469
|
-
|
|
1470
|
-
for i, p in enumerate(sorted(param_info, key=lambda x: x.ordinal)):
|
|
1471
|
-
oracle_key = str(p.ordinal + 1)
|
|
1472
|
-
if i < len(all_values):
|
|
1473
|
-
result_dict[oracle_key] = all_values[i]
|
|
1474
|
-
|
|
1475
|
-
return result_dict
|
|
1476
|
-
|
|
1477
|
-
def _convert_to_positional_colon_format(self, params: Any, param_info: list[Any]) -> Any:
|
|
1478
|
-
"""Convert to dict format for positional colon style.
|
|
1479
|
-
|
|
1480
|
-
Positional colon style uses :1, :2, etc. placeholders and expects
|
|
1481
|
-
parameters as a dict with string keys "1", "2", etc.
|
|
1482
|
-
|
|
1483
|
-
For execute_many operations, returns a list of parameter sets.
|
|
1484
|
-
|
|
1485
|
-
Args:
|
|
1486
|
-
params: Original parameters
|
|
1487
|
-
param_info: List of parameter information
|
|
1488
|
-
|
|
1489
|
-
Returns:
|
|
1490
|
-
Dict of parameters with string keys "1", "2", etc., or list for execute_many
|
|
1491
|
-
"""
|
|
1492
|
-
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
1493
|
-
return params
|
|
1494
|
-
|
|
1495
|
-
if isinstance(params, (list, tuple)):
|
|
1496
|
-
return self._convert_list_to_colon_dict(params, param_info)
|
|
1497
|
-
|
|
1498
|
-
if not is_dict(params) and param_info:
|
|
1499
|
-
return self._convert_single_value_to_colon_dict(params, param_info)
|
|
1500
|
-
|
|
1501
|
-
if is_dict(params):
|
|
1502
|
-
if all(key.isdigit() for key in params):
|
|
1503
|
-
return params
|
|
1504
|
-
|
|
1505
|
-
if all(key.startswith("param_") for key in params):
|
|
1506
|
-
param_result_dict: dict[str, Any] = {}
|
|
1507
|
-
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
1508
|
-
# Use the parameter's ordinal to find the converted key
|
|
1509
|
-
converted_key = f"param_{p.ordinal}"
|
|
1510
|
-
if converted_key in params:
|
|
1511
|
-
if p.name and p.name.isdigit():
|
|
1512
|
-
# For Oracle numeric parameters, preserve the original number
|
|
1513
|
-
param_result_dict[p.name] = params[converted_key]
|
|
1514
|
-
else:
|
|
1515
|
-
# For other cases, use sequential numbering
|
|
1516
|
-
param_result_dict[str(p.ordinal + 1)] = params[converted_key]
|
|
1517
|
-
return param_result_dict
|
|
1518
|
-
|
|
1519
|
-
has_oracle_numeric = any(key.isdigit() for key in params)
|
|
1520
|
-
has_param_converted = any(key.startswith("param_") for key in params)
|
|
1521
|
-
has_typed_params = any(has_parameter_value(v) for v in params.values())
|
|
1522
|
-
|
|
1523
|
-
if (has_oracle_numeric and has_param_converted) or has_typed_params:
|
|
1524
|
-
return self._process_mixed_colon_params(params, param_info)
|
|
1525
|
-
|
|
1526
|
-
result_dict: dict[str, Any] = {}
|
|
1527
|
-
|
|
1528
|
-
if param_info:
|
|
1529
|
-
# Process all parameters in order of their ordinals
|
|
1530
|
-
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
1531
|
-
oracle_key = str(p.ordinal + 1)
|
|
1532
|
-
value = None
|
|
1533
|
-
|
|
1534
|
-
# Try different ways to find the parameter value
|
|
1535
|
-
if p.name and (
|
|
1536
|
-
p.name in params
|
|
1537
|
-
or (p.name.isdigit() and p.name in params)
|
|
1538
|
-
or (p.name.startswith("param_") and p.name in params)
|
|
1539
|
-
):
|
|
1540
|
-
value = params[p.name]
|
|
1541
|
-
|
|
1542
|
-
# If not found by name, try by ordinal-based keys
|
|
1543
|
-
if value is None:
|
|
1544
|
-
# Try param_N format (common for pipeline parameters)
|
|
1545
|
-
param_key = f"param_{p.ordinal}"
|
|
1546
|
-
if param_key in params:
|
|
1547
|
-
value = params[param_key]
|
|
1548
|
-
# Try arg_N format
|
|
1549
|
-
elif f"arg_{p.ordinal}" in params:
|
|
1550
|
-
value = params[f"arg_{p.ordinal}"]
|
|
1551
|
-
# For positional colon, also check if there's a numeric key
|
|
1552
|
-
# that matches the ordinal position
|
|
1553
|
-
elif str(p.ordinal + 1) in params:
|
|
1554
|
-
value = params[str(p.ordinal + 1)]
|
|
1555
|
-
|
|
1556
|
-
# Unwrap TypedParameter if needed
|
|
1557
|
-
if value is not None:
|
|
1558
|
-
if has_parameter_value(value):
|
|
1559
|
-
value = value.value
|
|
1560
|
-
result_dict[oracle_key] = value
|
|
1561
|
-
|
|
1562
|
-
return result_dict
|
|
1563
|
-
|
|
1564
|
-
return params
|
|
1565
|
-
|
|
1566
|
-
@staticmethod
|
|
1567
|
-
def _convert_to_positional_format(params: Any, param_info: list[Any]) -> Any:
|
|
1568
|
-
"""Convert to list format for positional parameter styles.
|
|
1569
|
-
|
|
1570
|
-
Args:
|
|
1571
|
-
params: Original parameters
|
|
1572
|
-
param_info: List of parameter information
|
|
1573
|
-
|
|
1574
|
-
Returns:
|
|
1575
|
-
List of parameters
|
|
1576
|
-
"""
|
|
1577
|
-
result_list: list[Any] = []
|
|
1578
|
-
if is_dict(params):
|
|
1579
|
-
param_values_by_ordinal: dict[int, Any] = {}
|
|
1580
|
-
|
|
1581
|
-
for p in param_info:
|
|
1582
|
-
if p.name and p.name in params:
|
|
1583
|
-
param_values_by_ordinal[p.ordinal] = params[p.name]
|
|
1584
|
-
|
|
1585
|
-
for p in param_info:
|
|
1586
|
-
if p.name is None and p.ordinal not in param_values_by_ordinal:
|
|
1587
|
-
arg_key = f"arg_{p.ordinal}"
|
|
1588
|
-
param_key = f"param_{p.ordinal}"
|
|
1589
|
-
if arg_key in params:
|
|
1590
|
-
param_values_by_ordinal[p.ordinal] = params[arg_key]
|
|
1591
|
-
elif param_key in params:
|
|
1592
|
-
param_values_by_ordinal[p.ordinal] = params[param_key]
|
|
1593
|
-
|
|
1594
|
-
remaining_params = {
|
|
1595
|
-
k: v
|
|
1596
|
-
for k, v in params.items()
|
|
1597
|
-
if k not in {p.name for p in param_info if p.name} and not k.startswith(("arg_", "param_"))
|
|
1598
|
-
}
|
|
1599
|
-
|
|
1600
|
-
unmatched_ordinals = [p.ordinal for p in param_info if p.ordinal not in param_values_by_ordinal]
|
|
1601
|
-
|
|
1602
|
-
for ordinal, (_, value) in zip(unmatched_ordinals, remaining_params.items()):
|
|
1603
|
-
param_values_by_ordinal[ordinal] = value
|
|
1604
|
-
|
|
1605
|
-
for p in param_info:
|
|
1606
|
-
val = param_values_by_ordinal.get(p.ordinal)
|
|
1607
|
-
if val is not None:
|
|
1608
|
-
if has_parameter_value(val):
|
|
1609
|
-
result_list.append(val.value)
|
|
1610
|
-
else:
|
|
1611
|
-
result_list.append(val)
|
|
1612
|
-
else:
|
|
1613
|
-
result_list.append(None)
|
|
1614
|
-
|
|
1615
|
-
return result_list
|
|
1616
|
-
if isinstance(params, (list, tuple)):
|
|
1617
|
-
# Special case: if params is empty, preserve it (don't create None values)
|
|
1618
|
-
# This is important for execute_many with empty parameter lists
|
|
1619
|
-
if not params:
|
|
1620
|
-
return params
|
|
1621
|
-
|
|
1622
|
-
# Handle mixed parameter styles correctly
|
|
1623
|
-
# For mixed styles, assign parameters in order of appearance, not by numeric reference
|
|
1624
|
-
if param_info and any(p.style == ParameterStyle.NUMERIC for p in param_info):
|
|
1625
|
-
# Create mapping from ordinal to parameter value
|
|
1626
|
-
param_mapping: dict[int, Any] = {}
|
|
1627
|
-
|
|
1628
|
-
# Sort parameter info by position to get order of appearance
|
|
1629
|
-
sorted_params = sorted(param_info, key=lambda p: p.position)
|
|
1630
|
-
|
|
1631
|
-
# Assign parameters sequentially in order of appearance
|
|
1632
|
-
for i, param_info_item in enumerate(sorted_params):
|
|
1633
|
-
if i < len(params):
|
|
1634
|
-
param_mapping[param_info_item.ordinal] = params[i]
|
|
1635
|
-
|
|
1636
|
-
# Build result list ordered by original ordinal values
|
|
1637
|
-
for i in range(len(param_info)):
|
|
1638
|
-
val = param_mapping.get(i)
|
|
1639
|
-
if val is not None:
|
|
1640
|
-
if has_parameter_value(val):
|
|
1641
|
-
result_list.append(val.value)
|
|
1642
|
-
else:
|
|
1643
|
-
result_list.append(val)
|
|
1644
|
-
else:
|
|
1645
|
-
result_list.append(None)
|
|
1646
|
-
|
|
1647
|
-
return result_list
|
|
1648
|
-
|
|
1649
|
-
# Standard conversion for non-mixed styles
|
|
1650
|
-
for param in params:
|
|
1651
|
-
if has_parameter_value(param):
|
|
1652
|
-
result_list.append(param.value)
|
|
1653
|
-
else:
|
|
1654
|
-
result_list.append(param)
|
|
1655
|
-
return result_list
|
|
1656
|
-
return params
|
|
1657
|
-
|
|
1658
|
-
@staticmethod
|
|
1659
|
-
def _convert_to_named_colon_format(params: Any, param_info: list[Any]) -> Any:
|
|
1660
|
-
"""Convert to dict format for named colon style.
|
|
1661
|
-
|
|
1662
|
-
Args:
|
|
1663
|
-
params: Original parameters
|
|
1664
|
-
param_info: List of parameter information
|
|
1665
|
-
|
|
1666
|
-
Returns:
|
|
1667
|
-
Dict of parameters with generated names
|
|
1668
|
-
"""
|
|
1669
|
-
result_dict: dict[str, Any] = {}
|
|
1670
|
-
if is_dict(params):
|
|
1671
|
-
if all(p.name in params for p in param_info if p.name):
|
|
1672
|
-
return params
|
|
1673
|
-
for p in param_info:
|
|
1674
|
-
if p.name and p.name in params:
|
|
1675
|
-
result_dict[p.name] = params[p.name]
|
|
1676
|
-
elif f"param_{p.ordinal}" in params:
|
|
1677
|
-
result_dict[p.name or f"arg_{p.ordinal}"] = params[f"param_{p.ordinal}"]
|
|
1678
|
-
return result_dict
|
|
1679
|
-
if isinstance(params, (list, tuple)):
|
|
1680
|
-
for i, value in enumerate(params):
|
|
1681
|
-
if has_parameter_value(value):
|
|
1682
|
-
value = value.value
|
|
1683
|
-
|
|
1684
|
-
if i < len(param_info):
|
|
1685
|
-
p = param_info[i]
|
|
1686
|
-
param_name = p.name or f"arg_{i}"
|
|
1687
|
-
result_dict[param_name] = value
|
|
1688
|
-
else:
|
|
1689
|
-
param_name = f"arg_{i}"
|
|
1690
|
-
result_dict[param_name] = value
|
|
1691
|
-
return result_dict
|
|
1692
|
-
return params
|
|
1693
|
-
|
|
1694
|
-
@staticmethod
|
|
1695
|
-
def _convert_to_named_pyformat_format(params: Any, param_info: list[Any]) -> Any:
|
|
1696
|
-
"""Convert to dict format for named pyformat style.
|
|
1697
|
-
|
|
1698
|
-
Args:
|
|
1699
|
-
params: Original parameters
|
|
1700
|
-
param_info: List of parameter information
|
|
1701
|
-
|
|
1702
|
-
Returns:
|
|
1703
|
-
Dict of parameters with names
|
|
1704
|
-
"""
|
|
1705
|
-
if isinstance(params, (list, tuple)):
|
|
1706
|
-
result_dict: dict[str, Any] = {}
|
|
1707
|
-
for i, p in enumerate(param_info):
|
|
1708
|
-
if i < len(params):
|
|
1709
|
-
param_name = p.name or f"param_{i}"
|
|
1710
|
-
result_dict[param_name] = params[i]
|
|
1711
|
-
return result_dict
|
|
1712
|
-
return params
|
|
1713
|
-
|
|
1714
|
-
@property
|
|
1715
|
-
def validation_errors(self) -> list[Any]:
|
|
1716
|
-
"""Get validation errors."""
|
|
1717
|
-
if not self._config.enable_validation:
|
|
1718
|
-
return []
|
|
1719
|
-
self._ensure_processed()
|
|
1720
|
-
if not self._processed_state:
|
|
1721
|
-
msg = "Failed to process SQL statement"
|
|
1722
|
-
raise RuntimeError(msg)
|
|
1723
|
-
return self._processed_state.validation_errors
|
|
1724
|
-
|
|
1725
|
-
@property
|
|
1726
|
-
def has_errors(self) -> bool:
|
|
1727
|
-
"""Check if there are validation errors."""
|
|
1728
|
-
return bool(self.validation_errors)
|
|
1729
|
-
|
|
1730
|
-
@property
|
|
1731
|
-
def is_safe(self) -> bool:
|
|
1732
|
-
"""Check if statement is safe."""
|
|
1733
|
-
return not self.has_errors
|
|
1734
|
-
|
|
1735
|
-
def validate(self) -> list[Any]:
|
|
1736
|
-
"""Validate the SQL statement and return validation errors."""
|
|
1737
|
-
return self.validation_errors
|
|
1738
|
-
|
|
1739
|
-
@property
|
|
1740
|
-
def parameter_info(self) -> list[Any]:
|
|
1741
|
-
"""Get parameter information from the SQL statement.
|
|
1742
|
-
|
|
1743
|
-
Returns the original parameter info before any conversion.
|
|
1744
|
-
"""
|
|
1745
|
-
validator = self._config.parameter_validator
|
|
1746
|
-
if self._raw_sql:
|
|
1747
|
-
return validator.extract_parameters(self._raw_sql)
|
|
1748
|
-
|
|
1749
|
-
self._ensure_processed()
|
|
1750
|
-
|
|
1751
|
-
if self._processing_context:
|
|
1752
|
-
return self._processing_context.parameter_info
|
|
1753
|
-
|
|
1754
|
-
return []
|
|
1755
|
-
|
|
1756
|
-
@property
|
|
1757
|
-
def _raw_parameters(self) -> Any:
|
|
1758
|
-
"""Get raw parameters for compatibility."""
|
|
1759
|
-
return self._original_parameters
|
|
1760
|
-
|
|
1761
|
-
@property
|
|
1762
|
-
def _sql(self) -> str:
|
|
1763
|
-
"""Get SQL string for compatibility."""
|
|
1764
|
-
return self.sql
|
|
1765
|
-
|
|
1766
|
-
@property
|
|
1767
|
-
def _expression(self) -> "Optional[exp.Expression]":
|
|
1768
|
-
"""Get expression for compatibility."""
|
|
1769
|
-
return self.expression
|
|
1770
|
-
|
|
1771
|
-
@property
|
|
1772
|
-
def statement(self) -> exp.Expression:
|
|
1773
|
-
"""Get statement for compatibility."""
|
|
1774
|
-
return self._statement
|