sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Optional
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
5
|
|
|
6
6
|
from sqlglot import exp
|
|
7
7
|
from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
|
|
8
8
|
|
|
9
|
-
from sqlspec.
|
|
10
|
-
from sqlspec.statement.
|
|
9
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
10
|
+
from sqlspec.statement.parameters import ParameterStyle, TypedParameter
|
|
11
11
|
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
12
12
|
|
|
13
13
|
__all__ = ("ParameterizationContext", "ParameterizeLiterals")
|
|
@@ -24,6 +24,12 @@ DEFAULT_MAX_ARRAY_LENGTH = 100
|
|
|
24
24
|
DEFAULT_MAX_IN_LIST_SIZE = 50
|
|
25
25
|
"""Default maximum IN clause list size before parameterization."""
|
|
26
26
|
|
|
27
|
+
MAX_ENUM_LENGTH = 50
|
|
28
|
+
"""Maximum length for enum-like string values."""
|
|
29
|
+
|
|
30
|
+
MIN_ENUM_LENGTH = 2
|
|
31
|
+
"""Minimum length for enum-like string values to be meaningful."""
|
|
32
|
+
|
|
27
33
|
|
|
28
34
|
@dataclass
|
|
29
35
|
class ParameterizationContext:
|
|
@@ -35,8 +41,12 @@ class ParameterizationContext:
|
|
|
35
41
|
in_array: bool = False
|
|
36
42
|
in_in_clause: bool = False
|
|
37
43
|
in_recursive_cte: bool = False
|
|
44
|
+
in_subquery: bool = False
|
|
45
|
+
in_select_list: bool = False
|
|
46
|
+
in_join_condition: bool = False
|
|
38
47
|
function_depth: int = 0
|
|
39
48
|
cte_depth: int = 0
|
|
49
|
+
subquery_depth: int = 0
|
|
40
50
|
|
|
41
51
|
|
|
42
52
|
class ParameterizeLiterals(ProcessorProtocol):
|
|
@@ -94,6 +104,7 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
94
104
|
"ARRAY_UPPER",
|
|
95
105
|
"ARRAY_LOWER",
|
|
96
106
|
"ARRAY_NDIMS",
|
|
107
|
+
"ROUND",
|
|
97
108
|
]
|
|
98
109
|
self.parameterize_arrays = parameterize_arrays
|
|
99
110
|
self.parameterize_in_lists = parameterize_in_lists
|
|
@@ -104,20 +115,64 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
104
115
|
self.extracted_parameters: list[Any] = []
|
|
105
116
|
self._parameter_counter = 0
|
|
106
117
|
self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
|
|
118
|
+
self._preserve_dict_format = False # Track whether to preserve dict format
|
|
107
119
|
|
|
108
120
|
def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
|
|
109
121
|
"""Advanced literal parameterization with context-aware AST analysis."""
|
|
110
|
-
if expression is None or context.current_expression is None
|
|
122
|
+
if expression is None or context.current_expression is None:
|
|
123
|
+
return expression
|
|
124
|
+
|
|
125
|
+
# For named parameters (like BigQuery @param), don't reorder to avoid breaking name mapping
|
|
126
|
+
if (
|
|
127
|
+
context.config.input_sql_had_placeholders
|
|
128
|
+
and context.parameter_info
|
|
129
|
+
and any(p.name for p in context.parameter_info)
|
|
130
|
+
):
|
|
111
131
|
return expression
|
|
112
132
|
|
|
113
133
|
self.extracted_parameters = []
|
|
114
|
-
self._parameter_counter = 0
|
|
115
134
|
self._parameter_metadata = []
|
|
116
135
|
|
|
136
|
+
# When reordering is needed (SQL already has placeholders), we need to start
|
|
137
|
+
# our counter at the number of existing parameters to avoid conflicts
|
|
138
|
+
if context.config.input_sql_had_placeholders and context.parameter_info:
|
|
139
|
+
# Find the highest ordinal among existing parameters
|
|
140
|
+
max_ordinal = max(p.ordinal for p in context.parameter_info)
|
|
141
|
+
self._parameter_counter = max_ordinal + 1
|
|
142
|
+
else:
|
|
143
|
+
self._parameter_counter = 0
|
|
144
|
+
|
|
145
|
+
# Track original user parameters for proper merging
|
|
146
|
+
self._original_params = context.merged_parameters
|
|
147
|
+
self._user_param_index = 0
|
|
148
|
+
# If original params are dict and we have named placeholders, preserve dict format
|
|
149
|
+
if isinstance(context.merged_parameters, dict) and context.parameter_info:
|
|
150
|
+
# Check if we have named placeholders
|
|
151
|
+
has_named = any(p.name for p in context.parameter_info)
|
|
152
|
+
if has_named:
|
|
153
|
+
self._final_params: Union[dict[str, Any], list[Any]] = {}
|
|
154
|
+
self._preserve_dict_format = True
|
|
155
|
+
else:
|
|
156
|
+
self._final_params = []
|
|
157
|
+
self._preserve_dict_format = False
|
|
158
|
+
else:
|
|
159
|
+
self._final_params = []
|
|
160
|
+
self._preserve_dict_format = False
|
|
161
|
+
self._is_reordering_needed = context.config.input_sql_had_placeholders
|
|
162
|
+
|
|
117
163
|
param_context = ParameterizationContext(parent_stack=[])
|
|
118
164
|
transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
|
|
119
165
|
context.current_expression = transformed_expression
|
|
120
|
-
|
|
166
|
+
|
|
167
|
+
# If we're reordering, update the merged parameters with the reordered result
|
|
168
|
+
# In this case, we don't need to add to extracted_parameters_from_pipeline
|
|
169
|
+
# because the parameters are already in _final_params
|
|
170
|
+
if self._is_reordering_needed and self._final_params:
|
|
171
|
+
context.merged_parameters = self._final_params
|
|
172
|
+
else:
|
|
173
|
+
# Only add extracted parameters to the pipeline if we're not reordering
|
|
174
|
+
# This prevents duplication when parameters are already in merged_parameters
|
|
175
|
+
context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
|
|
121
176
|
|
|
122
177
|
context.metadata["parameter_metadata"] = self._parameter_metadata
|
|
123
178
|
|
|
@@ -138,6 +193,32 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
138
193
|
result = self._process_array(node, context)
|
|
139
194
|
elif isinstance(node, exp.In) and self.parameterize_in_lists:
|
|
140
195
|
result = self._process_in_clause(node, context)
|
|
196
|
+
elif isinstance(node, exp.Placeholder) and self._is_reordering_needed:
|
|
197
|
+
# Handle existing placeholders when reordering is needed
|
|
198
|
+
result = self._process_existing_placeholder(node, context)
|
|
199
|
+
elif isinstance(node, exp.Parameter) and self._is_reordering_needed:
|
|
200
|
+
# Handle PostgreSQL-style parameters ($1, $2) when reordering is needed
|
|
201
|
+
result = self._process_existing_parameter(node, context)
|
|
202
|
+
elif isinstance(node, exp.Column) and self._is_reordering_needed:
|
|
203
|
+
# Check if this column looks like a PostgreSQL parameter ($1, $2, etc.)
|
|
204
|
+
column_name = str(node.this) if hasattr(node, "this") else ""
|
|
205
|
+
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
206
|
+
# This is a PostgreSQL-style parameter parsed as a column
|
|
207
|
+
result = self._process_postgresql_column_parameter(node, context)
|
|
208
|
+
else:
|
|
209
|
+
# Regular column - process children
|
|
210
|
+
for key, value in node.args.items():
|
|
211
|
+
if isinstance(value, exp.Expression):
|
|
212
|
+
node.set(key, self._transform_with_context(value, context))
|
|
213
|
+
elif isinstance(value, list):
|
|
214
|
+
node.set(
|
|
215
|
+
key,
|
|
216
|
+
[
|
|
217
|
+
self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
|
|
218
|
+
for v in value
|
|
219
|
+
],
|
|
220
|
+
)
|
|
221
|
+
result = node
|
|
141
222
|
else:
|
|
142
223
|
# Recursively process children
|
|
143
224
|
for key, value in node.args.items():
|
|
@@ -161,51 +242,136 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
161
242
|
def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
|
|
162
243
|
"""Update parameterization context based on current AST node."""
|
|
163
244
|
if entering:
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
if isinstance(node, Func):
|
|
167
|
-
context.function_depth += 1
|
|
168
|
-
# Get function name from class name or node.name
|
|
169
|
-
func_name = node.__class__.__name__.upper()
|
|
170
|
-
if func_name in self.preserve_in_functions or (
|
|
171
|
-
node.name and node.name.upper() in self.preserve_in_functions
|
|
172
|
-
):
|
|
173
|
-
context.in_function_args = True
|
|
174
|
-
elif isinstance(node, exp.Case):
|
|
175
|
-
context.in_case_when = True
|
|
176
|
-
elif isinstance(node, Array):
|
|
177
|
-
context.in_array = True
|
|
178
|
-
elif isinstance(node, exp.In):
|
|
179
|
-
context.in_in_clause = True
|
|
180
|
-
elif isinstance(node, exp.CTE):
|
|
181
|
-
context.cte_depth += 1
|
|
182
|
-
# Check if this CTE is recursive:
|
|
183
|
-
# 1. Parent WITH must be RECURSIVE
|
|
184
|
-
# 2. CTE must contain UNION (characteristic of recursive CTEs)
|
|
185
|
-
is_in_recursive_with = any(
|
|
186
|
-
isinstance(parent, exp.With) and parent.args.get("recursive", False)
|
|
187
|
-
for parent in reversed(context.parent_stack)
|
|
188
|
-
)
|
|
189
|
-
if is_in_recursive_with and self._contains_union(node):
|
|
190
|
-
context.in_recursive_cte = True
|
|
245
|
+
self._update_context_entering(node, context)
|
|
191
246
|
else:
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
247
|
+
self._update_context_leaving(node, context)
|
|
248
|
+
|
|
249
|
+
def _update_context_entering(self, node: exp.Expression, context: ParameterizationContext) -> None:
|
|
250
|
+
"""Update context when entering a node."""
|
|
251
|
+
context.parent_stack.append(node)
|
|
252
|
+
|
|
253
|
+
if isinstance(node, Func):
|
|
254
|
+
self._update_context_entering_func(node, context)
|
|
255
|
+
elif isinstance(node, exp.Case):
|
|
256
|
+
context.in_case_when = True
|
|
257
|
+
elif isinstance(node, Array):
|
|
258
|
+
context.in_array = True
|
|
259
|
+
elif isinstance(node, exp.In):
|
|
260
|
+
context.in_in_clause = True
|
|
261
|
+
elif isinstance(node, exp.CTE):
|
|
262
|
+
self._update_context_entering_cte(node, context)
|
|
263
|
+
elif isinstance(node, exp.Subquery):
|
|
264
|
+
context.subquery_depth += 1
|
|
265
|
+
context.in_subquery = True
|
|
266
|
+
elif isinstance(node, exp.Select):
|
|
267
|
+
self._update_context_entering_select(node, context)
|
|
268
|
+
elif isinstance(node, exp.Join):
|
|
269
|
+
context.in_join_condition = True
|
|
270
|
+
|
|
271
|
+
def _update_context_entering_func(self, node: Func, context: ParameterizationContext) -> None:
|
|
272
|
+
"""Update context when entering a function node."""
|
|
273
|
+
context.function_depth += 1
|
|
274
|
+
# Get function name from class name or node.name
|
|
275
|
+
func_name = node.__class__.__name__.upper()
|
|
276
|
+
if func_name in self.preserve_in_functions or (node.name and node.name.upper() in self.preserve_in_functions):
|
|
277
|
+
context.in_function_args = True
|
|
278
|
+
|
|
279
|
+
def _update_context_entering_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
|
|
280
|
+
"""Update context when entering a CTE node."""
|
|
281
|
+
context.cte_depth += 1
|
|
282
|
+
# Check if this CTE is recursive:
|
|
283
|
+
# 1. Parent WITH must be RECURSIVE
|
|
284
|
+
# 2. CTE must contain UNION (characteristic of recursive CTEs)
|
|
285
|
+
is_in_recursive_with = any(
|
|
286
|
+
isinstance(parent, exp.With) and parent.args.get("recursive", False)
|
|
287
|
+
for parent in reversed(context.parent_stack)
|
|
288
|
+
)
|
|
289
|
+
if is_in_recursive_with and self._contains_union(node):
|
|
290
|
+
context.in_recursive_cte = True
|
|
291
|
+
|
|
292
|
+
def _update_context_entering_select(self, node: exp.Select, context: ParameterizationContext) -> None:
|
|
293
|
+
"""Update context when entering a SELECT node."""
|
|
294
|
+
# Only track nested SELECT statements as subqueries if they're not part of a recursive CTE
|
|
295
|
+
is_in_recursive_cte = any(
|
|
296
|
+
isinstance(parent, exp.CTE)
|
|
297
|
+
and any(
|
|
298
|
+
isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
|
|
299
|
+
for grandparent in context.parent_stack
|
|
300
|
+
)
|
|
301
|
+
for parent in context.parent_stack[:-1]
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if not is_in_recursive_cte and any(
|
|
305
|
+
isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
|
|
306
|
+
for parent in context.parent_stack[:-1] # Exclude the current node
|
|
307
|
+
):
|
|
308
|
+
context.subquery_depth += 1
|
|
309
|
+
context.in_subquery = True
|
|
310
|
+
# Check if we're in a SELECT clause expressions list
|
|
311
|
+
if hasattr(node, "expressions"):
|
|
312
|
+
# We'll handle this specifically when processing individual expressions
|
|
313
|
+
context.in_select_list = False # Will be detected by _is_in_select_expressions
|
|
314
|
+
|
|
315
|
+
def _update_context_leaving(self, node: exp.Expression, context: ParameterizationContext) -> None:
|
|
316
|
+
"""Update context when leaving a node."""
|
|
317
|
+
if context.parent_stack:
|
|
318
|
+
context.parent_stack.pop()
|
|
319
|
+
|
|
320
|
+
if isinstance(node, Func):
|
|
321
|
+
self._update_context_leaving_func(node, context)
|
|
322
|
+
elif isinstance(node, exp.Case):
|
|
323
|
+
context.in_case_when = False
|
|
324
|
+
elif isinstance(node, Array):
|
|
325
|
+
context.in_array = False
|
|
326
|
+
elif isinstance(node, exp.In):
|
|
327
|
+
context.in_in_clause = False
|
|
328
|
+
elif isinstance(node, exp.CTE):
|
|
329
|
+
self._update_context_leaving_cte(node, context)
|
|
330
|
+
elif isinstance(node, exp.Subquery):
|
|
331
|
+
self._update_context_leaving_subquery(node, context)
|
|
332
|
+
elif isinstance(node, exp.Select):
|
|
333
|
+
self._update_context_leaving_select(node, context)
|
|
334
|
+
elif isinstance(node, exp.Join):
|
|
335
|
+
context.in_join_condition = False
|
|
336
|
+
|
|
337
|
+
def _update_context_leaving_func(self, node: Func, context: ParameterizationContext) -> None:
|
|
338
|
+
"""Update context when leaving a function node."""
|
|
339
|
+
context.function_depth -= 1
|
|
340
|
+
if context.function_depth == 0:
|
|
341
|
+
context.in_function_args = False
|
|
342
|
+
|
|
343
|
+
def _update_context_leaving_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
|
|
344
|
+
"""Update context when leaving a CTE node."""
|
|
345
|
+
context.cte_depth -= 1
|
|
346
|
+
if context.cte_depth == 0:
|
|
347
|
+
context.in_recursive_cte = False
|
|
348
|
+
|
|
349
|
+
def _update_context_leaving_subquery(self, node: exp.Subquery, context: ParameterizationContext) -> None:
|
|
350
|
+
"""Update context when leaving a subquery node."""
|
|
351
|
+
context.subquery_depth -= 1
|
|
352
|
+
if context.subquery_depth == 0:
|
|
353
|
+
context.in_subquery = False
|
|
354
|
+
|
|
355
|
+
def _update_context_leaving_select(self, node: exp.Select, context: ParameterizationContext) -> None:
|
|
356
|
+
"""Update context when leaving a SELECT node."""
|
|
357
|
+
# Only decrement if this was a nested SELECT (not part of recursive CTE)
|
|
358
|
+
is_in_recursive_cte = any(
|
|
359
|
+
isinstance(parent, exp.CTE)
|
|
360
|
+
and any(
|
|
361
|
+
isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
|
|
362
|
+
for grandparent in context.parent_stack
|
|
363
|
+
)
|
|
364
|
+
for parent in context.parent_stack[:-1]
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if not is_in_recursive_cte and any(
|
|
368
|
+
isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
|
|
369
|
+
for parent in context.parent_stack[:-1] # Exclude current node
|
|
370
|
+
):
|
|
371
|
+
context.subquery_depth -= 1
|
|
372
|
+
if context.subquery_depth == 0:
|
|
373
|
+
context.in_subquery = False
|
|
374
|
+
context.in_select_list = False
|
|
209
375
|
|
|
210
376
|
def _process_literal_with_context(
|
|
211
377
|
self, literal: exp.Expression, context: ParameterizationContext
|
|
@@ -228,11 +394,26 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
228
394
|
semantic_name=semantic_name,
|
|
229
395
|
)
|
|
230
396
|
|
|
231
|
-
#
|
|
397
|
+
# Always track extracted parameters for proper merging
|
|
232
398
|
self.extracted_parameters.append(typed_param)
|
|
399
|
+
|
|
400
|
+
# If we're reordering, also add to final params directly
|
|
401
|
+
if self._is_reordering_needed:
|
|
402
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
403
|
+
# For dict format, we need a key
|
|
404
|
+
param_key = semantic_name or f"param_{len(self._final_params)}"
|
|
405
|
+
self._final_params[param_key] = typed_param
|
|
406
|
+
elif isinstance(self._final_params, list):
|
|
407
|
+
self._final_params.append(typed_param)
|
|
408
|
+
else:
|
|
409
|
+
# Fallback - this shouldn't happen but handle gracefully
|
|
410
|
+
if not hasattr(self, "_fallback_params"):
|
|
411
|
+
self._fallback_params = []
|
|
412
|
+
self._fallback_params.append(typed_param)
|
|
413
|
+
|
|
233
414
|
self._parameter_metadata.append(
|
|
234
415
|
{
|
|
235
|
-
"index": len(self.extracted_parameters) - 1,
|
|
416
|
+
"index": len(self._final_params if self._is_reordering_needed else self.extracted_parameters) - 1,
|
|
236
417
|
"type": type_hint,
|
|
237
418
|
"semantic_name": semantic_name,
|
|
238
419
|
"context": self._get_context_description(context),
|
|
@@ -242,23 +423,343 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
242
423
|
# Create appropriate placeholder
|
|
243
424
|
return self._create_placeholder(hint=semantic_name)
|
|
244
425
|
|
|
426
|
+
def _process_existing_placeholder(self, node: exp.Placeholder, context: ParameterizationContext) -> exp.Expression:
|
|
427
|
+
"""Process an existing placeholder when reordering parameters."""
|
|
428
|
+
if self._original_params is None:
|
|
429
|
+
return node
|
|
430
|
+
|
|
431
|
+
if isinstance(self._original_params, (list, tuple)):
|
|
432
|
+
self._handle_list_params_for_placeholder(node)
|
|
433
|
+
elif isinstance(self._original_params, dict):
|
|
434
|
+
self._handle_dict_params_for_placeholder(node)
|
|
435
|
+
else:
|
|
436
|
+
# Single value parameter
|
|
437
|
+
self._handle_single_value_param_for_placeholder(node)
|
|
438
|
+
|
|
439
|
+
return node
|
|
440
|
+
|
|
441
|
+
def _handle_list_params_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
442
|
+
"""Handle list/tuple parameters for placeholder."""
|
|
443
|
+
if isinstance(self._original_params, (list, tuple)) and self._user_param_index < len(self._original_params):
|
|
444
|
+
value = self._original_params[self._user_param_index]
|
|
445
|
+
self._add_to_final_params(value, node)
|
|
446
|
+
self._user_param_index += 1
|
|
447
|
+
else:
|
|
448
|
+
# More placeholders than user parameters
|
|
449
|
+
self._add_to_final_params(None, node)
|
|
450
|
+
|
|
451
|
+
def _handle_dict_params_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
452
|
+
"""Handle dict parameters for placeholder."""
|
|
453
|
+
if not isinstance(self._original_params, dict):
|
|
454
|
+
self._add_to_final_params(None, node)
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
raw_placeholder_name = node.this if hasattr(node, "this") else None
|
|
458
|
+
if not raw_placeholder_name:
|
|
459
|
+
# Unnamed placeholder '?' with dict params is ambiguous
|
|
460
|
+
self._add_to_final_params(None, node)
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
# FIX: Normalize the placeholder name by stripping leading sigils
|
|
464
|
+
placeholder_name = raw_placeholder_name.lstrip(":@")
|
|
465
|
+
|
|
466
|
+
# Debug logging
|
|
467
|
+
|
|
468
|
+
if placeholder_name in self._original_params:
|
|
469
|
+
# Direct match for placeholder name
|
|
470
|
+
self._add_to_final_params(self._original_params[placeholder_name], node)
|
|
471
|
+
elif placeholder_name.isdigit() and self._user_param_index == 0:
|
|
472
|
+
# Oracle-style numeric parameters
|
|
473
|
+
self._handle_oracle_numeric_params()
|
|
474
|
+
self._user_param_index += 1
|
|
475
|
+
elif placeholder_name.isdigit() and self._user_param_index > 0:
|
|
476
|
+
# Already handled Oracle params
|
|
477
|
+
pass
|
|
478
|
+
elif self._user_param_index == 0 and len(self._original_params) > 0:
|
|
479
|
+
# Single dict parameter case
|
|
480
|
+
self._handle_single_dict_param()
|
|
481
|
+
self._user_param_index += 1
|
|
482
|
+
else:
|
|
483
|
+
# No match found
|
|
484
|
+
self._add_to_final_params(None, node)
|
|
485
|
+
|
|
486
|
+
def _handle_single_value_param_for_placeholder(self, node: exp.Placeholder) -> None:
|
|
487
|
+
"""Handle single value parameter for placeholder."""
|
|
488
|
+
if self._user_param_index == 0:
|
|
489
|
+
self._add_to_final_params(self._original_params, node)
|
|
490
|
+
self._user_param_index += 1
|
|
491
|
+
else:
|
|
492
|
+
self._add_to_final_params(None, node)
|
|
493
|
+
|
|
494
|
+
def _handle_oracle_numeric_params(self) -> None:
|
|
495
|
+
"""Handle Oracle-style numeric parameters."""
|
|
496
|
+
if not isinstance(self._original_params, dict):
|
|
497
|
+
return
|
|
498
|
+
|
|
499
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
500
|
+
for k, v in self._original_params.items():
|
|
501
|
+
if k.isdigit():
|
|
502
|
+
self._final_params[k] = v
|
|
503
|
+
else:
|
|
504
|
+
# Convert to positional list
|
|
505
|
+
numeric_keys = [k for k in self._original_params if k.isdigit()]
|
|
506
|
+
if numeric_keys:
|
|
507
|
+
max_index = max(int(k) for k in numeric_keys)
|
|
508
|
+
param_list = [None] * (max_index + 1)
|
|
509
|
+
for k, v in self._original_params.items():
|
|
510
|
+
if k.isdigit():
|
|
511
|
+
param_list[int(k)] = v
|
|
512
|
+
if isinstance(self._final_params, list):
|
|
513
|
+
self._final_params.extend(param_list)
|
|
514
|
+
elif isinstance(self._final_params, dict):
|
|
515
|
+
for i, val in enumerate(param_list):
|
|
516
|
+
self._final_params[str(i)] = val
|
|
517
|
+
|
|
518
|
+
def _handle_single_dict_param(self) -> None:
|
|
519
|
+
"""Handle single dict parameter case."""
|
|
520
|
+
if not isinstance(self._original_params, dict):
|
|
521
|
+
return
|
|
522
|
+
|
|
523
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
524
|
+
for k, v in self._original_params.items():
|
|
525
|
+
self._final_params[k] = v
|
|
526
|
+
elif isinstance(self._final_params, list):
|
|
527
|
+
self._final_params.append(self._original_params)
|
|
528
|
+
elif isinstance(self._final_params, dict):
|
|
529
|
+
param_name = f"param_{len(self._final_params)}"
|
|
530
|
+
self._final_params[param_name] = self._original_params
|
|
531
|
+
|
|
532
|
+
def _add_to_final_params(self, value: Any, node: exp.Placeholder) -> None:
|
|
533
|
+
"""Add a value to final params with proper type handling."""
|
|
534
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
535
|
+
placeholder_name = node.this if hasattr(node, "this") else f"param_{self._user_param_index}"
|
|
536
|
+
self._final_params[placeholder_name] = value
|
|
537
|
+
elif isinstance(self._final_params, list):
|
|
538
|
+
self._final_params.append(value)
|
|
539
|
+
elif isinstance(self._final_params, dict):
|
|
540
|
+
param_name = f"param_{len(self._final_params)}"
|
|
541
|
+
self._final_params[param_name] = value
|
|
542
|
+
|
|
543
|
+
def _process_existing_parameter(self, node: exp.Parameter, context: ParameterizationContext) -> exp.Expression:
|
|
544
|
+
"""Process existing parameters (both numeric and named) when reordering parameters."""
|
|
545
|
+
# First try to get parameter name for named parameters (like BigQuery @param_name)
|
|
546
|
+
param_name = self._extract_parameter_name(node)
|
|
547
|
+
|
|
548
|
+
if param_name and isinstance(self._original_params, dict) and param_name in self._original_params:
|
|
549
|
+
value = self._original_params[param_name]
|
|
550
|
+
self._add_param_value_to_finals(value)
|
|
551
|
+
return node
|
|
552
|
+
|
|
553
|
+
# Fall back to numeric parameter handling for PostgreSQL-style parameters ($1, $2)
|
|
554
|
+
param_index = self._extract_parameter_index(node)
|
|
555
|
+
|
|
556
|
+
if self._original_params is None:
|
|
557
|
+
self._add_none_to_final_params()
|
|
558
|
+
elif isinstance(self._original_params, (list, tuple)):
|
|
559
|
+
self._handle_list_params_for_parameter_node(param_index)
|
|
560
|
+
elif isinstance(self._original_params, dict):
|
|
561
|
+
self._handle_dict_params_for_parameter_node(param_index)
|
|
562
|
+
elif param_index == 0:
|
|
563
|
+
# Single parameter case
|
|
564
|
+
self._add_param_value_to_finals(self._original_params)
|
|
565
|
+
else:
|
|
566
|
+
self._add_none_to_final_params()
|
|
567
|
+
|
|
568
|
+
# Return the parameter unchanged
|
|
569
|
+
return node
|
|
570
|
+
|
|
571
|
+
@staticmethod
|
|
572
|
+
def _extract_parameter_name(node: exp.Parameter) -> Optional[str]:
|
|
573
|
+
"""Extract parameter name from a Parameter node for named parameters."""
|
|
574
|
+
if hasattr(node, "this"):
|
|
575
|
+
if isinstance(node.this, exp.Var):
|
|
576
|
+
# Named parameter like @min_value -> min_value
|
|
577
|
+
return str(node.this.this)
|
|
578
|
+
if hasattr(node.this, "this"):
|
|
579
|
+
# Handle other node types that might contain the name
|
|
580
|
+
return str(node.this.this)
|
|
581
|
+
return None
|
|
582
|
+
|
|
583
|
+
@staticmethod
|
|
584
|
+
def _extract_parameter_index(node: exp.Parameter) -> Optional[int]:
|
|
585
|
+
"""Extract parameter index from a Parameter node."""
|
|
586
|
+
if hasattr(node, "this") and isinstance(node.this, Literal):
|
|
587
|
+
import contextlib
|
|
588
|
+
|
|
589
|
+
with contextlib.suppress(ValueError, TypeError):
|
|
590
|
+
return int(node.this.this) - 1 # Convert to 0-based index
|
|
591
|
+
return None
|
|
592
|
+
|
|
593
|
+
def _handle_list_params_for_parameter_node(self, param_index: Optional[int]) -> None:
|
|
594
|
+
"""Handle list/tuple parameters for Parameter node."""
|
|
595
|
+
if (
|
|
596
|
+
isinstance(self._original_params, (list, tuple))
|
|
597
|
+
and param_index is not None
|
|
598
|
+
and 0 <= param_index < len(self._original_params)
|
|
599
|
+
):
|
|
600
|
+
# Use the parameter at the specified index
|
|
601
|
+
self._add_param_value_to_finals(self._original_params[param_index])
|
|
602
|
+
else:
|
|
603
|
+
# More parameters than user provided
|
|
604
|
+
self._add_none_to_final_params()
|
|
605
|
+
|
|
606
|
+
def _handle_dict_params_for_parameter_node(self, param_index: Optional[int]) -> None:
|
|
607
|
+
"""Handle dict parameters for Parameter node."""
|
|
608
|
+
if param_index is not None:
|
|
609
|
+
self._handle_dict_param_with_index(param_index)
|
|
610
|
+
else:
|
|
611
|
+
self._add_none_to_final_params()
|
|
612
|
+
|
|
613
|
+
def _handle_dict_param_with_index(self, param_index: int) -> None:
|
|
614
|
+
"""Handle dict parameter when we have an index."""
|
|
615
|
+
if not isinstance(self._original_params, dict):
|
|
616
|
+
self._add_none_to_final_params()
|
|
617
|
+
return
|
|
618
|
+
|
|
619
|
+
# Try param_N key first
|
|
620
|
+
param_key = f"param_{param_index}"
|
|
621
|
+
if param_key in self._original_params:
|
|
622
|
+
self._add_dict_value_to_finals(param_key)
|
|
623
|
+
return
|
|
624
|
+
|
|
625
|
+
# Try direct numeric key (1-based)
|
|
626
|
+
numeric_key = str(param_index + 1)
|
|
627
|
+
if numeric_key in self._original_params:
|
|
628
|
+
self._add_dict_value_to_finals(numeric_key)
|
|
629
|
+
else:
|
|
630
|
+
self._add_none_to_final_params()
|
|
631
|
+
|
|
632
|
+
def _add_dict_value_to_finals(self, key: str) -> None:
|
|
633
|
+
"""Add a value from dict params to final params."""
|
|
634
|
+
if isinstance(self._original_params, dict) and key in self._original_params:
|
|
635
|
+
value = self._original_params[key]
|
|
636
|
+
if isinstance(self._final_params, list):
|
|
637
|
+
self._final_params.append(value)
|
|
638
|
+
elif isinstance(self._final_params, dict):
|
|
639
|
+
self._final_params[key] = value
|
|
640
|
+
|
|
641
|
+
def _add_param_value_to_finals(self, value: Any) -> None:
|
|
642
|
+
"""Add a parameter value to final params."""
|
|
643
|
+
if isinstance(self._final_params, list):
|
|
644
|
+
self._final_params.append(value)
|
|
645
|
+
elif isinstance(self._final_params, dict):
|
|
646
|
+
param_name = f"param_{len(self._final_params)}"
|
|
647
|
+
self._final_params[param_name] = value
|
|
648
|
+
|
|
649
|
+
def _add_none_to_final_params(self) -> None:
|
|
650
|
+
"""Add None to final params."""
|
|
651
|
+
if isinstance(self._final_params, list):
|
|
652
|
+
self._final_params.append(None)
|
|
653
|
+
elif isinstance(self._final_params, dict):
|
|
654
|
+
param_name = f"param_{len(self._final_params)}"
|
|
655
|
+
self._final_params[param_name] = None
|
|
656
|
+
|
|
657
|
+
def _process_postgresql_column_parameter(
|
|
658
|
+
self, node: exp.Column, context: ParameterizationContext
|
|
659
|
+
) -> exp.Expression:
|
|
660
|
+
"""Process PostgreSQL-style parameters that were parsed as columns ($1, $2)."""
|
|
661
|
+
# Extract the numeric part from $1, $2, etc.
|
|
662
|
+
column_name = str(node.this) if hasattr(node, "this") else ""
|
|
663
|
+
param_index = None
|
|
664
|
+
|
|
665
|
+
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
666
|
+
import contextlib
|
|
667
|
+
|
|
668
|
+
with contextlib.suppress(ValueError, TypeError):
|
|
669
|
+
param_index = int(column_name[1:]) - 1 # Convert to 0-based index
|
|
670
|
+
|
|
671
|
+
if self._original_params is None:
|
|
672
|
+
# No user parameters provided - don't add None
|
|
673
|
+
return node
|
|
674
|
+
if isinstance(self._original_params, (list, tuple)):
|
|
675
|
+
# When we have mixed parameter styles and reordering is needed,
|
|
676
|
+
# use sequential assignment based on _user_param_index
|
|
677
|
+
if self._is_reordering_needed:
|
|
678
|
+
# For mixed styles, parameters should be assigned sequentially
|
|
679
|
+
# regardless of the numeric value in the placeholder
|
|
680
|
+
if self._user_param_index < len(self._original_params):
|
|
681
|
+
param_value = self._original_params[self._user_param_index]
|
|
682
|
+
self._user_param_index += 1
|
|
683
|
+
else:
|
|
684
|
+
param_value = None
|
|
685
|
+
else:
|
|
686
|
+
# Non-mixed styles - use the numeric value from the placeholder
|
|
687
|
+
param_value = (
|
|
688
|
+
self._original_params[param_index]
|
|
689
|
+
if param_index is not None and 0 <= param_index < len(self._original_params)
|
|
690
|
+
else None
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
if param_value is not None:
|
|
694
|
+
# Add the parameter value to final params
|
|
695
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
696
|
+
param_key = f"param_{len(self._final_params)}"
|
|
697
|
+
self._final_params[param_key] = param_value
|
|
698
|
+
elif isinstance(self._final_params, list):
|
|
699
|
+
self._final_params.append(param_value)
|
|
700
|
+
elif isinstance(self._final_params, dict):
|
|
701
|
+
param_name = f"param_{len(self._final_params)}"
|
|
702
|
+
self._final_params[param_name] = param_value
|
|
703
|
+
# More parameters than user provided - don't add None
|
|
704
|
+
elif isinstance(self._original_params, dict):
|
|
705
|
+
# For dict parameters with numeric placeholders, try to map by index
|
|
706
|
+
if param_index is not None:
|
|
707
|
+
param_key = f"param_{param_index}"
|
|
708
|
+
if param_key in self._original_params:
|
|
709
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
710
|
+
self._final_params[param_key] = self._original_params[param_key]
|
|
711
|
+
elif isinstance(self._final_params, list):
|
|
712
|
+
self._final_params.append(self._original_params[param_key])
|
|
713
|
+
elif isinstance(self._final_params, dict):
|
|
714
|
+
self._final_params[param_key] = self._original_params[param_key]
|
|
715
|
+
else:
|
|
716
|
+
# Try direct numeric key
|
|
717
|
+
numeric_key = str(param_index + 1) # 1-based
|
|
718
|
+
if numeric_key in self._original_params:
|
|
719
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
720
|
+
self._final_params[numeric_key] = self._original_params[numeric_key]
|
|
721
|
+
elif isinstance(self._final_params, list):
|
|
722
|
+
self._final_params.append(self._original_params[numeric_key])
|
|
723
|
+
elif isinstance(self._final_params, dict):
|
|
724
|
+
self._final_params[numeric_key] = self._original_params[numeric_key]
|
|
725
|
+
# Single parameter case
|
|
726
|
+
elif param_index == 0:
|
|
727
|
+
if self._preserve_dict_format and isinstance(self._final_params, dict):
|
|
728
|
+
param_key = f"param_{len(self._final_params)}"
|
|
729
|
+
self._final_params[param_key] = self._original_params
|
|
730
|
+
elif isinstance(self._final_params, list):
|
|
731
|
+
self._final_params.append(self._original_params)
|
|
732
|
+
elif isinstance(self._final_params, dict):
|
|
733
|
+
param_name = f"param_{len(self._final_params)}"
|
|
734
|
+
self._final_params[param_name] = self._original_params
|
|
735
|
+
|
|
736
|
+
# Return the column unchanged - it represents the parameter placeholder
|
|
737
|
+
return node
|
|
738
|
+
|
|
245
739
|
def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
246
|
-
"""
|
|
247
|
-
#
|
|
740
|
+
"""Enhanced context-aware decision on literal preservation."""
|
|
741
|
+
# Existing preservation rules (maintain compatibility)
|
|
248
742
|
if self.preserve_null and isinstance(literal, Null):
|
|
249
743
|
return True
|
|
250
744
|
|
|
251
|
-
# Check for boolean values
|
|
252
745
|
if self.preserve_boolean and isinstance(literal, Boolean):
|
|
253
746
|
return True
|
|
254
747
|
|
|
748
|
+
# NEW: Context-based preservation rules
|
|
749
|
+
|
|
750
|
+
# Rule 4: Preserve enum-like literals in subquery lookups (the main fix we need)
|
|
751
|
+
if context.in_subquery and self._is_scalar_lookup_pattern(literal, context):
|
|
752
|
+
return self._is_enum_like_literal(literal)
|
|
753
|
+
|
|
754
|
+
# Existing preservation rules continue...
|
|
755
|
+
|
|
255
756
|
# Check if in preserved function arguments
|
|
256
757
|
if context.in_function_args:
|
|
257
758
|
return True
|
|
258
759
|
|
|
259
|
-
#
|
|
760
|
+
# ENHANCED: Intelligent recursive CTE literal preservation
|
|
260
761
|
if self.preserve_in_recursive_cte and context.in_recursive_cte:
|
|
261
|
-
return
|
|
762
|
+
return self._should_preserve_literal_in_recursive_cte(literal, context)
|
|
262
763
|
|
|
263
764
|
# Check if this literal is being used as an alias value in SELECT
|
|
264
765
|
# e.g., 'computed' as process_status should be preserved
|
|
@@ -299,6 +800,76 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
299
800
|
|
|
300
801
|
return False
|
|
301
802
|
|
|
803
|
+
def _is_in_select_expressions(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
804
|
+
"""Check if literal is in SELECT clause expressions (critical for type inference)."""
|
|
805
|
+
for parent in reversed(context.parent_stack):
|
|
806
|
+
if isinstance(parent, exp.Select):
|
|
807
|
+
if hasattr(parent, "expressions") and parent.expressions:
|
|
808
|
+
return any(self._literal_is_in_expression_tree(literal, expr) for expr in parent.expressions)
|
|
809
|
+
elif isinstance(parent, (exp.Where, exp.Having, exp.Join)):
|
|
810
|
+
return False
|
|
811
|
+
return False
|
|
812
|
+
|
|
813
|
+
def _is_recursive_computation(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
814
|
+
"""Check if literal is part of recursive computation logic."""
|
|
815
|
+
# Look for arithmetic operations that are part of recursive logic
|
|
816
|
+
for parent in reversed(context.parent_stack):
|
|
817
|
+
if isinstance(parent, exp.Binary) and parent.key in ("ADD", "SUB", "MUL", "DIV"):
|
|
818
|
+
# Check if this arithmetic is in a SELECT clause of a recursive part
|
|
819
|
+
return self._is_in_select_expressions(literal, context)
|
|
820
|
+
return False
|
|
821
|
+
|
|
822
|
+
def _should_preserve_literal_in_recursive_cte(
|
|
823
|
+
self, literal: exp.Expression, context: ParameterizationContext
|
|
824
|
+
) -> bool:
|
|
825
|
+
"""Intelligent recursive CTE literal preservation based on semantic role."""
|
|
826
|
+
# Preserve SELECT clause literals (type inference critical)
|
|
827
|
+
if self._is_in_select_expressions(literal, context):
|
|
828
|
+
return True
|
|
829
|
+
|
|
830
|
+
# Preserve recursive computation literals (core logic)
|
|
831
|
+
return self._is_recursive_computation(literal, context)
|
|
832
|
+
|
|
833
|
+
def _literal_is_in_expression_tree(self, target_literal: exp.Expression, expr: exp.Expression) -> bool:
|
|
834
|
+
"""Check if target literal is within the given expression tree."""
|
|
835
|
+
if expr == target_literal:
|
|
836
|
+
return True
|
|
837
|
+
# Recursively check child expressions
|
|
838
|
+
return any(child == target_literal for child in expr.iter_expressions())
|
|
839
|
+
|
|
840
|
+
def _is_scalar_lookup_pattern(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
841
|
+
"""Detect if literal is part of a scalar subquery lookup pattern."""
|
|
842
|
+
# Must be in a subquery for this pattern to apply
|
|
843
|
+
if context.subquery_depth == 0:
|
|
844
|
+
return False
|
|
845
|
+
|
|
846
|
+
# Check if we're in a WHERE clause of a subquery that returns a single column
|
|
847
|
+
# and the literal is being compared against a column
|
|
848
|
+
for parent in reversed(context.parent_stack):
|
|
849
|
+
if isinstance(parent, exp.Where):
|
|
850
|
+
# Look for pattern: WHERE column = 'literal'
|
|
851
|
+
if isinstance(parent.this, exp.Binary) and parent.this.right == literal:
|
|
852
|
+
return isinstance(parent.this.left, exp.Column)
|
|
853
|
+
# Also check for literal on the left side: WHERE 'literal' = column
|
|
854
|
+
if isinstance(parent.this, exp.Binary) and parent.this.left == literal:
|
|
855
|
+
return isinstance(parent.this.right, exp.Column)
|
|
856
|
+
return False
|
|
857
|
+
|
|
858
|
+
def _is_enum_like_literal(self, literal: exp.Expression) -> bool:
|
|
859
|
+
"""Detect if literal looks like an enum/identifier constant."""
|
|
860
|
+
if not isinstance(literal, exp.Literal) or not self._is_string_literal(literal):
|
|
861
|
+
return False
|
|
862
|
+
|
|
863
|
+
value = str(literal.this)
|
|
864
|
+
|
|
865
|
+
# Conservative heuristics for enum-like values
|
|
866
|
+
return (
|
|
867
|
+
len(value) <= MAX_ENUM_LENGTH # Reasonable length limit
|
|
868
|
+
and value.replace("_", "").isalnum() # Only alphanumeric + underscores
|
|
869
|
+
and not value.isdigit() # Not a pure number
|
|
870
|
+
and len(value) > MIN_ENUM_LENGTH # Not too short to be meaningful
|
|
871
|
+
)
|
|
872
|
+
|
|
302
873
|
def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
|
|
303
874
|
"""Extract the Python value and type info from a SQLGlot literal."""
|
|
304
875
|
if isinstance(literal, Null) or literal.this is None:
|
|
@@ -551,7 +1122,6 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
551
1122
|
array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
|
|
552
1123
|
|
|
553
1124
|
# Create TypedParameter for the entire array
|
|
554
|
-
from sqlspec.statement.parameters import TypedParameter
|
|
555
1125
|
|
|
556
1126
|
typed_param = TypedParameter(
|
|
557
1127
|
value=array_values,
|