sqlspec 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +50 -25
- sqlspec/__main__.py +1 -1
- sqlspec/__metadata__.py +1 -3
- sqlspec/_serialization.py +1 -2
- sqlspec/_sql.py +480 -121
- sqlspec/_typing.py +278 -142
- sqlspec/adapters/adbc/__init__.py +4 -3
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +115 -260
- sqlspec/adapters/adbc/driver.py +462 -367
- sqlspec/adapters/aiosqlite/__init__.py +18 -3
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +199 -129
- sqlspec/adapters/aiosqlite/driver.py +230 -269
- sqlspec/adapters/asyncmy/__init__.py +18 -3
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +80 -168
- sqlspec/adapters/asyncmy/driver.py +260 -225
- sqlspec/adapters/asyncpg/__init__.py +19 -4
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +82 -181
- sqlspec/adapters/asyncpg/driver.py +285 -383
- sqlspec/adapters/bigquery/__init__.py +17 -3
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +191 -258
- sqlspec/adapters/bigquery/driver.py +474 -646
- sqlspec/adapters/duckdb/__init__.py +14 -3
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +415 -351
- sqlspec/adapters/duckdb/driver.py +343 -413
- sqlspec/adapters/oracledb/__init__.py +19 -5
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +123 -379
- sqlspec/adapters/oracledb/driver.py +507 -560
- sqlspec/adapters/psqlpy/__init__.py +13 -3
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +93 -254
- sqlspec/adapters/psqlpy/driver.py +505 -234
- sqlspec/adapters/psycopg/__init__.py +19 -5
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +143 -403
- sqlspec/adapters/psycopg/driver.py +706 -872
- sqlspec/adapters/sqlite/__init__.py +14 -3
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +202 -118
- sqlspec/adapters/sqlite/driver.py +264 -303
- sqlspec/base.py +105 -9
- sqlspec/{statement/builder → builder}/__init__.py +12 -14
- sqlspec/{statement/builder → builder}/_base.py +120 -55
- sqlspec/{statement/builder → builder}/_column.py +17 -6
- sqlspec/{statement/builder → builder}/_ddl.py +46 -79
- sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
- sqlspec/{statement/builder → builder}/_delete.py +6 -25
- sqlspec/{statement/builder → builder}/_insert.py +18 -65
- sqlspec/builder/_merge.py +56 -0
- sqlspec/{statement/builder → builder}/_parsing_utils.py +8 -11
- sqlspec/{statement/builder → builder}/_select.py +11 -56
- sqlspec/{statement/builder → builder}/_update.py +12 -18
- sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
- sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
- sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +34 -18
- sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
- sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +19 -9
- sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
- sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
- sqlspec/{statement/builder → builder}/mixins/_select_operations.py +25 -38
- sqlspec/{statement/builder → builder}/mixins/_update_operations.py +15 -16
- sqlspec/{statement/builder → builder}/mixins/_where_clause.py +210 -137
- sqlspec/cli.py +4 -5
- sqlspec/config.py +180 -133
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.py +830 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.py +664 -0
- sqlspec/{statement → core}/splitter.py +321 -191
- sqlspec/core/statement.py +666 -0
- sqlspec/driver/__init__.py +7 -10
- sqlspec/driver/_async.py +387 -176
- sqlspec/driver/_common.py +527 -289
- sqlspec/driver/_sync.py +390 -172
- sqlspec/driver/mixins/__init__.py +2 -19
- sqlspec/driver/mixins/_result_tools.py +164 -0
- sqlspec/driver/mixins/_sql_translator.py +6 -3
- sqlspec/exceptions.py +5 -252
- sqlspec/extensions/aiosql/adapter.py +93 -96
- sqlspec/extensions/litestar/cli.py +1 -1
- sqlspec/extensions/litestar/config.py +0 -1
- sqlspec/extensions/litestar/handlers.py +15 -26
- sqlspec/extensions/litestar/plugin.py +18 -16
- sqlspec/extensions/litestar/providers.py +17 -52
- sqlspec/loader.py +424 -105
- sqlspec/migrations/__init__.py +12 -0
- sqlspec/migrations/base.py +92 -68
- sqlspec/migrations/commands.py +24 -106
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +49 -51
- sqlspec/migrations/tracker.py +31 -44
- sqlspec/migrations/utils.py +64 -24
- sqlspec/protocols.py +7 -183
- sqlspec/storage/__init__.py +1 -1
- sqlspec/storage/backends/base.py +37 -40
- sqlspec/storage/backends/fsspec.py +136 -112
- sqlspec/storage/backends/obstore.py +138 -160
- sqlspec/storage/capabilities.py +5 -4
- sqlspec/storage/registry.py +57 -106
- sqlspec/typing.py +136 -115
- sqlspec/utils/__init__.py +2 -3
- sqlspec/utils/correlation.py +0 -3
- sqlspec/utils/deprecation.py +6 -6
- sqlspec/utils/fixtures.py +6 -6
- sqlspec/utils/logging.py +0 -2
- sqlspec/utils/module_loader.py +7 -12
- sqlspec/utils/singleton.py +0 -1
- sqlspec/utils/sync_tools.py +17 -38
- sqlspec/utils/text.py +12 -51
- sqlspec/utils/type_guards.py +443 -232
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/METADATA +7 -2
- sqlspec-0.16.0.dist-info/RECORD +134 -0
- sqlspec/adapters/adbc/transformers.py +0 -108
- sqlspec/driver/connection.py +0 -207
- sqlspec/driver/mixins/_cache.py +0 -114
- sqlspec/driver/mixins/_csv_writer.py +0 -91
- sqlspec/driver/mixins/_pipeline.py +0 -508
- sqlspec/driver/mixins/_query_tools.py +0 -796
- sqlspec/driver/mixins/_result_utils.py +0 -138
- sqlspec/driver/mixins/_storage.py +0 -912
- sqlspec/driver/mixins/_type_coercion.py +0 -128
- sqlspec/driver/parameters.py +0 -138
- sqlspec/statement/__init__.py +0 -21
- sqlspec/statement/builder/_merge.py +0 -95
- sqlspec/statement/cache.py +0 -50
- sqlspec/statement/filters.py +0 -625
- sqlspec/statement/parameters.py +0 -956
- sqlspec/statement/pipelines/__init__.py +0 -210
- sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
- sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
- sqlspec/statement/pipelines/context.py +0 -109
- sqlspec/statement/pipelines/transformers/__init__.py +0 -7
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
- sqlspec/statement/pipelines/validators/__init__.py +0 -23
- sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
- sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
- sqlspec/statement/pipelines/validators/_performance.py +0 -714
- sqlspec/statement/pipelines/validators/_security.py +0 -967
- sqlspec/statement/result.py +0 -435
- sqlspec/statement/sql.py +0 -1774
- sqlspec/utils/cached_property.py +0 -25
- sqlspec/utils/statement_hashing.py +0 -203
- sqlspec-0.14.1.dist-info/RECORD +0 -145
- /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,1247 +0,0 @@
|
|
|
1
|
-
"""Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Optional, Union
|
|
5
|
-
|
|
6
|
-
from sqlglot import exp
|
|
7
|
-
from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
|
|
8
|
-
|
|
9
|
-
from sqlspec.protocols import ProcessorProtocol
|
|
10
|
-
from sqlspec.statement.parameters import ParameterStyle, TypedParameter
|
|
11
|
-
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
12
|
-
|
|
13
|
-
__all__ = ("ParameterizationContext", "ParameterizeLiterals")
|
|
14
|
-
|
|
15
|
-
# Constants for magic values and literal parameterization
|
|
16
|
-
MAX_DECIMAL_PRECISION = 6
|
|
17
|
-
MAX_INT32_VALUE = 2147483647
|
|
18
|
-
DEFAULT_MAX_STRING_LENGTH = 1000
|
|
19
|
-
"""Default maximum string length for literal parameterization."""
|
|
20
|
-
|
|
21
|
-
DEFAULT_MAX_ARRAY_LENGTH = 100
|
|
22
|
-
"""Default maximum array length for literal parameterization."""
|
|
23
|
-
|
|
24
|
-
DEFAULT_MAX_IN_LIST_SIZE = 50
|
|
25
|
-
"""Default maximum IN clause list size before parameterization."""
|
|
26
|
-
|
|
27
|
-
MAX_ENUM_LENGTH = 50
|
|
28
|
-
"""Maximum length for enum-like string values."""
|
|
29
|
-
|
|
30
|
-
MIN_ENUM_LENGTH = 2
|
|
31
|
-
"""Minimum length for enum-like string values to be meaningful."""
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class ParameterizationContext:
|
|
36
|
-
"""Context for tracking parameterization state during AST traversal."""
|
|
37
|
-
|
|
38
|
-
parent_stack: list[exp.Expression]
|
|
39
|
-
in_function_args: bool = False
|
|
40
|
-
in_case_when: bool = False
|
|
41
|
-
in_array: bool = False
|
|
42
|
-
in_in_clause: bool = False
|
|
43
|
-
in_recursive_cte: bool = False
|
|
44
|
-
in_subquery: bool = False
|
|
45
|
-
in_select_list: bool = False
|
|
46
|
-
in_join_condition: bool = False
|
|
47
|
-
function_depth: int = 0
|
|
48
|
-
cte_depth: int = 0
|
|
49
|
-
subquery_depth: int = 0
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class ParameterizeLiterals(ProcessorProtocol):
|
|
53
|
-
"""Advanced literal parameterization using SQLGlot AST analysis.
|
|
54
|
-
|
|
55
|
-
This enhanced version provides:
|
|
56
|
-
- Context-aware parameterization based on AST position
|
|
57
|
-
- Smart handling of arrays, IN clauses, and function arguments
|
|
58
|
-
- Type-preserving parameter extraction
|
|
59
|
-
- Configurable parameterization strategies
|
|
60
|
-
- Performance optimization for query plan caching
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
placeholder_style: Style of placeholder to use ("?", ":name", "$1", etc.).
|
|
64
|
-
preserve_null: Whether to preserve NULL literals as-is.
|
|
65
|
-
preserve_boolean: Whether to preserve boolean literals as-is.
|
|
66
|
-
preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
|
|
67
|
-
preserve_in_functions: List of function names where literals should be preserved.
|
|
68
|
-
preserve_in_recursive_cte: Whether to preserve literals in recursive CTEs (default True to avoid type inference issues).
|
|
69
|
-
parameterize_arrays: Whether to parameterize array literals.
|
|
70
|
-
parameterize_in_lists: Whether to parameterize IN clause lists.
|
|
71
|
-
max_string_length: Maximum string length to parameterize.
|
|
72
|
-
max_array_length: Maximum array length to parameterize.
|
|
73
|
-
max_in_list_size: Maximum IN list size to parameterize.
|
|
74
|
-
type_preservation: Whether to preserve exact literal types.
|
|
75
|
-
"""
|
|
76
|
-
|
|
77
|
-
def __init__(
|
|
78
|
-
self,
|
|
79
|
-
placeholder_style: str = "?",
|
|
80
|
-
preserve_null: bool = True,
|
|
81
|
-
preserve_boolean: bool = True,
|
|
82
|
-
preserve_numbers_in_limit: bool = True,
|
|
83
|
-
preserve_in_functions: Optional[list[str]] = None,
|
|
84
|
-
preserve_in_recursive_cte: bool = True,
|
|
85
|
-
parameterize_arrays: bool = True,
|
|
86
|
-
parameterize_in_lists: bool = True,
|
|
87
|
-
max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
|
|
88
|
-
max_array_length: int = DEFAULT_MAX_ARRAY_LENGTH,
|
|
89
|
-
max_in_list_size: int = DEFAULT_MAX_IN_LIST_SIZE,
|
|
90
|
-
type_preservation: bool = True,
|
|
91
|
-
) -> None:
|
|
92
|
-
self.placeholder_style = placeholder_style
|
|
93
|
-
self.preserve_null = preserve_null
|
|
94
|
-
self.preserve_boolean = preserve_boolean
|
|
95
|
-
self.preserve_numbers_in_limit = preserve_numbers_in_limit
|
|
96
|
-
self.preserve_in_recursive_cte = preserve_in_recursive_cte
|
|
97
|
-
self.preserve_in_functions = preserve_in_functions or [
|
|
98
|
-
"COALESCE",
|
|
99
|
-
"IFNULL",
|
|
100
|
-
"NVL",
|
|
101
|
-
"ISNULL",
|
|
102
|
-
# Array functions that take dimension arguments
|
|
103
|
-
"ARRAYSIZE", # SQLglot converts array_length to ArraySize
|
|
104
|
-
"ARRAY_UPPER",
|
|
105
|
-
"ARRAY_LOWER",
|
|
106
|
-
"ARRAY_NDIMS",
|
|
107
|
-
"ROUND",
|
|
108
|
-
]
|
|
109
|
-
self.parameterize_arrays = parameterize_arrays
|
|
110
|
-
self.parameterize_in_lists = parameterize_in_lists
|
|
111
|
-
self.max_string_length = max_string_length
|
|
112
|
-
self.max_array_length = max_array_length
|
|
113
|
-
self.max_in_list_size = max_in_list_size
|
|
114
|
-
self.type_preservation = type_preservation
|
|
115
|
-
self.extracted_parameters: list[Any] = []
|
|
116
|
-
self._parameter_counter = 0
|
|
117
|
-
self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
|
|
118
|
-
self._preserve_dict_format = False # Track whether to preserve dict format
|
|
119
|
-
|
|
120
|
-
def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
|
|
121
|
-
"""Advanced literal parameterization with context-aware AST analysis."""
|
|
122
|
-
if expression is None or context.current_expression is None:
|
|
123
|
-
return expression
|
|
124
|
-
|
|
125
|
-
# For named parameters (like BigQuery @param), don't reorder to avoid breaking name mapping
|
|
126
|
-
if (
|
|
127
|
-
context.config.input_sql_had_placeholders
|
|
128
|
-
and context.parameter_info
|
|
129
|
-
and any(p.name for p in context.parameter_info)
|
|
130
|
-
):
|
|
131
|
-
return expression
|
|
132
|
-
|
|
133
|
-
self.extracted_parameters = []
|
|
134
|
-
self._parameter_metadata = []
|
|
135
|
-
|
|
136
|
-
# When reordering is needed (SQL already has placeholders), we need to start
|
|
137
|
-
# our counter at the number of existing parameters to avoid conflicts
|
|
138
|
-
if context.config.input_sql_had_placeholders and context.parameter_info:
|
|
139
|
-
# Find the highest ordinal among existing parameters
|
|
140
|
-
max_ordinal = max(p.ordinal for p in context.parameter_info)
|
|
141
|
-
self._parameter_counter = max_ordinal + 1
|
|
142
|
-
else:
|
|
143
|
-
self._parameter_counter = 0
|
|
144
|
-
|
|
145
|
-
# Track original user parameters for proper merging
|
|
146
|
-
self._original_params = context.merged_parameters
|
|
147
|
-
self._user_param_index = 0
|
|
148
|
-
# If original params are dict and we have named placeholders, preserve dict format
|
|
149
|
-
if isinstance(context.merged_parameters, dict) and context.parameter_info:
|
|
150
|
-
# Check if we have named placeholders
|
|
151
|
-
has_named = any(p.name for p in context.parameter_info)
|
|
152
|
-
if has_named:
|
|
153
|
-
self._final_params: Union[dict[str, Any], list[Any]] = {}
|
|
154
|
-
self._preserve_dict_format = True
|
|
155
|
-
else:
|
|
156
|
-
self._final_params = []
|
|
157
|
-
self._preserve_dict_format = False
|
|
158
|
-
else:
|
|
159
|
-
self._final_params = []
|
|
160
|
-
self._preserve_dict_format = False
|
|
161
|
-
self._is_reordering_needed = context.config.input_sql_had_placeholders
|
|
162
|
-
|
|
163
|
-
param_context = ParameterizationContext(parent_stack=[])
|
|
164
|
-
transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
|
|
165
|
-
context.current_expression = transformed_expression
|
|
166
|
-
|
|
167
|
-
# If we're reordering, update the merged parameters with the reordered result
|
|
168
|
-
# In this case, we don't need to add to extracted_parameters_from_pipeline
|
|
169
|
-
# because the parameters are already in _final_params
|
|
170
|
-
if self._is_reordering_needed and self._final_params:
|
|
171
|
-
context.merged_parameters = self._final_params
|
|
172
|
-
else:
|
|
173
|
-
# Only add extracted parameters to the pipeline if we're not reordering
|
|
174
|
-
# This prevents duplication when parameters are already in merged_parameters
|
|
175
|
-
context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
|
|
176
|
-
|
|
177
|
-
context.metadata["parameter_metadata"] = self._parameter_metadata
|
|
178
|
-
|
|
179
|
-
return transformed_expression
|
|
180
|
-
|
|
181
|
-
def _transform_with_context(self, node: exp.Expression, context: ParameterizationContext) -> exp.Expression:
|
|
182
|
-
"""Transform expression tree with context tracking."""
|
|
183
|
-
# Update context based on node type
|
|
184
|
-
self._update_context(node, context, entering=True)
|
|
185
|
-
|
|
186
|
-
# Process the node
|
|
187
|
-
if isinstance(node, Literal):
|
|
188
|
-
result = self._process_literal_with_context(node, context)
|
|
189
|
-
elif isinstance(node, (Boolean, Null)):
|
|
190
|
-
# Boolean and Null are not Literal subclasses, handle them separately
|
|
191
|
-
result = self._process_literal_with_context(node, context)
|
|
192
|
-
elif isinstance(node, Array) and self.parameterize_arrays:
|
|
193
|
-
result = self._process_array(node, context)
|
|
194
|
-
elif isinstance(node, exp.In) and self.parameterize_in_lists:
|
|
195
|
-
result = self._process_in_clause(node, context)
|
|
196
|
-
elif isinstance(node, exp.Placeholder) and self._is_reordering_needed:
|
|
197
|
-
# Handle existing placeholders when reordering is needed
|
|
198
|
-
result = self._process_existing_placeholder(node, context)
|
|
199
|
-
elif isinstance(node, exp.Parameter) and self._is_reordering_needed:
|
|
200
|
-
# Handle PostgreSQL-style parameters ($1, $2) when reordering is needed
|
|
201
|
-
result = self._process_existing_parameter(node, context)
|
|
202
|
-
elif isinstance(node, exp.Column) and self._is_reordering_needed:
|
|
203
|
-
# Check if this column looks like a PostgreSQL parameter ($1, $2, etc.)
|
|
204
|
-
column_name = str(node.this) if hasattr(node, "this") else ""
|
|
205
|
-
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
206
|
-
# This is a PostgreSQL-style parameter parsed as a column
|
|
207
|
-
result = self._process_postgresql_column_parameter(node, context)
|
|
208
|
-
else:
|
|
209
|
-
# Regular column - process children
|
|
210
|
-
for key, value in node.args.items():
|
|
211
|
-
if isinstance(value, exp.Expression):
|
|
212
|
-
node.set(key, self._transform_with_context(value, context))
|
|
213
|
-
elif isinstance(value, list):
|
|
214
|
-
node.set(
|
|
215
|
-
key,
|
|
216
|
-
[
|
|
217
|
-
self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
|
|
218
|
-
for v in value
|
|
219
|
-
],
|
|
220
|
-
)
|
|
221
|
-
result = node
|
|
222
|
-
else:
|
|
223
|
-
# Recursively process children
|
|
224
|
-
for key, value in node.args.items():
|
|
225
|
-
if isinstance(value, exp.Expression):
|
|
226
|
-
node.set(key, self._transform_with_context(value, context))
|
|
227
|
-
elif isinstance(value, list):
|
|
228
|
-
node.set(
|
|
229
|
-
key,
|
|
230
|
-
[
|
|
231
|
-
self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
|
|
232
|
-
for v in value
|
|
233
|
-
],
|
|
234
|
-
)
|
|
235
|
-
result = node
|
|
236
|
-
|
|
237
|
-
# Update context when leaving
|
|
238
|
-
self._update_context(node, context, entering=False)
|
|
239
|
-
|
|
240
|
-
return result
|
|
241
|
-
|
|
242
|
-
def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
|
|
243
|
-
"""Update parameterization context based on current AST node."""
|
|
244
|
-
if entering:
|
|
245
|
-
self._update_context_entering(node, context)
|
|
246
|
-
else:
|
|
247
|
-
self._update_context_leaving(node, context)
|
|
248
|
-
|
|
249
|
-
def _update_context_entering(self, node: exp.Expression, context: ParameterizationContext) -> None:
|
|
250
|
-
"""Update context when entering a node."""
|
|
251
|
-
context.parent_stack.append(node)
|
|
252
|
-
|
|
253
|
-
if isinstance(node, Func):
|
|
254
|
-
self._update_context_entering_func(node, context)
|
|
255
|
-
elif isinstance(node, exp.Case):
|
|
256
|
-
context.in_case_when = True
|
|
257
|
-
elif isinstance(node, Array):
|
|
258
|
-
context.in_array = True
|
|
259
|
-
elif isinstance(node, exp.In):
|
|
260
|
-
context.in_in_clause = True
|
|
261
|
-
elif isinstance(node, exp.CTE):
|
|
262
|
-
self._update_context_entering_cte(node, context)
|
|
263
|
-
elif isinstance(node, exp.Subquery):
|
|
264
|
-
context.subquery_depth += 1
|
|
265
|
-
context.in_subquery = True
|
|
266
|
-
elif isinstance(node, exp.Select):
|
|
267
|
-
self._update_context_entering_select(node, context)
|
|
268
|
-
elif isinstance(node, exp.Join):
|
|
269
|
-
context.in_join_condition = True
|
|
270
|
-
|
|
271
|
-
def _update_context_entering_func(self, node: Func, context: ParameterizationContext) -> None:
|
|
272
|
-
"""Update context when entering a function node."""
|
|
273
|
-
context.function_depth += 1
|
|
274
|
-
# Get function name from class name or node.name
|
|
275
|
-
func_name = node.__class__.__name__.upper()
|
|
276
|
-
if func_name in self.preserve_in_functions or (node.name and node.name.upper() in self.preserve_in_functions):
|
|
277
|
-
context.in_function_args = True
|
|
278
|
-
|
|
279
|
-
def _update_context_entering_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
|
|
280
|
-
"""Update context when entering a CTE node."""
|
|
281
|
-
context.cte_depth += 1
|
|
282
|
-
# Check if this CTE is recursive:
|
|
283
|
-
# 1. Parent WITH must be RECURSIVE
|
|
284
|
-
# 2. CTE must contain UNION (characteristic of recursive CTEs)
|
|
285
|
-
is_in_recursive_with = any(
|
|
286
|
-
isinstance(parent, exp.With) and parent.args.get("recursive", False)
|
|
287
|
-
for parent in reversed(context.parent_stack)
|
|
288
|
-
)
|
|
289
|
-
if is_in_recursive_with and self._contains_union(node):
|
|
290
|
-
context.in_recursive_cte = True
|
|
291
|
-
|
|
292
|
-
def _update_context_entering_select(self, node: exp.Select, context: ParameterizationContext) -> None:
|
|
293
|
-
"""Update context when entering a SELECT node."""
|
|
294
|
-
# Only track nested SELECT statements as subqueries if they're not part of a recursive CTE
|
|
295
|
-
is_in_recursive_cte = any(
|
|
296
|
-
isinstance(parent, exp.CTE)
|
|
297
|
-
and any(
|
|
298
|
-
isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
|
|
299
|
-
for grandparent in context.parent_stack
|
|
300
|
-
)
|
|
301
|
-
for parent in context.parent_stack[:-1]
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
if not is_in_recursive_cte and any(
|
|
305
|
-
isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
|
|
306
|
-
for parent in context.parent_stack[:-1] # Exclude the current node
|
|
307
|
-
):
|
|
308
|
-
context.subquery_depth += 1
|
|
309
|
-
context.in_subquery = True
|
|
310
|
-
# Check if we're in a SELECT clause expressions list
|
|
311
|
-
if hasattr(node, "expressions"):
|
|
312
|
-
# We'll handle this specifically when processing individual expressions
|
|
313
|
-
context.in_select_list = False # Will be detected by _is_in_select_expressions
|
|
314
|
-
|
|
315
|
-
def _update_context_leaving(self, node: exp.Expression, context: ParameterizationContext) -> None:
|
|
316
|
-
"""Update context when leaving a node."""
|
|
317
|
-
if context.parent_stack:
|
|
318
|
-
context.parent_stack.pop()
|
|
319
|
-
|
|
320
|
-
if isinstance(node, Func):
|
|
321
|
-
self._update_context_leaving_func(node, context)
|
|
322
|
-
elif isinstance(node, exp.Case):
|
|
323
|
-
context.in_case_when = False
|
|
324
|
-
elif isinstance(node, Array):
|
|
325
|
-
context.in_array = False
|
|
326
|
-
elif isinstance(node, exp.In):
|
|
327
|
-
context.in_in_clause = False
|
|
328
|
-
elif isinstance(node, exp.CTE):
|
|
329
|
-
self._update_context_leaving_cte(node, context)
|
|
330
|
-
elif isinstance(node, exp.Subquery):
|
|
331
|
-
self._update_context_leaving_subquery(node, context)
|
|
332
|
-
elif isinstance(node, exp.Select):
|
|
333
|
-
self._update_context_leaving_select(node, context)
|
|
334
|
-
elif isinstance(node, exp.Join):
|
|
335
|
-
context.in_join_condition = False
|
|
336
|
-
|
|
337
|
-
def _update_context_leaving_func(self, node: Func, context: ParameterizationContext) -> None:
|
|
338
|
-
"""Update context when leaving a function node."""
|
|
339
|
-
context.function_depth -= 1
|
|
340
|
-
if context.function_depth == 0:
|
|
341
|
-
context.in_function_args = False
|
|
342
|
-
|
|
343
|
-
def _update_context_leaving_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
|
|
344
|
-
"""Update context when leaving a CTE node."""
|
|
345
|
-
context.cte_depth -= 1
|
|
346
|
-
if context.cte_depth == 0:
|
|
347
|
-
context.in_recursive_cte = False
|
|
348
|
-
|
|
349
|
-
def _update_context_leaving_subquery(self, node: exp.Subquery, context: ParameterizationContext) -> None:
|
|
350
|
-
"""Update context when leaving a subquery node."""
|
|
351
|
-
context.subquery_depth -= 1
|
|
352
|
-
if context.subquery_depth == 0:
|
|
353
|
-
context.in_subquery = False
|
|
354
|
-
|
|
355
|
-
def _update_context_leaving_select(self, node: exp.Select, context: ParameterizationContext) -> None:
|
|
356
|
-
"""Update context when leaving a SELECT node."""
|
|
357
|
-
# Only decrement if this was a nested SELECT (not part of recursive CTE)
|
|
358
|
-
is_in_recursive_cte = any(
|
|
359
|
-
isinstance(parent, exp.CTE)
|
|
360
|
-
and any(
|
|
361
|
-
isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
|
|
362
|
-
for grandparent in context.parent_stack
|
|
363
|
-
)
|
|
364
|
-
for parent in context.parent_stack[:-1]
|
|
365
|
-
)
|
|
366
|
-
|
|
367
|
-
if not is_in_recursive_cte and any(
|
|
368
|
-
isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
|
|
369
|
-
for parent in context.parent_stack[:-1] # Exclude current node
|
|
370
|
-
):
|
|
371
|
-
context.subquery_depth -= 1
|
|
372
|
-
if context.subquery_depth == 0:
|
|
373
|
-
context.in_subquery = False
|
|
374
|
-
context.in_select_list = False
|
|
375
|
-
|
|
376
|
-
def _process_literal_with_context(
|
|
377
|
-
self, literal: exp.Expression, context: ParameterizationContext
|
|
378
|
-
) -> exp.Expression:
|
|
379
|
-
"""Process a literal with awareness of its AST context."""
|
|
380
|
-
# Check if this literal should be preserved based on context
|
|
381
|
-
if self._should_preserve_literal_in_context(literal, context):
|
|
382
|
-
return literal
|
|
383
|
-
|
|
384
|
-
# Use optimized extraction for single-pass processing
|
|
385
|
-
value, type_hint, sqlglot_type, semantic_name = self._extract_literal_value_and_type_optimized(literal, context)
|
|
386
|
-
|
|
387
|
-
# Create TypedParameter object
|
|
388
|
-
from sqlspec.statement.parameters import TypedParameter
|
|
389
|
-
|
|
390
|
-
typed_param = TypedParameter(
|
|
391
|
-
value=value,
|
|
392
|
-
sqlglot_type=sqlglot_type or exp.DataType.build("VARCHAR"), # Fallback type
|
|
393
|
-
type_hint=type_hint,
|
|
394
|
-
semantic_name=semantic_name,
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
# Always track extracted parameters for proper merging
|
|
398
|
-
self.extracted_parameters.append(typed_param)
|
|
399
|
-
|
|
400
|
-
# If we're reordering, also add to final params directly
|
|
401
|
-
if self._is_reordering_needed:
|
|
402
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
403
|
-
# For dict format, we need a key
|
|
404
|
-
param_key = semantic_name or f"param_{len(self._final_params)}"
|
|
405
|
-
self._final_params[param_key] = typed_param
|
|
406
|
-
elif isinstance(self._final_params, list):
|
|
407
|
-
self._final_params.append(typed_param)
|
|
408
|
-
else:
|
|
409
|
-
# Fallback - this shouldn't happen but handle gracefully
|
|
410
|
-
if not hasattr(self, "_fallback_params"):
|
|
411
|
-
self._fallback_params = []
|
|
412
|
-
self._fallback_params.append(typed_param)
|
|
413
|
-
|
|
414
|
-
self._parameter_metadata.append(
|
|
415
|
-
{
|
|
416
|
-
"index": len(self._final_params if self._is_reordering_needed else self.extracted_parameters) - 1,
|
|
417
|
-
"type": type_hint,
|
|
418
|
-
"semantic_name": semantic_name,
|
|
419
|
-
"context": self._get_context_description(context),
|
|
420
|
-
}
|
|
421
|
-
)
|
|
422
|
-
|
|
423
|
-
# Create appropriate placeholder
|
|
424
|
-
return self._create_placeholder(hint=semantic_name)
|
|
425
|
-
|
|
426
|
-
def _process_existing_placeholder(self, node: exp.Placeholder, context: ParameterizationContext) -> exp.Expression:
|
|
427
|
-
"""Process an existing placeholder when reordering parameters."""
|
|
428
|
-
if self._original_params is None:
|
|
429
|
-
return node
|
|
430
|
-
|
|
431
|
-
if isinstance(self._original_params, (list, tuple)):
|
|
432
|
-
self._handle_list_params_for_placeholder(node)
|
|
433
|
-
elif isinstance(self._original_params, dict):
|
|
434
|
-
self._handle_dict_params_for_placeholder(node)
|
|
435
|
-
else:
|
|
436
|
-
# Single value parameter
|
|
437
|
-
self._handle_single_value_param_for_placeholder(node)
|
|
438
|
-
|
|
439
|
-
return node
|
|
440
|
-
|
|
441
|
-
def _handle_list_params_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
442
|
-
"""Handle list/tuple parameters for placeholder."""
|
|
443
|
-
if isinstance(self._original_params, (list, tuple)) and self._user_param_index < len(self._original_params):
|
|
444
|
-
value = self._original_params[self._user_param_index]
|
|
445
|
-
self._add_to_final_params(value, node)
|
|
446
|
-
self._user_param_index += 1
|
|
447
|
-
else:
|
|
448
|
-
# More placeholders than user parameters
|
|
449
|
-
self._add_to_final_params(None, node)
|
|
450
|
-
|
|
451
|
-
def _handle_dict_params_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
452
|
-
"""Handle dict parameters for placeholder."""
|
|
453
|
-
if not isinstance(self._original_params, dict):
|
|
454
|
-
self._add_to_final_params(None, node)
|
|
455
|
-
return
|
|
456
|
-
|
|
457
|
-
raw_placeholder_name = node.this if hasattr(node, "this") else None
|
|
458
|
-
if not raw_placeholder_name:
|
|
459
|
-
# Unnamed placeholder '?' with dict params is ambiguous
|
|
460
|
-
self._add_to_final_params(None, node)
|
|
461
|
-
return
|
|
462
|
-
|
|
463
|
-
# FIX: Normalize the placeholder name by stripping leading sigils
|
|
464
|
-
placeholder_name = raw_placeholder_name.lstrip(":@")
|
|
465
|
-
|
|
466
|
-
# Debug logging
|
|
467
|
-
|
|
468
|
-
if placeholder_name in self._original_params:
|
|
469
|
-
# Direct match for placeholder name
|
|
470
|
-
self._add_to_final_params(self._original_params[placeholder_name], node)
|
|
471
|
-
elif placeholder_name.isdigit() and self._user_param_index == 0:
|
|
472
|
-
# Oracle-style numeric parameters
|
|
473
|
-
self._handle_oracle_numeric_params()
|
|
474
|
-
self._user_param_index += 1
|
|
475
|
-
elif placeholder_name.isdigit() and self._user_param_index > 0:
|
|
476
|
-
# Already handled Oracle params
|
|
477
|
-
pass
|
|
478
|
-
elif self._user_param_index == 0 and len(self._original_params) > 0:
|
|
479
|
-
# Single dict parameter case
|
|
480
|
-
self._handle_single_dict_param()
|
|
481
|
-
self._user_param_index += 1
|
|
482
|
-
else:
|
|
483
|
-
# No match found
|
|
484
|
-
self._add_to_final_params(None, node)
|
|
485
|
-
|
|
486
|
-
def _handle_single_value_param_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
487
|
-
"""Handle single value parameter for placeholder."""
|
|
488
|
-
if self._user_param_index == 0:
|
|
489
|
-
self._add_to_final_params(self._original_params, node)
|
|
490
|
-
self._user_param_index += 1
|
|
491
|
-
else:
|
|
492
|
-
self._add_to_final_params(None, node)
|
|
493
|
-
|
|
494
|
-
def _handle_oracle_numeric_params(self) -> None:
|
|
495
|
-
"""Handle Oracle-style numeric parameters."""
|
|
496
|
-
if not isinstance(self._original_params, dict):
|
|
497
|
-
return
|
|
498
|
-
|
|
499
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
500
|
-
for k, v in self._original_params.items():
|
|
501
|
-
if k.isdigit():
|
|
502
|
-
self._final_params[k] = v
|
|
503
|
-
else:
|
|
504
|
-
# Convert to positional list
|
|
505
|
-
numeric_keys = [k for k in self._original_params if k.isdigit()]
|
|
506
|
-
if numeric_keys:
|
|
507
|
-
max_index = max(int(k) for k in numeric_keys)
|
|
508
|
-
param_list = [None] * (max_index + 1)
|
|
509
|
-
for k, v in self._original_params.items():
|
|
510
|
-
if k.isdigit():
|
|
511
|
-
param_list[int(k)] = v
|
|
512
|
-
if isinstance(self._final_params, list):
|
|
513
|
-
self._final_params.extend(param_list)
|
|
514
|
-
elif isinstance(self._final_params, dict):
|
|
515
|
-
for i, val in enumerate(param_list):
|
|
516
|
-
self._final_params[str(i)] = val
|
|
517
|
-
|
|
518
|
-
def _handle_single_dict_param(self) -> None:
|
|
519
|
-
"""Handle single dict parameter case."""
|
|
520
|
-
if not isinstance(self._original_params, dict):
|
|
521
|
-
return
|
|
522
|
-
|
|
523
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
524
|
-
for k, v in self._original_params.items():
|
|
525
|
-
self._final_params[k] = v
|
|
526
|
-
elif isinstance(self._final_params, list):
|
|
527
|
-
self._final_params.append(self._original_params)
|
|
528
|
-
elif isinstance(self._final_params, dict):
|
|
529
|
-
param_name = f"param_{len(self._final_params)}"
|
|
530
|
-
self._final_params[param_name] = self._original_params
|
|
531
|
-
|
|
532
|
-
def _add_to_final_params(self, value: Any, node: exp.Placeholder) -> None:
|
|
533
|
-
"""Add a value to final params with proper type handling."""
|
|
534
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
535
|
-
placeholder_name = node.this if hasattr(node, "this") else f"param_{self._user_param_index}"
|
|
536
|
-
self._final_params[placeholder_name] = value
|
|
537
|
-
elif isinstance(self._final_params, list):
|
|
538
|
-
self._final_params.append(value)
|
|
539
|
-
elif isinstance(self._final_params, dict):
|
|
540
|
-
param_name = f"param_{len(self._final_params)}"
|
|
541
|
-
self._final_params[param_name] = value
|
|
542
|
-
|
|
543
|
-
def _process_existing_parameter(self, node: exp.Parameter, context: ParameterizationContext) -> exp.Expression:
|
|
544
|
-
"""Process existing parameters (both numeric and named) when reordering parameters."""
|
|
545
|
-
# First try to get parameter name for named parameters (like BigQuery @param_name)
|
|
546
|
-
param_name = self._extract_parameter_name(node)
|
|
547
|
-
|
|
548
|
-
if param_name and isinstance(self._original_params, dict) and param_name in self._original_params:
|
|
549
|
-
value = self._original_params[param_name]
|
|
550
|
-
self._add_param_value_to_finals(value)
|
|
551
|
-
return node
|
|
552
|
-
|
|
553
|
-
# Fall back to numeric parameter handling for PostgreSQL-style parameters ($1, $2)
|
|
554
|
-
param_index = self._extract_parameter_index(node)
|
|
555
|
-
|
|
556
|
-
if self._original_params is None:
|
|
557
|
-
self._add_none_to_final_params()
|
|
558
|
-
elif isinstance(self._original_params, (list, tuple)):
|
|
559
|
-
self._handle_list_params_for_parameter_node(param_index)
|
|
560
|
-
elif isinstance(self._original_params, dict):
|
|
561
|
-
self._handle_dict_params_for_parameter_node(param_index)
|
|
562
|
-
elif param_index == 0:
|
|
563
|
-
# Single parameter case
|
|
564
|
-
self._add_param_value_to_finals(self._original_params)
|
|
565
|
-
else:
|
|
566
|
-
self._add_none_to_final_params()
|
|
567
|
-
|
|
568
|
-
# Return the parameter unchanged
|
|
569
|
-
return node
|
|
570
|
-
|
|
571
|
-
@staticmethod
|
|
572
|
-
def _extract_parameter_name(node: exp.Parameter) -> Optional[str]:
|
|
573
|
-
"""Extract parameter name from a Parameter node for named parameters."""
|
|
574
|
-
if hasattr(node, "this"):
|
|
575
|
-
if isinstance(node.this, exp.Var):
|
|
576
|
-
# Named parameter like @min_value -> min_value
|
|
577
|
-
return str(node.this.this)
|
|
578
|
-
if hasattr(node.this, "this"):
|
|
579
|
-
# Handle other node types that might contain the name
|
|
580
|
-
return str(node.this.this)
|
|
581
|
-
return None
|
|
582
|
-
|
|
583
|
-
@staticmethod
|
|
584
|
-
def _extract_parameter_index(node: exp.Parameter) -> Optional[int]:
|
|
585
|
-
"""Extract parameter index from a Parameter node."""
|
|
586
|
-
if hasattr(node, "this") and isinstance(node.this, Literal):
|
|
587
|
-
import contextlib
|
|
588
|
-
|
|
589
|
-
with contextlib.suppress(ValueError, TypeError):
|
|
590
|
-
return int(node.this.this) - 1 # Convert to 0-based index
|
|
591
|
-
return None
|
|
592
|
-
|
|
593
|
-
def _handle_list_params_for_parameter_node(self, param_index: Optional[int]) -> None:
|
|
594
|
-
"""Handle list/tuple parameters for Parameter node."""
|
|
595
|
-
if (
|
|
596
|
-
isinstance(self._original_params, (list, tuple))
|
|
597
|
-
and param_index is not None
|
|
598
|
-
and 0 <= param_index < len(self._original_params)
|
|
599
|
-
):
|
|
600
|
-
# Use the parameter at the specified index
|
|
601
|
-
self._add_param_value_to_finals(self._original_params[param_index])
|
|
602
|
-
else:
|
|
603
|
-
# More parameters than user provided
|
|
604
|
-
self._add_none_to_final_params()
|
|
605
|
-
|
|
606
|
-
def _handle_dict_params_for_parameter_node(self, param_index: Optional[int]) -> None:
|
|
607
|
-
"""Handle dict parameters for Parameter node."""
|
|
608
|
-
if param_index is not None:
|
|
609
|
-
self._handle_dict_param_with_index(param_index)
|
|
610
|
-
else:
|
|
611
|
-
self._add_none_to_final_params()
|
|
612
|
-
|
|
613
|
-
def _handle_dict_param_with_index(self, param_index: int) -> None:
|
|
614
|
-
"""Handle dict parameter when we have an index."""
|
|
615
|
-
if not isinstance(self._original_params, dict):
|
|
616
|
-
self._add_none_to_final_params()
|
|
617
|
-
return
|
|
618
|
-
|
|
619
|
-
# Try param_N key first
|
|
620
|
-
param_key = f"param_{param_index}"
|
|
621
|
-
if param_key in self._original_params:
|
|
622
|
-
self._add_dict_value_to_finals(param_key)
|
|
623
|
-
return
|
|
624
|
-
|
|
625
|
-
# Try direct numeric key (1-based)
|
|
626
|
-
numeric_key = str(param_index + 1)
|
|
627
|
-
if numeric_key in self._original_params:
|
|
628
|
-
self._add_dict_value_to_finals(numeric_key)
|
|
629
|
-
else:
|
|
630
|
-
self._add_none_to_final_params()
|
|
631
|
-
|
|
632
|
-
def _add_dict_value_to_finals(self, key: str) -> None:
|
|
633
|
-
"""Add a value from dict params to final params."""
|
|
634
|
-
if isinstance(self._original_params, dict) and key in self._original_params:
|
|
635
|
-
value = self._original_params[key]
|
|
636
|
-
if isinstance(self._final_params, list):
|
|
637
|
-
self._final_params.append(value)
|
|
638
|
-
elif isinstance(self._final_params, dict):
|
|
639
|
-
self._final_params[key] = value
|
|
640
|
-
|
|
641
|
-
def _add_param_value_to_finals(self, value: Any) -> None:
|
|
642
|
-
"""Add a parameter value to final params."""
|
|
643
|
-
if isinstance(self._final_params, list):
|
|
644
|
-
self._final_params.append(value)
|
|
645
|
-
elif isinstance(self._final_params, dict):
|
|
646
|
-
param_name = f"param_{len(self._final_params)}"
|
|
647
|
-
self._final_params[param_name] = value
|
|
648
|
-
|
|
649
|
-
def _add_none_to_final_params(self) -> None:
|
|
650
|
-
"""Add None to final params."""
|
|
651
|
-
if isinstance(self._final_params, list):
|
|
652
|
-
self._final_params.append(None)
|
|
653
|
-
elif isinstance(self._final_params, dict):
|
|
654
|
-
param_name = f"param_{len(self._final_params)}"
|
|
655
|
-
self._final_params[param_name] = None
|
|
656
|
-
|
|
657
|
-
def _process_postgresql_column_parameter(
|
|
658
|
-
self, node: exp.Column, context: ParameterizationContext
|
|
659
|
-
) -> exp.Expression:
|
|
660
|
-
"""Process PostgreSQL-style parameters that were parsed as columns ($1, $2)."""
|
|
661
|
-
# Extract the numeric part from $1, $2, etc.
|
|
662
|
-
column_name = str(node.this) if hasattr(node, "this") else ""
|
|
663
|
-
param_index = None
|
|
664
|
-
|
|
665
|
-
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
666
|
-
import contextlib
|
|
667
|
-
|
|
668
|
-
with contextlib.suppress(ValueError, TypeError):
|
|
669
|
-
param_index = int(column_name[1:]) - 1 # Convert to 0-based index
|
|
670
|
-
|
|
671
|
-
if self._original_params is None:
|
|
672
|
-
# No user parameters provided - don't add None
|
|
673
|
-
return node
|
|
674
|
-
if isinstance(self._original_params, (list, tuple)):
|
|
675
|
-
# When we have mixed parameter styles and reordering is needed,
|
|
676
|
-
# use sequential assignment based on _user_param_index
|
|
677
|
-
if self._is_reordering_needed:
|
|
678
|
-
# For mixed styles, parameters should be assigned sequentially
|
|
679
|
-
# regardless of the numeric value in the placeholder
|
|
680
|
-
if self._user_param_index < len(self._original_params):
|
|
681
|
-
param_value = self._original_params[self._user_param_index]
|
|
682
|
-
self._user_param_index += 1
|
|
683
|
-
else:
|
|
684
|
-
param_value = None
|
|
685
|
-
else:
|
|
686
|
-
# Non-mixed styles - use the numeric value from the placeholder
|
|
687
|
-
param_value = (
|
|
688
|
-
self._original_params[param_index]
|
|
689
|
-
if param_index is not None and 0 <= param_index < len(self._original_params)
|
|
690
|
-
else None
|
|
691
|
-
)
|
|
692
|
-
|
|
693
|
-
if param_value is not None:
|
|
694
|
-
# Add the parameter value to final params
|
|
695
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
696
|
-
param_key = f"param_{len(self._final_params)}"
|
|
697
|
-
self._final_params[param_key] = param_value
|
|
698
|
-
elif isinstance(self._final_params, list):
|
|
699
|
-
self._final_params.append(param_value)
|
|
700
|
-
elif isinstance(self._final_params, dict):
|
|
701
|
-
param_name = f"param_{len(self._final_params)}"
|
|
702
|
-
self._final_params[param_name] = param_value
|
|
703
|
-
# More parameters than user provided - don't add None
|
|
704
|
-
elif isinstance(self._original_params, dict):
|
|
705
|
-
# For dict parameters with numeric placeholders, try to map by index
|
|
706
|
-
if param_index is not None:
|
|
707
|
-
param_key = f"param_{param_index}"
|
|
708
|
-
if param_key in self._original_params:
|
|
709
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
710
|
-
self._final_params[param_key] = self._original_params[param_key]
|
|
711
|
-
elif isinstance(self._final_params, list):
|
|
712
|
-
self._final_params.append(self._original_params[param_key])
|
|
713
|
-
elif isinstance(self._final_params, dict):
|
|
714
|
-
self._final_params[param_key] = self._original_params[param_key]
|
|
715
|
-
else:
|
|
716
|
-
# Try direct numeric key
|
|
717
|
-
numeric_key = str(param_index + 1) # 1-based
|
|
718
|
-
if numeric_key in self._original_params:
|
|
719
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
720
|
-
self._final_params[numeric_key] = self._original_params[numeric_key]
|
|
721
|
-
elif isinstance(self._final_params, list):
|
|
722
|
-
self._final_params.append(self._original_params[numeric_key])
|
|
723
|
-
elif isinstance(self._final_params, dict):
|
|
724
|
-
self._final_params[numeric_key] = self._original_params[numeric_key]
|
|
725
|
-
# Single parameter case
|
|
726
|
-
elif param_index == 0:
|
|
727
|
-
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
728
|
-
param_key = f"param_{len(self._final_params)}"
|
|
729
|
-
self._final_params[param_key] = self._original_params
|
|
730
|
-
elif isinstance(self._final_params, list):
|
|
731
|
-
self._final_params.append(self._original_params)
|
|
732
|
-
elif isinstance(self._final_params, dict):
|
|
733
|
-
param_name = f"param_{len(self._final_params)}"
|
|
734
|
-
self._final_params[param_name] = self._original_params
|
|
735
|
-
|
|
736
|
-
# Return the column unchanged - it represents the parameter placeholder
|
|
737
|
-
return node
|
|
738
|
-
|
|
739
|
-
def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
740
|
-
"""Enhanced context-aware decision on literal preservation."""
|
|
741
|
-
# Existing preservation rules (maintain compatibility)
|
|
742
|
-
if self.preserve_null and isinstance(literal, Null):
|
|
743
|
-
return True
|
|
744
|
-
|
|
745
|
-
if self.preserve_boolean and isinstance(literal, Boolean):
|
|
746
|
-
return True
|
|
747
|
-
|
|
748
|
-
# NEW: Context-based preservation rules
|
|
749
|
-
|
|
750
|
-
# Rule 4: Preserve enum-like literals in subquery lookups (the main fix we need)
|
|
751
|
-
if context.in_subquery and self._is_scalar_lookup_pattern(literal, context):
|
|
752
|
-
return self._is_enum_like_literal(literal)
|
|
753
|
-
|
|
754
|
-
# Existing preservation rules continue...
|
|
755
|
-
|
|
756
|
-
# Check if in preserved function arguments
|
|
757
|
-
if context.in_function_args:
|
|
758
|
-
return True
|
|
759
|
-
|
|
760
|
-
# ENHANCED: Intelligent recursive CTE literal preservation
|
|
761
|
-
if self.preserve_in_recursive_cte and context.in_recursive_cte:
|
|
762
|
-
return self._should_preserve_literal_in_recursive_cte(literal, context)
|
|
763
|
-
|
|
764
|
-
# Check if this literal is being used as an alias value in SELECT
|
|
765
|
-
# e.g., 'computed' as process_status should be preserved
|
|
766
|
-
if hasattr(literal, "parent") and literal.parent:
|
|
767
|
-
parent = literal.parent
|
|
768
|
-
# Check if it's an Alias node and the literal is the expression (not the alias name)
|
|
769
|
-
if isinstance(parent, exp.Alias) and parent.this == literal:
|
|
770
|
-
# Check if this alias is in a SELECT clause
|
|
771
|
-
for ancestor in context.parent_stack:
|
|
772
|
-
if isinstance(ancestor, exp.Select):
|
|
773
|
-
return True
|
|
774
|
-
|
|
775
|
-
# Check parent context more intelligently
|
|
776
|
-
for parent in context.parent_stack:
|
|
777
|
-
# Preserve in schema/DDL contexts
|
|
778
|
-
if isinstance(parent, (DataType, exp.ColumnDef, exp.Create, exp.Schema)):
|
|
779
|
-
return True
|
|
780
|
-
|
|
781
|
-
# Preserve numbers in LIMIT/OFFSET
|
|
782
|
-
if (
|
|
783
|
-
self.preserve_numbers_in_limit
|
|
784
|
-
and isinstance(parent, (exp.Limit, exp.Offset))
|
|
785
|
-
and isinstance(literal, exp.Literal)
|
|
786
|
-
and self._is_number_literal(literal)
|
|
787
|
-
):
|
|
788
|
-
return True
|
|
789
|
-
|
|
790
|
-
# Preserve in CASE conditions for readability
|
|
791
|
-
if isinstance(parent, exp.Case) and context.in_case_when:
|
|
792
|
-
# Only preserve simple comparisons
|
|
793
|
-
return not isinstance(literal.parent, Binary)
|
|
794
|
-
|
|
795
|
-
# Check string length
|
|
796
|
-
if isinstance(literal, exp.Literal) and self._is_string_literal(literal):
|
|
797
|
-
string_value = str(literal.this)
|
|
798
|
-
if len(string_value) > self.max_string_length:
|
|
799
|
-
return True
|
|
800
|
-
|
|
801
|
-
return False
|
|
802
|
-
|
|
803
|
-
def _is_in_select_expressions(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
804
|
-
"""Check if literal is in SELECT clause expressions (critical for type inference)."""
|
|
805
|
-
for parent in reversed(context.parent_stack):
|
|
806
|
-
if isinstance(parent, exp.Select):
|
|
807
|
-
if hasattr(parent, "expressions") and parent.expressions:
|
|
808
|
-
return any(self._literal_is_in_expression_tree(literal, expr) for expr in parent.expressions)
|
|
809
|
-
elif isinstance(parent, (exp.Where, exp.Having, exp.Join)):
|
|
810
|
-
return False
|
|
811
|
-
return False
|
|
812
|
-
|
|
813
|
-
def _is_recursive_computation(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
814
|
-
"""Check if literal is part of recursive computation logic."""
|
|
815
|
-
# Look for arithmetic operations that are part of recursive logic
|
|
816
|
-
for parent in reversed(context.parent_stack):
|
|
817
|
-
if isinstance(parent, exp.Binary) and parent.key in ("ADD", "SUB", "MUL", "DIV"):
|
|
818
|
-
# Check if this arithmetic is in a SELECT clause of a recursive part
|
|
819
|
-
return self._is_in_select_expressions(literal, context)
|
|
820
|
-
return False
|
|
821
|
-
|
|
822
|
-
def _should_preserve_literal_in_recursive_cte(
|
|
823
|
-
self, literal: exp.Expression, context: ParameterizationContext
|
|
824
|
-
) -> bool:
|
|
825
|
-
"""Intelligent recursive CTE literal preservation based on semantic role."""
|
|
826
|
-
# Preserve SELECT clause literals (type inference critical)
|
|
827
|
-
if self._is_in_select_expressions(literal, context):
|
|
828
|
-
return True
|
|
829
|
-
|
|
830
|
-
# Preserve recursive computation literals (core logic)
|
|
831
|
-
return self._is_recursive_computation(literal, context)
|
|
832
|
-
|
|
833
|
-
def _literal_is_in_expression_tree(self, target_literal: exp.Expression, expr: exp.Expression) -> bool:
|
|
834
|
-
"""Check if target literal is within the given expression tree."""
|
|
835
|
-
if expr == target_literal:
|
|
836
|
-
return True
|
|
837
|
-
# Recursively check child expressions
|
|
838
|
-
return any(child == target_literal for child in expr.iter_expressions())
|
|
839
|
-
|
|
840
|
-
def _is_scalar_lookup_pattern(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
841
|
-
"""Detect if literal is part of a scalar subquery lookup pattern."""
|
|
842
|
-
# Must be in a subquery for this pattern to apply
|
|
843
|
-
if context.subquery_depth == 0:
|
|
844
|
-
return False
|
|
845
|
-
|
|
846
|
-
# Check if we're in a WHERE clause of a subquery that returns a single column
|
|
847
|
-
# and the literal is being compared against a column
|
|
848
|
-
for parent in reversed(context.parent_stack):
|
|
849
|
-
if isinstance(parent, exp.Where):
|
|
850
|
-
# Look for pattern: WHERE column = 'literal'
|
|
851
|
-
if isinstance(parent.this, exp.Binary) and parent.this.right == literal:
|
|
852
|
-
return isinstance(parent.this.left, exp.Column)
|
|
853
|
-
# Also check for literal on the left side: WHERE 'literal' = column
|
|
854
|
-
if isinstance(parent.this, exp.Binary) and parent.this.left == literal:
|
|
855
|
-
return isinstance(parent.this.right, exp.Column)
|
|
856
|
-
return False
|
|
857
|
-
|
|
858
|
-
def _is_enum_like_literal(self, literal: exp.Expression) -> bool:
|
|
859
|
-
"""Detect if literal looks like an enum/identifier constant."""
|
|
860
|
-
if not isinstance(literal, exp.Literal) or not self._is_string_literal(literal):
|
|
861
|
-
return False
|
|
862
|
-
|
|
863
|
-
value = str(literal.this)
|
|
864
|
-
|
|
865
|
-
# Conservative heuristics for enum-like values
|
|
866
|
-
return (
|
|
867
|
-
len(value) <= MAX_ENUM_LENGTH # Reasonable length limit
|
|
868
|
-
and value.replace("_", "").isalnum() # Only alphanumeric + underscores
|
|
869
|
-
and not value.isdigit() # Not a pure number
|
|
870
|
-
and len(value) > MIN_ENUM_LENGTH # Not too short to be meaningful
|
|
871
|
-
)
|
|
872
|
-
|
|
873
|
-
def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
|
|
874
|
-
"""Extract the Python value and type info from a SQLGlot literal."""
|
|
875
|
-
if isinstance(literal, Null) or literal.this is None:
|
|
876
|
-
return None, "null"
|
|
877
|
-
|
|
878
|
-
# Ensure we have a Literal for type checking methods
|
|
879
|
-
if not isinstance(literal, exp.Literal):
|
|
880
|
-
return str(literal), "string"
|
|
881
|
-
|
|
882
|
-
if isinstance(literal, Boolean) or isinstance(literal.this, bool):
|
|
883
|
-
return literal.this, "boolean"
|
|
884
|
-
|
|
885
|
-
if self._is_string_literal(literal):
|
|
886
|
-
return str(literal.this), "string"
|
|
887
|
-
|
|
888
|
-
if self._is_number_literal(literal):
|
|
889
|
-
# Preserve numeric precision if enabled
|
|
890
|
-
if self.type_preservation:
|
|
891
|
-
value_str = str(literal.this)
|
|
892
|
-
if "." in value_str or "e" in value_str.lower():
|
|
893
|
-
try:
|
|
894
|
-
# Check if it's a decimal that needs precision
|
|
895
|
-
decimal_places = len(value_str.split(".")[1]) if "." in value_str else 0
|
|
896
|
-
if decimal_places > MAX_DECIMAL_PRECISION: # Likely needs decimal precision
|
|
897
|
-
return value_str, "decimal"
|
|
898
|
-
return float(literal.this), "float"
|
|
899
|
-
except (ValueError, IndexError):
|
|
900
|
-
return str(literal.this), "numeric_string"
|
|
901
|
-
else:
|
|
902
|
-
try:
|
|
903
|
-
value = int(literal.this)
|
|
904
|
-
except ValueError:
|
|
905
|
-
return str(literal.this), "numeric_string"
|
|
906
|
-
else:
|
|
907
|
-
# Check for bigint
|
|
908
|
-
if abs(value) > MAX_INT32_VALUE: # Max 32-bit int
|
|
909
|
-
return value, "bigint"
|
|
910
|
-
return value, "integer"
|
|
911
|
-
else:
|
|
912
|
-
# Simple type conversion
|
|
913
|
-
try:
|
|
914
|
-
if "." in str(literal.this):
|
|
915
|
-
return float(literal.this), "float"
|
|
916
|
-
return int(literal.this), "integer"
|
|
917
|
-
except ValueError:
|
|
918
|
-
return str(literal.this), "numeric_string"
|
|
919
|
-
|
|
920
|
-
# Handle date/time literals - these are DataType attributes not Literal attributes
|
|
921
|
-
# Date/time values are typically string literals that need context-aware processing
|
|
922
|
-
# We'll return them as strings and let the database handle type conversion
|
|
923
|
-
|
|
924
|
-
# Fallback
|
|
925
|
-
return str(literal.this), "unknown"
|
|
926
|
-
|
|
927
|
-
def _extract_literal_value_and_type_optimized(
|
|
928
|
-
self, literal: exp.Expression, context: ParameterizationContext
|
|
929
|
-
) -> "tuple[Any, str, Optional[exp.DataType], Optional[str]]":
|
|
930
|
-
"""Single-pass extraction of value, type hint, SQLGlot type, and semantic name.
|
|
931
|
-
|
|
932
|
-
This optimized method extracts all information in one pass, avoiding redundant
|
|
933
|
-
AST traversals and expensive operations like literal.sql().
|
|
934
|
-
|
|
935
|
-
Args:
|
|
936
|
-
literal: The literal expression to extract from
|
|
937
|
-
context: Current parameterization context with parent stack
|
|
938
|
-
|
|
939
|
-
Returns:
|
|
940
|
-
Tuple of (value, type_hint, sqlglot_type, semantic_name)
|
|
941
|
-
"""
|
|
942
|
-
# Extract value and basic type hint using existing logic
|
|
943
|
-
value, type_hint = self._extract_literal_value_and_type(literal)
|
|
944
|
-
|
|
945
|
-
# Determine SQLGlot type based on the type hint without additional parsing
|
|
946
|
-
sqlglot_type = self._infer_sqlglot_type(type_hint, value)
|
|
947
|
-
|
|
948
|
-
# Generate semantic name from context if available
|
|
949
|
-
semantic_name = self._generate_semantic_name_from_context(literal, context)
|
|
950
|
-
|
|
951
|
-
return value, type_hint, sqlglot_type, semantic_name
|
|
952
|
-
|
|
953
|
-
@staticmethod
|
|
954
|
-
def _infer_sqlglot_type(type_hint: str, value: Any) -> "Optional[exp.DataType]":
|
|
955
|
-
"""Infer SQLGlot DataType from type hint without parsing.
|
|
956
|
-
|
|
957
|
-
Args:
|
|
958
|
-
type_hint: The simple type hint string
|
|
959
|
-
value: The actual value for additional context
|
|
960
|
-
|
|
961
|
-
Returns:
|
|
962
|
-
SQLGlot DataType instance or None
|
|
963
|
-
"""
|
|
964
|
-
type_mapping = {
|
|
965
|
-
"null": "NULL",
|
|
966
|
-
"boolean": "BOOLEAN",
|
|
967
|
-
"integer": "INT",
|
|
968
|
-
"bigint": "BIGINT",
|
|
969
|
-
"float": "FLOAT",
|
|
970
|
-
"decimal": "DECIMAL",
|
|
971
|
-
"string": "VARCHAR",
|
|
972
|
-
"numeric_string": "VARCHAR",
|
|
973
|
-
"unknown": "VARCHAR",
|
|
974
|
-
}
|
|
975
|
-
|
|
976
|
-
type_name = type_mapping.get(type_hint, "VARCHAR")
|
|
977
|
-
|
|
978
|
-
# Build DataType with appropriate parameters
|
|
979
|
-
if type_hint == "decimal" and isinstance(value, str):
|
|
980
|
-
# Try to infer precision and scale
|
|
981
|
-
parts = value.split(".")
|
|
982
|
-
precision = len(parts[0]) + len(parts[1]) if len(parts) > 1 else len(parts[0])
|
|
983
|
-
scale = len(parts[1]) if len(parts) > 1 else 0
|
|
984
|
-
return exp.DataType.build(type_name, expressions=[exp.Literal.number(precision), exp.Literal.number(scale)])
|
|
985
|
-
if type_hint == "string" and isinstance(value, str):
|
|
986
|
-
# Infer VARCHAR length
|
|
987
|
-
length = len(value)
|
|
988
|
-
if length > 0:
|
|
989
|
-
return exp.DataType.build(type_name, expressions=[exp.Literal.number(length)])
|
|
990
|
-
|
|
991
|
-
# Default case - just the type name
|
|
992
|
-
return exp.DataType.build(type_name)
|
|
993
|
-
|
|
994
|
-
@staticmethod
|
|
995
|
-
def _generate_semantic_name_from_context(
|
|
996
|
-
literal: exp.Expression, context: ParameterizationContext
|
|
997
|
-
) -> "Optional[str]":
|
|
998
|
-
"""Generate semantic name from AST context using existing parent stack.
|
|
999
|
-
|
|
1000
|
-
Args:
|
|
1001
|
-
literal: The literal being parameterized
|
|
1002
|
-
context: Current context with parent stack
|
|
1003
|
-
|
|
1004
|
-
Returns:
|
|
1005
|
-
Semantic name or None
|
|
1006
|
-
"""
|
|
1007
|
-
# Look for column comparisons in parent stack
|
|
1008
|
-
for parent in reversed(context.parent_stack):
|
|
1009
|
-
if isinstance(parent, Binary):
|
|
1010
|
-
# It's a comparison - check if we're comparing to a column
|
|
1011
|
-
if parent.left == literal and isinstance(parent.right, exp.Column):
|
|
1012
|
-
return parent.right.name
|
|
1013
|
-
if parent.right == literal and isinstance(parent.left, exp.Column):
|
|
1014
|
-
return parent.left.name
|
|
1015
|
-
elif isinstance(parent, exp.In):
|
|
1016
|
-
# IN clause - check the left side for column
|
|
1017
|
-
if parent.this and isinstance(parent.this, exp.Column):
|
|
1018
|
-
return f"{parent.this.name}_value"
|
|
1019
|
-
|
|
1020
|
-
# Check if we're in a specific SQL clause
|
|
1021
|
-
for parent in reversed(context.parent_stack):
|
|
1022
|
-
if isinstance(parent, exp.Where):
|
|
1023
|
-
return "where_value"
|
|
1024
|
-
if isinstance(parent, exp.Having):
|
|
1025
|
-
return "having_value"
|
|
1026
|
-
if isinstance(parent, exp.Join):
|
|
1027
|
-
return "join_value"
|
|
1028
|
-
if isinstance(parent, exp.Select):
|
|
1029
|
-
return "select_value"
|
|
1030
|
-
|
|
1031
|
-
return None
|
|
1032
|
-
|
|
1033
|
-
def _is_string_literal(self, literal: exp.Literal) -> bool:
|
|
1034
|
-
"""Check if a literal is a string."""
|
|
1035
|
-
# Check if it's explicitly a string literal
|
|
1036
|
-
return (hasattr(literal, "is_string") and literal.is_string) or (
|
|
1037
|
-
isinstance(literal.this, str) and not self._is_number_literal(literal)
|
|
1038
|
-
)
|
|
1039
|
-
|
|
1040
|
-
@staticmethod
|
|
1041
|
-
def _is_number_literal(literal: exp.Literal) -> bool:
|
|
1042
|
-
"""Check if a literal is a number."""
|
|
1043
|
-
# Check if it's explicitly a number literal
|
|
1044
|
-
if hasattr(literal, "is_number") and literal.is_number:
|
|
1045
|
-
return True
|
|
1046
|
-
if literal.this is None:
|
|
1047
|
-
return False
|
|
1048
|
-
# Try to determine if it's numeric by attempting conversion
|
|
1049
|
-
try:
|
|
1050
|
-
float(str(literal.this))
|
|
1051
|
-
except (ValueError, TypeError):
|
|
1052
|
-
return False
|
|
1053
|
-
return True
|
|
1054
|
-
|
|
1055
|
-
def _create_placeholder(self, hint: Optional[str] = None) -> exp.Expression:
|
|
1056
|
-
"""Create a placeholder expression with optional type hint."""
|
|
1057
|
-
# Import ParameterStyle for proper comparison
|
|
1058
|
-
|
|
1059
|
-
# Handle both style names and actual placeholder prefixes
|
|
1060
|
-
style = self.placeholder_style
|
|
1061
|
-
if style in {"?", ParameterStyle.QMARK, "qmark"}:
|
|
1062
|
-
placeholder = exp.Placeholder()
|
|
1063
|
-
elif style == ":name":
|
|
1064
|
-
# Use hint in parameter name if available
|
|
1065
|
-
param_name = f"{hint}_{self._parameter_counter}" if hint else f"param_{self._parameter_counter}"
|
|
1066
|
-
placeholder = exp.Placeholder(this=param_name)
|
|
1067
|
-
elif style in {ParameterStyle.NAMED_COLON, "named_colon"} or style.startswith(":"):
|
|
1068
|
-
param_name = f"param_{self._parameter_counter}"
|
|
1069
|
-
placeholder = exp.Placeholder(this=param_name)
|
|
1070
|
-
elif style in {ParameterStyle.NUMERIC, "numeric"} or style.startswith("$"):
|
|
1071
|
-
# PostgreSQL style numbered parameters - use Var for consistent $N format
|
|
1072
|
-
# Note: PostgreSQL uses 1-based indexing
|
|
1073
|
-
placeholder = exp.Var(this=f"${self._parameter_counter + 1}") # type: ignore[assignment]
|
|
1074
|
-
elif style in {ParameterStyle.NAMED_AT, "named_at"}:
|
|
1075
|
-
# BigQuery style @param - don't include @ in the placeholder name
|
|
1076
|
-
# The @ will be added during SQL generation
|
|
1077
|
-
# Use 0-based indexing for consistency with parameter arrays
|
|
1078
|
-
param_name = f"param_{self._parameter_counter}"
|
|
1079
|
-
placeholder = exp.Placeholder(this=param_name)
|
|
1080
|
-
elif style in {ParameterStyle.POSITIONAL_PYFORMAT, "pyformat"}:
|
|
1081
|
-
# Don't use pyformat directly in SQLGlot - use standard placeholder
|
|
1082
|
-
# and let the compile method convert it later
|
|
1083
|
-
placeholder = exp.Placeholder()
|
|
1084
|
-
else:
|
|
1085
|
-
# Default to question mark
|
|
1086
|
-
placeholder = exp.Placeholder()
|
|
1087
|
-
|
|
1088
|
-
# Increment counter after creating placeholder
|
|
1089
|
-
self._parameter_counter += 1
|
|
1090
|
-
return placeholder
|
|
1091
|
-
|
|
1092
|
-
def _process_array(self, array_node: Array, context: ParameterizationContext) -> exp.Expression:
|
|
1093
|
-
"""Process array literals for parameterization."""
|
|
1094
|
-
if not array_node.expressions:
|
|
1095
|
-
return array_node
|
|
1096
|
-
|
|
1097
|
-
# Check array size
|
|
1098
|
-
if len(array_node.expressions) > self.max_array_length:
|
|
1099
|
-
# Too large, preserve as-is
|
|
1100
|
-
return array_node
|
|
1101
|
-
|
|
1102
|
-
# Extract all array elements
|
|
1103
|
-
array_values = []
|
|
1104
|
-
element_types = []
|
|
1105
|
-
all_literals = True
|
|
1106
|
-
|
|
1107
|
-
for expr in array_node.expressions:
|
|
1108
|
-
if isinstance(expr, Literal):
|
|
1109
|
-
value, type_hint = self._extract_literal_value_and_type(expr)
|
|
1110
|
-
array_values.append(value)
|
|
1111
|
-
element_types.append(type_hint)
|
|
1112
|
-
else:
|
|
1113
|
-
all_literals = False
|
|
1114
|
-
break
|
|
1115
|
-
|
|
1116
|
-
if all_literals:
|
|
1117
|
-
# Determine array element type from the first element
|
|
1118
|
-
element_type = element_types[0] if element_types else "unknown"
|
|
1119
|
-
|
|
1120
|
-
# Create SQLGlot array type
|
|
1121
|
-
element_sqlglot_type = self._infer_sqlglot_type(element_type, array_values[0] if array_values else None)
|
|
1122
|
-
array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
|
|
1123
|
-
|
|
1124
|
-
# Create TypedParameter for the entire array
|
|
1125
|
-
|
|
1126
|
-
typed_param = TypedParameter(
|
|
1127
|
-
value=array_values,
|
|
1128
|
-
sqlglot_type=array_sqlglot_type,
|
|
1129
|
-
type_hint=f"array<{element_type}>",
|
|
1130
|
-
semantic_name="array_values",
|
|
1131
|
-
)
|
|
1132
|
-
|
|
1133
|
-
# Replace entire array with a single parameter
|
|
1134
|
-
self.extracted_parameters.append(typed_param)
|
|
1135
|
-
self._parameter_metadata.append(
|
|
1136
|
-
{
|
|
1137
|
-
"index": len(self.extracted_parameters) - 1,
|
|
1138
|
-
"type": f"array<{element_type}>",
|
|
1139
|
-
"length": len(array_values),
|
|
1140
|
-
"context": "array_literal",
|
|
1141
|
-
}
|
|
1142
|
-
)
|
|
1143
|
-
return self._create_placeholder("array")
|
|
1144
|
-
# Process individual elements
|
|
1145
|
-
new_expressions = []
|
|
1146
|
-
for expr in array_node.expressions:
|
|
1147
|
-
if isinstance(expr, Literal):
|
|
1148
|
-
new_expressions.append(self._process_literal_with_context(expr, context))
|
|
1149
|
-
else:
|
|
1150
|
-
new_expressions.append(self._transform_with_context(expr, context))
|
|
1151
|
-
array_node.set("expressions", new_expressions)
|
|
1152
|
-
return array_node
|
|
1153
|
-
|
|
1154
|
-
def _process_in_clause(self, in_node: exp.In, context: ParameterizationContext) -> exp.Expression:
|
|
1155
|
-
"""Process IN clause for intelligent parameterization."""
|
|
1156
|
-
# Check if it's a subquery IN clause (has 'query' in args)
|
|
1157
|
-
if in_node.args.get("query"):
|
|
1158
|
-
# Don't parameterize subqueries, just process them recursively
|
|
1159
|
-
in_node.set("query", self._transform_with_context(in_node.args["query"], context))
|
|
1160
|
-
return in_node
|
|
1161
|
-
|
|
1162
|
-
# Check if it has literal expressions (the values on the right side)
|
|
1163
|
-
if "expressions" not in in_node.args or not in_node.args["expressions"]:
|
|
1164
|
-
return in_node
|
|
1165
|
-
|
|
1166
|
-
# Check if the IN list is too large
|
|
1167
|
-
expressions = in_node.args["expressions"]
|
|
1168
|
-
if len(expressions) > self.max_in_list_size:
|
|
1169
|
-
# Consider alternative strategies for large IN lists
|
|
1170
|
-
return in_node
|
|
1171
|
-
|
|
1172
|
-
# Process the expressions in the IN clause
|
|
1173
|
-
has_literals = any(isinstance(expr, Literal) for expr in expressions)
|
|
1174
|
-
|
|
1175
|
-
if has_literals:
|
|
1176
|
-
# Transform literals in the IN list
|
|
1177
|
-
new_expressions = []
|
|
1178
|
-
for expr in expressions:
|
|
1179
|
-
if isinstance(expr, Literal):
|
|
1180
|
-
new_expressions.append(self._process_literal_with_context(expr, context))
|
|
1181
|
-
else:
|
|
1182
|
-
new_expressions.append(self._transform_with_context(expr, context))
|
|
1183
|
-
|
|
1184
|
-
# Update the IN node's expressions using set method
|
|
1185
|
-
in_node.set("expressions", new_expressions)
|
|
1186
|
-
|
|
1187
|
-
return in_node
|
|
1188
|
-
|
|
1189
|
-
def _get_context_description(self, context: ParameterizationContext) -> str:
|
|
1190
|
-
"""Get a description of the current parameterization context."""
|
|
1191
|
-
descriptions = []
|
|
1192
|
-
|
|
1193
|
-
if context.in_function_args:
|
|
1194
|
-
descriptions.append("function_args")
|
|
1195
|
-
if context.in_case_when:
|
|
1196
|
-
descriptions.append("case_when")
|
|
1197
|
-
if context.in_array:
|
|
1198
|
-
descriptions.append("array")
|
|
1199
|
-
if context.in_in_clause:
|
|
1200
|
-
descriptions.append("in_clause")
|
|
1201
|
-
|
|
1202
|
-
if not descriptions:
|
|
1203
|
-
# Try to determine from parent stack
|
|
1204
|
-
for parent in reversed(context.parent_stack):
|
|
1205
|
-
if isinstance(parent, exp.Select):
|
|
1206
|
-
descriptions.append("select")
|
|
1207
|
-
break
|
|
1208
|
-
if isinstance(parent, exp.Where):
|
|
1209
|
-
descriptions.append("where")
|
|
1210
|
-
break
|
|
1211
|
-
if isinstance(parent, exp.Join):
|
|
1212
|
-
descriptions.append("join")
|
|
1213
|
-
break
|
|
1214
|
-
|
|
1215
|
-
return "_".join(descriptions) if descriptions else "general"
|
|
1216
|
-
|
|
1217
|
-
def get_parameters(self) -> list[Any]:
|
|
1218
|
-
"""Get the list of extracted parameters from the last processing operation.
|
|
1219
|
-
|
|
1220
|
-
Returns:
|
|
1221
|
-
List of parameter values extracted during the last process() call.
|
|
1222
|
-
"""
|
|
1223
|
-
return self.extracted_parameters.copy()
|
|
1224
|
-
|
|
1225
|
-
def get_parameter_metadata(self) -> list[dict[str, Any]]:
|
|
1226
|
-
"""Get metadata about extracted parameters for advanced usage.
|
|
1227
|
-
|
|
1228
|
-
Returns:
|
|
1229
|
-
List of parameter metadata dictionaries.
|
|
1230
|
-
"""
|
|
1231
|
-
return self._parameter_metadata.copy()
|
|
1232
|
-
|
|
1233
|
-
def _contains_union(self, cte_node: exp.CTE) -> bool:
|
|
1234
|
-
"""Check if a CTE contains a UNION (characteristic of recursive CTEs)."""
|
|
1235
|
-
|
|
1236
|
-
def has_union(node: exp.Expression) -> bool:
|
|
1237
|
-
if isinstance(node, exp.Union):
|
|
1238
|
-
return True
|
|
1239
|
-
return any(has_union(child) for child in node.iter_expressions())
|
|
1240
|
-
|
|
1241
|
-
return cte_node.this and has_union(cte_node.this)
|
|
1242
|
-
|
|
1243
|
-
def clear_parameters(self) -> None:
|
|
1244
|
-
"""Clear the extracted parameters list."""
|
|
1245
|
-
self.extracted_parameters = []
|
|
1246
|
-
self._parameter_counter = 0
|
|
1247
|
-
self._parameter_metadata = []
|