sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Optional
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
5
|
|
|
6
6
|
from sqlglot import exp
|
|
7
7
|
from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
|
|
8
8
|
|
|
9
|
-
from sqlspec.
|
|
10
|
-
from sqlspec.statement.
|
|
9
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
10
|
+
from sqlspec.statement.parameters import ParameterStyle, TypedParameter
|
|
11
11
|
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
12
12
|
|
|
13
13
|
__all__ = ("ParameterizationContext", "ParameterizeLiterals")
|
|
@@ -24,6 +24,12 @@ DEFAULT_MAX_ARRAY_LENGTH = 100
|
|
|
24
24
|
DEFAULT_MAX_IN_LIST_SIZE = 50
|
|
25
25
|
"""Default maximum IN clause list size before parameterization."""
|
|
26
26
|
|
|
27
|
+
MAX_ENUM_LENGTH = 50
|
|
28
|
+
"""Maximum length for enum-like string values."""
|
|
29
|
+
|
|
30
|
+
MIN_ENUM_LENGTH = 2
|
|
31
|
+
"""Minimum length for enum-like string values to be meaningful."""
|
|
32
|
+
|
|
27
33
|
|
|
28
34
|
@dataclass
|
|
29
35
|
class ParameterizationContext:
|
|
@@ -34,7 +40,13 @@ class ParameterizationContext:
|
|
|
34
40
|
in_case_when: bool = False
|
|
35
41
|
in_array: bool = False
|
|
36
42
|
in_in_clause: bool = False
|
|
43
|
+
in_recursive_cte: bool = False
|
|
44
|
+
in_subquery: bool = False
|
|
45
|
+
in_select_list: bool = False
|
|
46
|
+
in_join_condition: bool = False
|
|
37
47
|
function_depth: int = 0
|
|
48
|
+
cte_depth: int = 0
|
|
49
|
+
subquery_depth: int = 0
|
|
38
50
|
|
|
39
51
|
|
|
40
52
|
class ParameterizeLiterals(ProcessorProtocol):
|
|
@@ -53,6 +65,7 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
53
65
|
preserve_boolean: Whether to preserve boolean literals as-is.
|
|
54
66
|
preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
|
|
55
67
|
preserve_in_functions: List of function names where literals should be preserved.
|
|
68
|
+
preserve_in_recursive_cte: Whether to preserve literals in recursive CTEs (default True to avoid type inference issues).
|
|
56
69
|
parameterize_arrays: Whether to parameterize array literals.
|
|
57
70
|
parameterize_in_lists: Whether to parameterize IN clause lists.
|
|
58
71
|
max_string_length: Maximum string length to parameterize.
|
|
@@ -68,6 +81,7 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
68
81
|
preserve_boolean: bool = True,
|
|
69
82
|
preserve_numbers_in_limit: bool = True,
|
|
70
83
|
preserve_in_functions: Optional[list[str]] = None,
|
|
84
|
+
preserve_in_recursive_cte: bool = True,
|
|
71
85
|
parameterize_arrays: bool = True,
|
|
72
86
|
parameterize_in_lists: bool = True,
|
|
73
87
|
max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
|
|
@@ -79,7 +93,19 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
79
93
|
self.preserve_null = preserve_null
|
|
80
94
|
self.preserve_boolean = preserve_boolean
|
|
81
95
|
self.preserve_numbers_in_limit = preserve_numbers_in_limit
|
|
82
|
-
self.
|
|
96
|
+
self.preserve_in_recursive_cte = preserve_in_recursive_cte
|
|
97
|
+
self.preserve_in_functions = preserve_in_functions or [
|
|
98
|
+
"COALESCE",
|
|
99
|
+
"IFNULL",
|
|
100
|
+
"NVL",
|
|
101
|
+
"ISNULL",
|
|
102
|
+
# Array functions that take dimension arguments
|
|
103
|
+
"ARRAYSIZE", # SQLglot converts array_length to ArraySize
|
|
104
|
+
"ARRAY_UPPER",
|
|
105
|
+
"ARRAY_LOWER",
|
|
106
|
+
"ARRAY_NDIMS",
|
|
107
|
+
"ROUND",
|
|
108
|
+
]
|
|
83
109
|
self.parameterize_arrays = parameterize_arrays
|
|
84
110
|
self.parameterize_in_lists = parameterize_in_lists
|
|
85
111
|
self.max_string_length = max_string_length
|
|
@@ -89,20 +115,64 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
89
115
|
self.extracted_parameters: list[Any] = []
|
|
90
116
|
self._parameter_counter = 0
|
|
91
117
|
self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
|
|
118
|
+
self._preserve_dict_format = False # Track whether to preserve dict format
|
|
92
119
|
|
|
93
120
|
def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
|
|
94
121
|
"""Advanced literal parameterization with context-aware AST analysis."""
|
|
95
|
-
if expression is None or context.current_expression is None
|
|
122
|
+
if expression is None or context.current_expression is None:
|
|
123
|
+
return expression
|
|
124
|
+
|
|
125
|
+
# For named parameters (like BigQuery @param), don't reorder to avoid breaking name mapping
|
|
126
|
+
if (
|
|
127
|
+
context.config.input_sql_had_placeholders
|
|
128
|
+
and context.parameter_info
|
|
129
|
+
and any(p.name for p in context.parameter_info)
|
|
130
|
+
):
|
|
96
131
|
return expression
|
|
97
132
|
|
|
98
133
|
self.extracted_parameters = []
|
|
99
|
-
self._parameter_counter = 0
|
|
100
134
|
self._parameter_metadata = []
|
|
101
135
|
|
|
136
|
+
# When reordering is needed (SQL already has placeholders), we need to start
|
|
137
|
+
# our counter at the number of existing parameters to avoid conflicts
|
|
138
|
+
if context.config.input_sql_had_placeholders and context.parameter_info:
|
|
139
|
+
# Find the highest ordinal among existing parameters
|
|
140
|
+
max_ordinal = max(p.ordinal for p in context.parameter_info)
|
|
141
|
+
self._parameter_counter = max_ordinal + 1
|
|
142
|
+
else:
|
|
143
|
+
self._parameter_counter = 0
|
|
144
|
+
|
|
145
|
+
# Track original user parameters for proper merging
|
|
146
|
+
self._original_params = context.merged_parameters
|
|
147
|
+
self._user_param_index = 0
|
|
148
|
+
# If original params are dict and we have named placeholders, preserve dict format
|
|
149
|
+
if isinstance(context.merged_parameters, dict) and context.parameter_info:
|
|
150
|
+
# Check if we have named placeholders
|
|
151
|
+
has_named = any(p.name for p in context.parameter_info)
|
|
152
|
+
if has_named:
|
|
153
|
+
self._final_params: Union[dict[str, Any], list[Any]] = {}
|
|
154
|
+
self._preserve_dict_format = True
|
|
155
|
+
else:
|
|
156
|
+
self._final_params = []
|
|
157
|
+
self._preserve_dict_format = False
|
|
158
|
+
else:
|
|
159
|
+
self._final_params = []
|
|
160
|
+
self._preserve_dict_format = False
|
|
161
|
+
self._is_reordering_needed = context.config.input_sql_had_placeholders
|
|
162
|
+
|
|
102
163
|
param_context = ParameterizationContext(parent_stack=[])
|
|
103
164
|
transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
|
|
104
165
|
context.current_expression = transformed_expression
|
|
105
|
-
|
|
166
|
+
|
|
167
|
+
# If we're reordering, update the merged parameters with the reordered result
|
|
168
|
+
# In this case, we don't need to add to extracted_parameters_from_pipeline
|
|
169
|
+
# because the parameters are already in _final_params
|
|
170
|
+
if self._is_reordering_needed and self._final_params:
|
|
171
|
+
context.merged_parameters = self._final_params
|
|
172
|
+
else:
|
|
173
|
+
# Only add extracted parameters to the pipeline if we're not reordering
|
|
174
|
+
# This prevents duplication when parameters are already in merged_parameters
|
|
175
|
+
context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
|
|
106
176
|
|
|
107
177
|
context.metadata["parameter_metadata"] = self._parameter_metadata
|
|
108
178
|
|
|
@@ -123,6 +193,32 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
123
193
|
result = self._process_array(node, context)
|
|
124
194
|
elif isinstance(node, exp.In) and self.parameterize_in_lists:
|
|
125
195
|
result = self._process_in_clause(node, context)
|
|
196
|
+
elif isinstance(node, exp.Placeholder) and self._is_reordering_needed:
|
|
197
|
+
# Handle existing placeholders when reordering is needed
|
|
198
|
+
result = self._process_existing_placeholder(node, context)
|
|
199
|
+
elif isinstance(node, exp.Parameter) and self._is_reordering_needed:
|
|
200
|
+
# Handle PostgreSQL-style parameters ($1, $2) when reordering is needed
|
|
201
|
+
result = self._process_existing_parameter(node, context)
|
|
202
|
+
elif isinstance(node, exp.Column) and self._is_reordering_needed:
|
|
203
|
+
# Check if this column looks like a PostgreSQL parameter ($1, $2, etc.)
|
|
204
|
+
column_name = str(node.this) if hasattr(node, "this") else ""
|
|
205
|
+
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
206
|
+
# This is a PostgreSQL-style parameter parsed as a column
|
|
207
|
+
result = self._process_postgresql_column_parameter(node, context)
|
|
208
|
+
else:
|
|
209
|
+
# Regular column - process children
|
|
210
|
+
for key, value in node.args.items():
|
|
211
|
+
if isinstance(value, exp.Expression):
|
|
212
|
+
node.set(key, self._transform_with_context(value, context))
|
|
213
|
+
elif isinstance(value, list):
|
|
214
|
+
node.set(
|
|
215
|
+
key,
|
|
216
|
+
[
|
|
217
|
+
self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
|
|
218
|
+
for v in value
|
|
219
|
+
],
|
|
220
|
+
)
|
|
221
|
+
result = node
|
|
126
222
|
else:
|
|
127
223
|
# Recursively process children
|
|
128
224
|
for key, value in node.args.items():
|
|
@@ -146,36 +242,136 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
146
242
|
def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
|
|
147
243
|
"""Update parameterization context based on current AST node."""
|
|
148
244
|
if entering:
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
if isinstance(node, Func):
|
|
152
|
-
context.function_depth += 1
|
|
153
|
-
# Get function name from class name or node.name
|
|
154
|
-
func_name = node.__class__.__name__.upper()
|
|
155
|
-
if func_name in self.preserve_in_functions or (
|
|
156
|
-
node.name and node.name.upper() in self.preserve_in_functions
|
|
157
|
-
):
|
|
158
|
-
context.in_function_args = True
|
|
159
|
-
elif isinstance(node, exp.Case):
|
|
160
|
-
context.in_case_when = True
|
|
161
|
-
elif isinstance(node, Array):
|
|
162
|
-
context.in_array = True
|
|
163
|
-
elif isinstance(node, exp.In):
|
|
164
|
-
context.in_in_clause = True
|
|
245
|
+
self._update_context_entering(node, context)
|
|
165
246
|
else:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
247
|
+
self._update_context_leaving(node, context)
|
|
248
|
+
|
|
249
|
+
def _update_context_entering(self, node: exp.Expression, context: ParameterizationContext) -> None:
|
|
250
|
+
"""Update context when entering a node."""
|
|
251
|
+
context.parent_stack.append(node)
|
|
252
|
+
|
|
253
|
+
if isinstance(node, Func):
|
|
254
|
+
self._update_context_entering_func(node, context)
|
|
255
|
+
elif isinstance(node, exp.Case):
|
|
256
|
+
context.in_case_when = True
|
|
257
|
+
elif isinstance(node, Array):
|
|
258
|
+
context.in_array = True
|
|
259
|
+
elif isinstance(node, exp.In):
|
|
260
|
+
context.in_in_clause = True
|
|
261
|
+
elif isinstance(node, exp.CTE):
|
|
262
|
+
self._update_context_entering_cte(node, context)
|
|
263
|
+
elif isinstance(node, exp.Subquery):
|
|
264
|
+
context.subquery_depth += 1
|
|
265
|
+
context.in_subquery = True
|
|
266
|
+
elif isinstance(node, exp.Select):
|
|
267
|
+
self._update_context_entering_select(node, context)
|
|
268
|
+
elif isinstance(node, exp.Join):
|
|
269
|
+
context.in_join_condition = True
|
|
270
|
+
|
|
271
|
+
def _update_context_entering_func(self, node: Func, context: ParameterizationContext) -> None:
|
|
272
|
+
"""Update context when entering a function node."""
|
|
273
|
+
context.function_depth += 1
|
|
274
|
+
# Get function name from class name or node.name
|
|
275
|
+
func_name = node.__class__.__name__.upper()
|
|
276
|
+
if func_name in self.preserve_in_functions or (node.name and node.name.upper() in self.preserve_in_functions):
|
|
277
|
+
context.in_function_args = True
|
|
278
|
+
|
|
279
|
+
def _update_context_entering_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
|
|
280
|
+
"""Update context when entering a CTE node."""
|
|
281
|
+
context.cte_depth += 1
|
|
282
|
+
# Check if this CTE is recursive:
|
|
283
|
+
# 1. Parent WITH must be RECURSIVE
|
|
284
|
+
# 2. CTE must contain UNION (characteristic of recursive CTEs)
|
|
285
|
+
is_in_recursive_with = any(
|
|
286
|
+
isinstance(parent, exp.With) and parent.args.get("recursive", False)
|
|
287
|
+
for parent in reversed(context.parent_stack)
|
|
288
|
+
)
|
|
289
|
+
if is_in_recursive_with and self._contains_union(node):
|
|
290
|
+
context.in_recursive_cte = True
|
|
291
|
+
|
|
292
|
+
def _update_context_entering_select(self, node: exp.Select, context: ParameterizationContext) -> None:
|
|
293
|
+
"""Update context when entering a SELECT node."""
|
|
294
|
+
# Only track nested SELECT statements as subqueries if they're not part of a recursive CTE
|
|
295
|
+
is_in_recursive_cte = any(
|
|
296
|
+
isinstance(parent, exp.CTE)
|
|
297
|
+
and any(
|
|
298
|
+
isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
|
|
299
|
+
for grandparent in context.parent_stack
|
|
300
|
+
)
|
|
301
|
+
for parent in context.parent_stack[:-1]
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if not is_in_recursive_cte and any(
|
|
305
|
+
isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
|
|
306
|
+
for parent in context.parent_stack[:-1] # Exclude the current node
|
|
307
|
+
):
|
|
308
|
+
context.subquery_depth += 1
|
|
309
|
+
context.in_subquery = True
|
|
310
|
+
# Check if we're in a SELECT clause expressions list
|
|
311
|
+
if hasattr(node, "expressions"):
|
|
312
|
+
# We'll handle this specifically when processing individual expressions
|
|
313
|
+
context.in_select_list = False # Will be detected by _is_in_select_expressions
|
|
314
|
+
|
|
315
|
+
def _update_context_leaving(self, node: exp.Expression, context: ParameterizationContext) -> None:
|
|
316
|
+
"""Update context when leaving a node."""
|
|
317
|
+
if context.parent_stack:
|
|
318
|
+
context.parent_stack.pop()
|
|
319
|
+
|
|
320
|
+
if isinstance(node, Func):
|
|
321
|
+
self._update_context_leaving_func(node, context)
|
|
322
|
+
elif isinstance(node, exp.Case):
|
|
323
|
+
context.in_case_when = False
|
|
324
|
+
elif isinstance(node, Array):
|
|
325
|
+
context.in_array = False
|
|
326
|
+
elif isinstance(node, exp.In):
|
|
327
|
+
context.in_in_clause = False
|
|
328
|
+
elif isinstance(node, exp.CTE):
|
|
329
|
+
self._update_context_leaving_cte(node, context)
|
|
330
|
+
elif isinstance(node, exp.Subquery):
|
|
331
|
+
self._update_context_leaving_subquery(node, context)
|
|
332
|
+
elif isinstance(node, exp.Select):
|
|
333
|
+
self._update_context_leaving_select(node, context)
|
|
334
|
+
elif isinstance(node, exp.Join):
|
|
335
|
+
context.in_join_condition = False
|
|
336
|
+
|
|
337
|
+
def _update_context_leaving_func(self, node: Func, context: ParameterizationContext) -> None:
|
|
338
|
+
"""Update context when leaving a function node."""
|
|
339
|
+
context.function_depth -= 1
|
|
340
|
+
if context.function_depth == 0:
|
|
341
|
+
context.in_function_args = False
|
|
342
|
+
|
|
343
|
+
def _update_context_leaving_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
|
|
344
|
+
"""Update context when leaving a CTE node."""
|
|
345
|
+
context.cte_depth -= 1
|
|
346
|
+
if context.cte_depth == 0:
|
|
347
|
+
context.in_recursive_cte = False
|
|
348
|
+
|
|
349
|
+
def _update_context_leaving_subquery(self, node: exp.Subquery, context: ParameterizationContext) -> None:
|
|
350
|
+
"""Update context when leaving a subquery node."""
|
|
351
|
+
context.subquery_depth -= 1
|
|
352
|
+
if context.subquery_depth == 0:
|
|
353
|
+
context.in_subquery = False
|
|
354
|
+
|
|
355
|
+
def _update_context_leaving_select(self, node: exp.Select, context: ParameterizationContext) -> None:
|
|
356
|
+
"""Update context when leaving a SELECT node."""
|
|
357
|
+
# Only decrement if this was a nested SELECT (not part of recursive CTE)
|
|
358
|
+
is_in_recursive_cte = any(
|
|
359
|
+
isinstance(parent, exp.CTE)
|
|
360
|
+
and any(
|
|
361
|
+
isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
|
|
362
|
+
for grandparent in context.parent_stack
|
|
363
|
+
)
|
|
364
|
+
for parent in context.parent_stack[:-1]
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if not is_in_recursive_cte and any(
|
|
368
|
+
isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
|
|
369
|
+
for parent in context.parent_stack[:-1] # Exclude current node
|
|
370
|
+
):
|
|
371
|
+
context.subquery_depth -= 1
|
|
372
|
+
if context.subquery_depth == 0:
|
|
373
|
+
context.in_subquery = False
|
|
374
|
+
context.in_select_list = False
|
|
179
375
|
|
|
180
376
|
def _process_literal_with_context(
|
|
181
377
|
self, literal: exp.Expression, context: ParameterizationContext
|
|
@@ -198,35 +394,384 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
198
394
|
semantic_name=semantic_name,
|
|
199
395
|
)
|
|
200
396
|
|
|
201
|
-
#
|
|
397
|
+
# Always track extracted parameters for proper merging
|
|
202
398
|
self.extracted_parameters.append(typed_param)
|
|
399
|
+
|
|
400
|
+
# If we're reordering, also add to final params directly
|
|
401
|
+
if self._is_reordering_needed:
|
|
402
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
403
|
+
# For dict format, we need a key
|
|
404
|
+
param_key = semantic_name or f"param_{len(self._final_params)}"
|
|
405
|
+
self._final_params[param_key] = typed_param
|
|
406
|
+
elif isinstance(self._final_params, list):
|
|
407
|
+
self._final_params.append(typed_param)
|
|
408
|
+
else:
|
|
409
|
+
# Fallback - this shouldn't happen but handle gracefully
|
|
410
|
+
if not hasattr(self, "_fallback_params"):
|
|
411
|
+
self._fallback_params = []
|
|
412
|
+
self._fallback_params.append(typed_param)
|
|
413
|
+
|
|
203
414
|
self._parameter_metadata.append(
|
|
204
415
|
{
|
|
205
|
-
"index": len(self.extracted_parameters) - 1,
|
|
416
|
+
"index": len(self._final_params if self._is_reordering_needed else self.extracted_parameters) - 1,
|
|
206
417
|
"type": type_hint,
|
|
207
418
|
"semantic_name": semantic_name,
|
|
208
419
|
"context": self._get_context_description(context),
|
|
209
|
-
# Note: We avoid calling literal.sql() for performance
|
|
210
420
|
}
|
|
211
421
|
)
|
|
212
422
|
|
|
213
423
|
# Create appropriate placeholder
|
|
214
424
|
return self._create_placeholder(hint=semantic_name)
|
|
215
425
|
|
|
426
|
+
def _process_existing_placeholder(self, node: exp.Placeholder, context: ParameterizationContext) -> exp.Expression:
|
|
427
|
+
"""Process an existing placeholder when reordering parameters."""
|
|
428
|
+
if self._original_params is None:
|
|
429
|
+
return node
|
|
430
|
+
|
|
431
|
+
if isinstance(self._original_params, (list, tuple)):
|
|
432
|
+
self._handle_list_params_for_placeholder(node)
|
|
433
|
+
elif isinstance(self._original_params, dict):
|
|
434
|
+
self._handle_dict_params_for_placeholder(node)
|
|
435
|
+
else:
|
|
436
|
+
# Single value parameter
|
|
437
|
+
self._handle_single_value_param_for_placeholder(node)
|
|
438
|
+
|
|
439
|
+
return node
|
|
440
|
+
|
|
441
|
+
def _handle_list_params_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
442
|
+
"""Handle list/tuple parameters for placeholder."""
|
|
443
|
+
if isinstance(self._original_params, (list, tuple)) and self._user_param_index < len(self._original_params):
|
|
444
|
+
value = self._original_params[self._user_param_index]
|
|
445
|
+
self._add_to_final_params(value, node)
|
|
446
|
+
self._user_param_index += 1
|
|
447
|
+
else:
|
|
448
|
+
# More placeholders than user parameters
|
|
449
|
+
self._add_to_final_params(None, node)
|
|
450
|
+
|
|
451
|
+
def _handle_dict_params_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
452
|
+
"""Handle dict parameters for placeholder."""
|
|
453
|
+
if not isinstance(self._original_params, dict):
|
|
454
|
+
self._add_to_final_params(None, node)
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
raw_placeholder_name = node.this if hasattr(node, "this") else None
|
|
458
|
+
if not raw_placeholder_name:
|
|
459
|
+
# Unnamed placeholder '?' with dict params is ambiguous
|
|
460
|
+
self._add_to_final_params(None, node)
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
# FIX: Normalize the placeholder name by stripping leading sigils
|
|
464
|
+
placeholder_name = raw_placeholder_name.lstrip(":@")
|
|
465
|
+
|
|
466
|
+
# Debug logging
|
|
467
|
+
|
|
468
|
+
if placeholder_name in self._original_params:
|
|
469
|
+
# Direct match for placeholder name
|
|
470
|
+
self._add_to_final_params(self._original_params[placeholder_name], node)
|
|
471
|
+
elif placeholder_name.isdigit() and self._user_param_index == 0:
|
|
472
|
+
# Oracle-style numeric parameters
|
|
473
|
+
self._handle_oracle_numeric_params()
|
|
474
|
+
self._user_param_index += 1
|
|
475
|
+
elif placeholder_name.isdigit() and self._user_param_index > 0:
|
|
476
|
+
# Already handled Oracle params
|
|
477
|
+
pass
|
|
478
|
+
elif self._user_param_index == 0 and len(self._original_params) > 0:
|
|
479
|
+
# Single dict parameter case
|
|
480
|
+
self._handle_single_dict_param()
|
|
481
|
+
self._user_param_index += 1
|
|
482
|
+
else:
|
|
483
|
+
# No match found
|
|
484
|
+
self._add_to_final_params(None, node)
|
|
485
|
+
|
|
486
|
+
def _handle_single_value_param_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
487
|
+
"""Handle single value parameter for placeholder."""
|
|
488
|
+
if self._user_param_index == 0:
|
|
489
|
+
self._add_to_final_params(self._original_params, node)
|
|
490
|
+
self._user_param_index += 1
|
|
491
|
+
else:
|
|
492
|
+
self._add_to_final_params(None, node)
|
|
493
|
+
|
|
494
|
+
def _handle_oracle_numeric_params(self) -> None:
|
|
495
|
+
"""Handle Oracle-style numeric parameters."""
|
|
496
|
+
if not isinstance(self._original_params, dict):
|
|
497
|
+
return
|
|
498
|
+
|
|
499
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
500
|
+
for k, v in self._original_params.items():
|
|
501
|
+
if k.isdigit():
|
|
502
|
+
self._final_params[k] = v
|
|
503
|
+
else:
|
|
504
|
+
# Convert to positional list
|
|
505
|
+
numeric_keys = [k for k in self._original_params if k.isdigit()]
|
|
506
|
+
if numeric_keys:
|
|
507
|
+
max_index = max(int(k) for k in numeric_keys)
|
|
508
|
+
param_list = [None] * (max_index + 1)
|
|
509
|
+
for k, v in self._original_params.items():
|
|
510
|
+
if k.isdigit():
|
|
511
|
+
param_list[int(k)] = v
|
|
512
|
+
if isinstance(self._final_params, list):
|
|
513
|
+
self._final_params.extend(param_list)
|
|
514
|
+
elif isinstance(self._final_params, dict):
|
|
515
|
+
for i, val in enumerate(param_list):
|
|
516
|
+
self._final_params[str(i)] = val
|
|
517
|
+
|
|
518
|
+
def _handle_single_dict_param(self) -> None:
|
|
519
|
+
"""Handle single dict parameter case."""
|
|
520
|
+
if not isinstance(self._original_params, dict):
|
|
521
|
+
return
|
|
522
|
+
|
|
523
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
524
|
+
for k, v in self._original_params.items():
|
|
525
|
+
self._final_params[k] = v
|
|
526
|
+
elif isinstance(self._final_params, list):
|
|
527
|
+
self._final_params.append(self._original_params)
|
|
528
|
+
elif isinstance(self._final_params, dict):
|
|
529
|
+
param_name = f"param_{len(self._final_params)}"
|
|
530
|
+
self._final_params[param_name] = self._original_params
|
|
531
|
+
|
|
532
|
+
def _add_to_final_params(self, value: Any, node: exp.Placeholder) -> None:
|
|
533
|
+
"""Add a value to final params with proper type handling."""
|
|
534
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
535
|
+
placeholder_name = node.this if hasattr(node, "this") else f"param_{self._user_param_index}"
|
|
536
|
+
self._final_params[placeholder_name] = value
|
|
537
|
+
elif isinstance(self._final_params, list):
|
|
538
|
+
self._final_params.append(value)
|
|
539
|
+
elif isinstance(self._final_params, dict):
|
|
540
|
+
param_name = f"param_{len(self._final_params)}"
|
|
541
|
+
self._final_params[param_name] = value
|
|
542
|
+
|
|
543
|
+
def _process_existing_parameter(self, node: exp.Parameter, context: ParameterizationContext) -> exp.Expression:
|
|
544
|
+
"""Process existing parameters (both numeric and named) when reordering parameters."""
|
|
545
|
+
# First try to get parameter name for named parameters (like BigQuery @param_name)
|
|
546
|
+
param_name = self._extract_parameter_name(node)
|
|
547
|
+
|
|
548
|
+
if param_name and isinstance(self._original_params, dict) and param_name in self._original_params:
|
|
549
|
+
value = self._original_params[param_name]
|
|
550
|
+
self._add_param_value_to_finals(value)
|
|
551
|
+
return node
|
|
552
|
+
|
|
553
|
+
# Fall back to numeric parameter handling for PostgreSQL-style parameters ($1, $2)
|
|
554
|
+
param_index = self._extract_parameter_index(node)
|
|
555
|
+
|
|
556
|
+
if self._original_params is None:
|
|
557
|
+
self._add_none_to_final_params()
|
|
558
|
+
elif isinstance(self._original_params, (list, tuple)):
|
|
559
|
+
self._handle_list_params_for_parameter_node(param_index)
|
|
560
|
+
elif isinstance(self._original_params, dict):
|
|
561
|
+
self._handle_dict_params_for_parameter_node(param_index)
|
|
562
|
+
elif param_index == 0:
|
|
563
|
+
# Single parameter case
|
|
564
|
+
self._add_param_value_to_finals(self._original_params)
|
|
565
|
+
else:
|
|
566
|
+
self._add_none_to_final_params()
|
|
567
|
+
|
|
568
|
+
# Return the parameter unchanged
|
|
569
|
+
return node
|
|
570
|
+
|
|
571
|
+
@staticmethod
|
|
572
|
+
def _extract_parameter_name(node: exp.Parameter) -> Optional[str]:
|
|
573
|
+
"""Extract parameter name from a Parameter node for named parameters."""
|
|
574
|
+
if hasattr(node, "this"):
|
|
575
|
+
if isinstance(node.this, exp.Var):
|
|
576
|
+
# Named parameter like @min_value -> min_value
|
|
577
|
+
return str(node.this.this)
|
|
578
|
+
if hasattr(node.this, "this"):
|
|
579
|
+
# Handle other node types that might contain the name
|
|
580
|
+
return str(node.this.this)
|
|
581
|
+
return None
|
|
582
|
+
|
|
583
|
+
@staticmethod
|
|
584
|
+
def _extract_parameter_index(node: exp.Parameter) -> Optional[int]:
|
|
585
|
+
"""Extract parameter index from a Parameter node."""
|
|
586
|
+
if hasattr(node, "this") and isinstance(node.this, Literal):
|
|
587
|
+
import contextlib
|
|
588
|
+
|
|
589
|
+
with contextlib.suppress(ValueError, TypeError):
|
|
590
|
+
return int(node.this.this) - 1 # Convert to 0-based index
|
|
591
|
+
return None
|
|
592
|
+
|
|
593
|
+
def _handle_list_params_for_parameter_node(self, param_index: Optional[int]) -> None:
|
|
594
|
+
"""Handle list/tuple parameters for Parameter node."""
|
|
595
|
+
if (
|
|
596
|
+
isinstance(self._original_params, (list, tuple))
|
|
597
|
+
and param_index is not None
|
|
598
|
+
and 0 <= param_index < len(self._original_params)
|
|
599
|
+
):
|
|
600
|
+
# Use the parameter at the specified index
|
|
601
|
+
self._add_param_value_to_finals(self._original_params[param_index])
|
|
602
|
+
else:
|
|
603
|
+
# More parameters than user provided
|
|
604
|
+
self._add_none_to_final_params()
|
|
605
|
+
|
|
606
|
+
def _handle_dict_params_for_parameter_node(self, param_index: Optional[int]) -> None:
|
|
607
|
+
"""Handle dict parameters for Parameter node."""
|
|
608
|
+
if param_index is not None:
|
|
609
|
+
self._handle_dict_param_with_index(param_index)
|
|
610
|
+
else:
|
|
611
|
+
self._add_none_to_final_params()
|
|
612
|
+
|
|
613
|
+
def _handle_dict_param_with_index(self, param_index: int) -> None:
|
|
614
|
+
"""Handle dict parameter when we have an index."""
|
|
615
|
+
if not isinstance(self._original_params, dict):
|
|
616
|
+
self._add_none_to_final_params()
|
|
617
|
+
return
|
|
618
|
+
|
|
619
|
+
# Try param_N key first
|
|
620
|
+
param_key = f"param_{param_index}"
|
|
621
|
+
if param_key in self._original_params:
|
|
622
|
+
self._add_dict_value_to_finals(param_key)
|
|
623
|
+
return
|
|
624
|
+
|
|
625
|
+
# Try direct numeric key (1-based)
|
|
626
|
+
numeric_key = str(param_index + 1)
|
|
627
|
+
if numeric_key in self._original_params:
|
|
628
|
+
self._add_dict_value_to_finals(numeric_key)
|
|
629
|
+
else:
|
|
630
|
+
self._add_none_to_final_params()
|
|
631
|
+
|
|
632
|
+
def _add_dict_value_to_finals(self, key: str) -> None:
|
|
633
|
+
"""Add a value from dict params to final params."""
|
|
634
|
+
if isinstance(self._original_params, dict) and key in self._original_params:
|
|
635
|
+
value = self._original_params[key]
|
|
636
|
+
if isinstance(self._final_params, list):
|
|
637
|
+
self._final_params.append(value)
|
|
638
|
+
elif isinstance(self._final_params, dict):
|
|
639
|
+
self._final_params[key] = value
|
|
640
|
+
|
|
641
|
+
def _add_param_value_to_finals(self, value: Any) -> None:
|
|
642
|
+
"""Add a parameter value to final params."""
|
|
643
|
+
if isinstance(self._final_params, list):
|
|
644
|
+
self._final_params.append(value)
|
|
645
|
+
elif isinstance(self._final_params, dict):
|
|
646
|
+
param_name = f"param_{len(self._final_params)}"
|
|
647
|
+
self._final_params[param_name] = value
|
|
648
|
+
|
|
649
|
+
def _add_none_to_final_params(self) -> None:
|
|
650
|
+
"""Add None to final params."""
|
|
651
|
+
if isinstance(self._final_params, list):
|
|
652
|
+
self._final_params.append(None)
|
|
653
|
+
elif isinstance(self._final_params, dict):
|
|
654
|
+
param_name = f"param_{len(self._final_params)}"
|
|
655
|
+
self._final_params[param_name] = None
|
|
656
|
+
|
|
657
|
+
def _process_postgresql_column_parameter(
|
|
658
|
+
self, node: exp.Column, context: ParameterizationContext
|
|
659
|
+
) -> exp.Expression:
|
|
660
|
+
"""Process PostgreSQL-style parameters that were parsed as columns ($1, $2)."""
|
|
661
|
+
# Extract the numeric part from $1, $2, etc.
|
|
662
|
+
column_name = str(node.this) if hasattr(node, "this") else ""
|
|
663
|
+
param_index = None
|
|
664
|
+
|
|
665
|
+
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
666
|
+
import contextlib
|
|
667
|
+
|
|
668
|
+
with contextlib.suppress(ValueError, TypeError):
|
|
669
|
+
param_index = int(column_name[1:]) - 1 # Convert to 0-based index
|
|
670
|
+
|
|
671
|
+
if self._original_params is None:
|
|
672
|
+
# No user parameters provided - don't add None
|
|
673
|
+
return node
|
|
674
|
+
if isinstance(self._original_params, (list, tuple)):
|
|
675
|
+
# When we have mixed parameter styles and reordering is needed,
|
|
676
|
+
# use sequential assignment based on _user_param_index
|
|
677
|
+
if self._is_reordering_needed:
|
|
678
|
+
# For mixed styles, parameters should be assigned sequentially
|
|
679
|
+
# regardless of the numeric value in the placeholder
|
|
680
|
+
if self._user_param_index < len(self._original_params):
|
|
681
|
+
param_value = self._original_params[self._user_param_index]
|
|
682
|
+
self._user_param_index += 1
|
|
683
|
+
else:
|
|
684
|
+
param_value = None
|
|
685
|
+
else:
|
|
686
|
+
# Non-mixed styles - use the numeric value from the placeholder
|
|
687
|
+
param_value = (
|
|
688
|
+
self._original_params[param_index]
|
|
689
|
+
if param_index is not None and 0 <= param_index < len(self._original_params)
|
|
690
|
+
else None
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
if param_value is not None:
|
|
694
|
+
# Add the parameter value to final params
|
|
695
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
696
|
+
param_key = f"param_{len(self._final_params)}"
|
|
697
|
+
self._final_params[param_key] = param_value
|
|
698
|
+
elif isinstance(self._final_params, list):
|
|
699
|
+
self._final_params.append(param_value)
|
|
700
|
+
elif isinstance(self._final_params, dict):
|
|
701
|
+
param_name = f"param_{len(self._final_params)}"
|
|
702
|
+
self._final_params[param_name] = param_value
|
|
703
|
+
# More parameters than user provided - don't add None
|
|
704
|
+
elif isinstance(self._original_params, dict):
|
|
705
|
+
# For dict parameters with numeric placeholders, try to map by index
|
|
706
|
+
if param_index is not None:
|
|
707
|
+
param_key = f"param_{param_index}"
|
|
708
|
+
if param_key in self._original_params:
|
|
709
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
710
|
+
self._final_params[param_key] = self._original_params[param_key]
|
|
711
|
+
elif isinstance(self._final_params, list):
|
|
712
|
+
self._final_params.append(self._original_params[param_key])
|
|
713
|
+
elif isinstance(self._final_params, dict):
|
|
714
|
+
self._final_params[param_key] = self._original_params[param_key]
|
|
715
|
+
else:
|
|
716
|
+
# Try direct numeric key
|
|
717
|
+
numeric_key = str(param_index + 1) # 1-based
|
|
718
|
+
if numeric_key in self._original_params:
|
|
719
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
720
|
+
self._final_params[numeric_key] = self._original_params[numeric_key]
|
|
721
|
+
elif isinstance(self._final_params, list):
|
|
722
|
+
self._final_params.append(self._original_params[numeric_key])
|
|
723
|
+
elif isinstance(self._final_params, dict):
|
|
724
|
+
self._final_params[numeric_key] = self._original_params[numeric_key]
|
|
725
|
+
# Single parameter case
|
|
726
|
+
elif param_index == 0:
|
|
727
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
728
|
+
param_key = f"param_{len(self._final_params)}"
|
|
729
|
+
self._final_params[param_key] = self._original_params
|
|
730
|
+
elif isinstance(self._final_params, list):
|
|
731
|
+
self._final_params.append(self._original_params)
|
|
732
|
+
elif isinstance(self._final_params, dict):
|
|
733
|
+
param_name = f"param_{len(self._final_params)}"
|
|
734
|
+
self._final_params[param_name] = self._original_params
|
|
735
|
+
|
|
736
|
+
# Return the column unchanged - it represents the parameter placeholder
|
|
737
|
+
return node
|
|
738
|
+
|
|
216
739
|
def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
217
|
-
"""
|
|
218
|
-
#
|
|
740
|
+
"""Enhanced context-aware decision on literal preservation."""
|
|
741
|
+
# Existing preservation rules (maintain compatibility)
|
|
219
742
|
if self.preserve_null and isinstance(literal, Null):
|
|
220
743
|
return True
|
|
221
744
|
|
|
222
|
-
# Check for boolean values
|
|
223
745
|
if self.preserve_boolean and isinstance(literal, Boolean):
|
|
224
746
|
return True
|
|
225
747
|
|
|
748
|
+
# NEW: Context-based preservation rules
|
|
749
|
+
|
|
750
|
+
# Rule 4: Preserve enum-like literals in subquery lookups (the main fix we need)
|
|
751
|
+
if context.in_subquery and self._is_scalar_lookup_pattern(literal, context):
|
|
752
|
+
return self._is_enum_like_literal(literal)
|
|
753
|
+
|
|
754
|
+
# Existing preservation rules continue...
|
|
755
|
+
|
|
226
756
|
# Check if in preserved function arguments
|
|
227
757
|
if context.in_function_args:
|
|
228
758
|
return True
|
|
229
759
|
|
|
760
|
+
# ENHANCED: Intelligent recursive CTE literal preservation
|
|
761
|
+
if self.preserve_in_recursive_cte and context.in_recursive_cte:
|
|
762
|
+
return self._should_preserve_literal_in_recursive_cte(literal, context)
|
|
763
|
+
|
|
764
|
+
# Check if this literal is being used as an alias value in SELECT
|
|
765
|
+
# e.g., 'computed' as process_status should be preserved
|
|
766
|
+
if hasattr(literal, "parent") and literal.parent:
|
|
767
|
+
parent = literal.parent
|
|
768
|
+
# Check if it's an Alias node and the literal is the expression (not the alias name)
|
|
769
|
+
if isinstance(parent, exp.Alias) and parent.this == literal:
|
|
770
|
+
# Check if this alias is in a SELECT clause
|
|
771
|
+
for ancestor in context.parent_stack:
|
|
772
|
+
if isinstance(ancestor, exp.Select):
|
|
773
|
+
return True
|
|
774
|
+
|
|
230
775
|
# Check parent context more intelligently
|
|
231
776
|
for parent in context.parent_stack:
|
|
232
777
|
# Preserve in schema/DDL contexts
|
|
@@ -255,6 +800,76 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
255
800
|
|
|
256
801
|
return False
|
|
257
802
|
|
|
803
|
+
def _is_in_select_expressions(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
804
|
+
"""Check if literal is in SELECT clause expressions (critical for type inference)."""
|
|
805
|
+
for parent in reversed(context.parent_stack):
|
|
806
|
+
if isinstance(parent, exp.Select):
|
|
807
|
+
if hasattr(parent, "expressions") and parent.expressions:
|
|
808
|
+
return any(self._literal_is_in_expression_tree(literal, expr) for expr in parent.expressions)
|
|
809
|
+
elif isinstance(parent, (exp.Where, exp.Having, exp.Join)):
|
|
810
|
+
return False
|
|
811
|
+
return False
|
|
812
|
+
|
|
813
|
+
def _is_recursive_computation(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
814
|
+
"""Check if literal is part of recursive computation logic."""
|
|
815
|
+
# Look for arithmetic operations that are part of recursive logic
|
|
816
|
+
for parent in reversed(context.parent_stack):
|
|
817
|
+
if isinstance(parent, exp.Binary) and parent.key in ("ADD", "SUB", "MUL", "DIV"):
|
|
818
|
+
# Check if this arithmetic is in a SELECT clause of a recursive part
|
|
819
|
+
return self._is_in_select_expressions(literal, context)
|
|
820
|
+
return False
|
|
821
|
+
|
|
822
|
+
def _should_preserve_literal_in_recursive_cte(
|
|
823
|
+
self, literal: exp.Expression, context: ParameterizationContext
|
|
824
|
+
) -> bool:
|
|
825
|
+
"""Intelligent recursive CTE literal preservation based on semantic role."""
|
|
826
|
+
# Preserve SELECT clause literals (type inference critical)
|
|
827
|
+
if self._is_in_select_expressions(literal, context):
|
|
828
|
+
return True
|
|
829
|
+
|
|
830
|
+
# Preserve recursive computation literals (core logic)
|
|
831
|
+
return self._is_recursive_computation(literal, context)
|
|
832
|
+
|
|
833
|
+
def _literal_is_in_expression_tree(self, target_literal: exp.Expression, expr: exp.Expression) -> bool:
|
|
834
|
+
"""Check if target literal is within the given expression tree."""
|
|
835
|
+
if expr == target_literal:
|
|
836
|
+
return True
|
|
837
|
+
# Recursively check child expressions
|
|
838
|
+
return any(child == target_literal for child in expr.iter_expressions())
|
|
839
|
+
|
|
840
|
+
def _is_scalar_lookup_pattern(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
841
|
+
"""Detect if literal is part of a scalar subquery lookup pattern."""
|
|
842
|
+
# Must be in a subquery for this pattern to apply
|
|
843
|
+
if context.subquery_depth == 0:
|
|
844
|
+
return False
|
|
845
|
+
|
|
846
|
+
# Check if we're in a WHERE clause of a subquery that returns a single column
|
|
847
|
+
# and the literal is being compared against a column
|
|
848
|
+
for parent in reversed(context.parent_stack):
|
|
849
|
+
if isinstance(parent, exp.Where):
|
|
850
|
+
# Look for pattern: WHERE column = 'literal'
|
|
851
|
+
if isinstance(parent.this, exp.Binary) and parent.this.right == literal:
|
|
852
|
+
return isinstance(parent.this.left, exp.Column)
|
|
853
|
+
# Also check for literal on the left side: WHERE 'literal' = column
|
|
854
|
+
if isinstance(parent.this, exp.Binary) and parent.this.left == literal:
|
|
855
|
+
return isinstance(parent.this.right, exp.Column)
|
|
856
|
+
return False
|
|
857
|
+
|
|
858
|
+
def _is_enum_like_literal(self, literal: exp.Expression) -> bool:
|
|
859
|
+
"""Detect if literal looks like an enum/identifier constant."""
|
|
860
|
+
if not isinstance(literal, exp.Literal) or not self._is_string_literal(literal):
|
|
861
|
+
return False
|
|
862
|
+
|
|
863
|
+
value = str(literal.this)
|
|
864
|
+
|
|
865
|
+
# Conservative heuristics for enum-like values
|
|
866
|
+
return (
|
|
867
|
+
len(value) <= MAX_ENUM_LENGTH # Reasonable length limit
|
|
868
|
+
and value.replace("_", "").isalnum() # Only alphanumeric + underscores
|
|
869
|
+
and not value.isdigit() # Not a pure number
|
|
870
|
+
and len(value) > MIN_ENUM_LENGTH # Not too short to be meaningful
|
|
871
|
+
)
|
|
872
|
+
|
|
258
873
|
def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
|
|
259
874
|
"""Extract the Python value and type info from a SQLGlot literal."""
|
|
260
875
|
if isinstance(literal, Null) or literal.this is None:
|
|
@@ -507,7 +1122,6 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
507
1122
|
array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
|
|
508
1123
|
|
|
509
1124
|
# Create TypedParameter for the entire array
|
|
510
|
-
from sqlspec.statement.parameters import TypedParameter
|
|
511
1125
|
|
|
512
1126
|
typed_param = TypedParameter(
|
|
513
1127
|
value=array_values,
|
|
@@ -616,6 +1230,16 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
616
1230
|
"""
|
|
617
1231
|
return self._parameter_metadata.copy()
|
|
618
1232
|
|
|
1233
|
+
def _contains_union(self, cte_node: exp.CTE) -> bool:
|
|
1234
|
+
"""Check if a CTE contains a UNION (characteristic of recursive CTEs)."""
|
|
1235
|
+
|
|
1236
|
+
def has_union(node: exp.Expression) -> bool:
|
|
1237
|
+
if isinstance(node, exp.Union):
|
|
1238
|
+
return True
|
|
1239
|
+
return any(has_union(child) for child in node.iter_expressions())
|
|
1240
|
+
|
|
1241
|
+
return cte_node.this and has_union(cte_node.this)
|
|
1242
|
+
|
|
619
1243
|
def clear_parameters(self) -> None:
|
|
620
1244
|
"""Clear the extracted parameters list."""
|
|
621
1245
|
self.extracted_parameters = []
|