sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,13 +1,13 @@
1
1
  """Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Optional
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  from sqlglot import exp
7
7
  from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
8
8
 
9
- from sqlspec.statement.parameters import ParameterStyle
10
- from sqlspec.statement.pipelines.base import ProcessorProtocol
9
+ from sqlspec.protocols import ProcessorProtocol
10
+ from sqlspec.statement.parameters import ParameterStyle, TypedParameter
11
11
  from sqlspec.statement.pipelines.context import SQLProcessingContext
12
12
 
13
13
  __all__ = ("ParameterizationContext", "ParameterizeLiterals")
@@ -24,6 +24,12 @@ DEFAULT_MAX_ARRAY_LENGTH = 100
24
24
  DEFAULT_MAX_IN_LIST_SIZE = 50
25
25
  """Default maximum IN clause list size before parameterization."""
26
26
 
27
+ MAX_ENUM_LENGTH = 50
28
+ """Maximum length for enum-like string values."""
29
+
30
+ MIN_ENUM_LENGTH = 2
31
+ """Minimum length for enum-like string values to be meaningful."""
32
+
27
33
 
28
34
  @dataclass
29
35
  class ParameterizationContext:
@@ -35,8 +41,12 @@ class ParameterizationContext:
35
41
  in_array: bool = False
36
42
  in_in_clause: bool = False
37
43
  in_recursive_cte: bool = False
44
+ in_subquery: bool = False
45
+ in_select_list: bool = False
46
+ in_join_condition: bool = False
38
47
  function_depth: int = 0
39
48
  cte_depth: int = 0
49
+ subquery_depth: int = 0
40
50
 
41
51
 
42
52
  class ParameterizeLiterals(ProcessorProtocol):
@@ -94,6 +104,7 @@ class ParameterizeLiterals(ProcessorProtocol):
94
104
  "ARRAY_UPPER",
95
105
  "ARRAY_LOWER",
96
106
  "ARRAY_NDIMS",
107
+ "ROUND",
97
108
  ]
98
109
  self.parameterize_arrays = parameterize_arrays
99
110
  self.parameterize_in_lists = parameterize_in_lists
@@ -104,20 +115,64 @@ class ParameterizeLiterals(ProcessorProtocol):
104
115
  self.extracted_parameters: list[Any] = []
105
116
  self._parameter_counter = 0
106
117
  self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
118
+ self._preserve_dict_format = False # Track whether to preserve dict format
107
119
 
108
120
  def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
109
121
  """Advanced literal parameterization with context-aware AST analysis."""
110
- if expression is None or context.current_expression is None or context.config.input_sql_had_placeholders:
122
+ if expression is None or context.current_expression is None:
123
+ return expression
124
+
125
+ # For named parameters (like BigQuery @param), don't reorder to avoid breaking name mapping
126
+ if (
127
+ context.config.input_sql_had_placeholders
128
+ and context.parameter_info
129
+ and any(p.name for p in context.parameter_info)
130
+ ):
111
131
  return expression
112
132
 
113
133
  self.extracted_parameters = []
114
- self._parameter_counter = 0
115
134
  self._parameter_metadata = []
116
135
 
136
+ # When reordering is needed (SQL already has placeholders), we need to start
137
+ # our counter at the number of existing parameters to avoid conflicts
138
+ if context.config.input_sql_had_placeholders and context.parameter_info:
139
+ # Find the highest ordinal among existing parameters
140
+ max_ordinal = max(p.ordinal for p in context.parameter_info)
141
+ self._parameter_counter = max_ordinal + 1
142
+ else:
143
+ self._parameter_counter = 0
144
+
145
+ # Track original user parameters for proper merging
146
+ self._original_params = context.merged_parameters
147
+ self._user_param_index = 0
148
+ # If original params are dict and we have named placeholders, preserve dict format
149
+ if isinstance(context.merged_parameters, dict) and context.parameter_info:
150
+ # Check if we have named placeholders
151
+ has_named = any(p.name for p in context.parameter_info)
152
+ if has_named:
153
+ self._final_params: Union[dict[str, Any], list[Any]] = {}
154
+ self._preserve_dict_format = True
155
+ else:
156
+ self._final_params = []
157
+ self._preserve_dict_format = False
158
+ else:
159
+ self._final_params = []
160
+ self._preserve_dict_format = False
161
+ self._is_reordering_needed = context.config.input_sql_had_placeholders
162
+
117
163
  param_context = ParameterizationContext(parent_stack=[])
118
164
  transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
119
165
  context.current_expression = transformed_expression
120
- context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
166
+
167
+ # If we're reordering, update the merged parameters with the reordered result
168
+ # In this case, we don't need to add to extracted_parameters_from_pipeline
169
+ # because the parameters are already in _final_params
170
+ if self._is_reordering_needed and self._final_params:
171
+ context.merged_parameters = self._final_params
172
+ else:
173
+ # Only add extracted parameters to the pipeline if we're not reordering
174
+ # This prevents duplication when parameters are already in merged_parameters
175
+ context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
121
176
 
122
177
  context.metadata["parameter_metadata"] = self._parameter_metadata
123
178
 
@@ -138,6 +193,32 @@ class ParameterizeLiterals(ProcessorProtocol):
138
193
  result = self._process_array(node, context)
139
194
  elif isinstance(node, exp.In) and self.parameterize_in_lists:
140
195
  result = self._process_in_clause(node, context)
196
+ elif isinstance(node, exp.Placeholder) and self._is_reordering_needed:
197
+ # Handle existing placeholders when reordering is needed
198
+ result = self._process_existing_placeholder(node, context)
199
+ elif isinstance(node, exp.Parameter) and self._is_reordering_needed:
200
+ # Handle PostgreSQL-style parameters ($1, $2) when reordering is needed
201
+ result = self._process_existing_parameter(node, context)
202
+ elif isinstance(node, exp.Column) and self._is_reordering_needed:
203
+ # Check if this column looks like a PostgreSQL parameter ($1, $2, etc.)
204
+ column_name = str(node.this) if hasattr(node, "this") else ""
205
+ if column_name.startswith("$") and column_name[1:].isdigit():
206
+ # This is a PostgreSQL-style parameter parsed as a column
207
+ result = self._process_postgresql_column_parameter(node, context)
208
+ else:
209
+ # Regular column - process children
210
+ for key, value in node.args.items():
211
+ if isinstance(value, exp.Expression):
212
+ node.set(key, self._transform_with_context(value, context))
213
+ elif isinstance(value, list):
214
+ node.set(
215
+ key,
216
+ [
217
+ self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
218
+ for v in value
219
+ ],
220
+ )
221
+ result = node
141
222
  else:
142
223
  # Recursively process children
143
224
  for key, value in node.args.items():
@@ -161,51 +242,136 @@ class ParameterizeLiterals(ProcessorProtocol):
161
242
  def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
162
243
  """Update parameterization context based on current AST node."""
163
244
  if entering:
164
- context.parent_stack.append(node)
165
-
166
- if isinstance(node, Func):
167
- context.function_depth += 1
168
- # Get function name from class name or node.name
169
- func_name = node.__class__.__name__.upper()
170
- if func_name in self.preserve_in_functions or (
171
- node.name and node.name.upper() in self.preserve_in_functions
172
- ):
173
- context.in_function_args = True
174
- elif isinstance(node, exp.Case):
175
- context.in_case_when = True
176
- elif isinstance(node, Array):
177
- context.in_array = True
178
- elif isinstance(node, exp.In):
179
- context.in_in_clause = True
180
- elif isinstance(node, exp.CTE):
181
- context.cte_depth += 1
182
- # Check if this CTE is recursive:
183
- # 1. Parent WITH must be RECURSIVE
184
- # 2. CTE must contain UNION (characteristic of recursive CTEs)
185
- is_in_recursive_with = any(
186
- isinstance(parent, exp.With) and parent.args.get("recursive", False)
187
- for parent in reversed(context.parent_stack)
188
- )
189
- if is_in_recursive_with and self._contains_union(node):
190
- context.in_recursive_cte = True
245
+ self._update_context_entering(node, context)
191
246
  else:
192
- if context.parent_stack:
193
- context.parent_stack.pop()
194
-
195
- if isinstance(node, Func):
196
- context.function_depth -= 1
197
- if context.function_depth == 0:
198
- context.in_function_args = False
199
- elif isinstance(node, exp.Case):
200
- context.in_case_when = False
201
- elif isinstance(node, Array):
202
- context.in_array = False
203
- elif isinstance(node, exp.In):
204
- context.in_in_clause = False
205
- elif isinstance(node, exp.CTE):
206
- context.cte_depth -= 1
207
- if context.cte_depth == 0:
208
- context.in_recursive_cte = False
247
+ self._update_context_leaving(node, context)
248
+
249
+ def _update_context_entering(self, node: exp.Expression, context: ParameterizationContext) -> None:
250
+ """Update context when entering a node."""
251
+ context.parent_stack.append(node)
252
+
253
+ if isinstance(node, Func):
254
+ self._update_context_entering_func(node, context)
255
+ elif isinstance(node, exp.Case):
256
+ context.in_case_when = True
257
+ elif isinstance(node, Array):
258
+ context.in_array = True
259
+ elif isinstance(node, exp.In):
260
+ context.in_in_clause = True
261
+ elif isinstance(node, exp.CTE):
262
+ self._update_context_entering_cte(node, context)
263
+ elif isinstance(node, exp.Subquery):
264
+ context.subquery_depth += 1
265
+ context.in_subquery = True
266
+ elif isinstance(node, exp.Select):
267
+ self._update_context_entering_select(node, context)
268
+ elif isinstance(node, exp.Join):
269
+ context.in_join_condition = True
270
+
271
+ def _update_context_entering_func(self, node: Func, context: ParameterizationContext) -> None:
272
+ """Update context when entering a function node."""
273
+ context.function_depth += 1
274
+ # Get function name from class name or node.name
275
+ func_name = node.__class__.__name__.upper()
276
+ if func_name in self.preserve_in_functions or (node.name and node.name.upper() in self.preserve_in_functions):
277
+ context.in_function_args = True
278
+
279
+ def _update_context_entering_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
280
+ """Update context when entering a CTE node."""
281
+ context.cte_depth += 1
282
+ # Check if this CTE is recursive:
283
+ # 1. Parent WITH must be RECURSIVE
284
+ # 2. CTE must contain UNION (characteristic of recursive CTEs)
285
+ is_in_recursive_with = any(
286
+ isinstance(parent, exp.With) and parent.args.get("recursive", False)
287
+ for parent in reversed(context.parent_stack)
288
+ )
289
+ if is_in_recursive_with and self._contains_union(node):
290
+ context.in_recursive_cte = True
291
+
292
+ def _update_context_entering_select(self, node: exp.Select, context: ParameterizationContext) -> None:
293
+ """Update context when entering a SELECT node."""
294
+ # Only track nested SELECT statements as subqueries if they're not part of a recursive CTE
295
+ is_in_recursive_cte = any(
296
+ isinstance(parent, exp.CTE)
297
+ and any(
298
+ isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
299
+ for grandparent in context.parent_stack
300
+ )
301
+ for parent in context.parent_stack[:-1]
302
+ )
303
+
304
+ if not is_in_recursive_cte and any(
305
+ isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
306
+ for parent in context.parent_stack[:-1] # Exclude the current node
307
+ ):
308
+ context.subquery_depth += 1
309
+ context.in_subquery = True
310
+ # Check if we're in a SELECT clause expressions list
311
+ if hasattr(node, "expressions"):
312
+ # We'll handle this specifically when processing individual expressions
313
+ context.in_select_list = False # Will be detected by _is_in_select_expressions
314
+
315
+ def _update_context_leaving(self, node: exp.Expression, context: ParameterizationContext) -> None:
316
+ """Update context when leaving a node."""
317
+ if context.parent_stack:
318
+ context.parent_stack.pop()
319
+
320
+ if isinstance(node, Func):
321
+ self._update_context_leaving_func(node, context)
322
+ elif isinstance(node, exp.Case):
323
+ context.in_case_when = False
324
+ elif isinstance(node, Array):
325
+ context.in_array = False
326
+ elif isinstance(node, exp.In):
327
+ context.in_in_clause = False
328
+ elif isinstance(node, exp.CTE):
329
+ self._update_context_leaving_cte(node, context)
330
+ elif isinstance(node, exp.Subquery):
331
+ self._update_context_leaving_subquery(node, context)
332
+ elif isinstance(node, exp.Select):
333
+ self._update_context_leaving_select(node, context)
334
+ elif isinstance(node, exp.Join):
335
+ context.in_join_condition = False
336
+
337
+ def _update_context_leaving_func(self, node: Func, context: ParameterizationContext) -> None:
338
+ """Update context when leaving a function node."""
339
+ context.function_depth -= 1
340
+ if context.function_depth == 0:
341
+ context.in_function_args = False
342
+
343
+ def _update_context_leaving_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
344
+ """Update context when leaving a CTE node."""
345
+ context.cte_depth -= 1
346
+ if context.cte_depth == 0:
347
+ context.in_recursive_cte = False
348
+
349
+ def _update_context_leaving_subquery(self, node: exp.Subquery, context: ParameterizationContext) -> None:
350
+ """Update context when leaving a subquery node."""
351
+ context.subquery_depth -= 1
352
+ if context.subquery_depth == 0:
353
+ context.in_subquery = False
354
+
355
+ def _update_context_leaving_select(self, node: exp.Select, context: ParameterizationContext) -> None:
356
+ """Update context when leaving a SELECT node."""
357
+ # Only decrement if this was a nested SELECT (not part of recursive CTE)
358
+ is_in_recursive_cte = any(
359
+ isinstance(parent, exp.CTE)
360
+ and any(
361
+ isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
362
+ for grandparent in context.parent_stack
363
+ )
364
+ for parent in context.parent_stack[:-1]
365
+ )
366
+
367
+ if not is_in_recursive_cte and any(
368
+ isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
369
+ for parent in context.parent_stack[:-1] # Exclude current node
370
+ ):
371
+ context.subquery_depth -= 1
372
+ if context.subquery_depth == 0:
373
+ context.in_subquery = False
374
+ context.in_select_list = False
209
375
 
210
376
  def _process_literal_with_context(
211
377
  self, literal: exp.Expression, context: ParameterizationContext
@@ -228,11 +394,26 @@ class ParameterizeLiterals(ProcessorProtocol):
228
394
  semantic_name=semantic_name,
229
395
  )
230
396
 
231
- # Add to parameters list
397
+ # Always track extracted parameters for proper merging
232
398
  self.extracted_parameters.append(typed_param)
399
+
400
+ # If we're reordering, also add to final params directly
401
+ if self._is_reordering_needed:
402
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
403
+ # For dict format, we need a key
404
+ param_key = semantic_name or f"param_{len(self._final_params)}"
405
+ self._final_params[param_key] = typed_param
406
+ elif isinstance(self._final_params, list):
407
+ self._final_params.append(typed_param)
408
+ else:
409
+ # Fallback - this shouldn't happen but handle gracefully
410
+ if not hasattr(self, "_fallback_params"):
411
+ self._fallback_params = []
412
+ self._fallback_params.append(typed_param)
413
+
233
414
  self._parameter_metadata.append(
234
415
  {
235
- "index": len(self.extracted_parameters) - 1,
416
+ "index": len(self._final_params if self._is_reordering_needed else self.extracted_parameters) - 1,
236
417
  "type": type_hint,
237
418
  "semantic_name": semantic_name,
238
419
  "context": self._get_context_description(context),
@@ -242,23 +423,343 @@ class ParameterizeLiterals(ProcessorProtocol):
242
423
  # Create appropriate placeholder
243
424
  return self._create_placeholder(hint=semantic_name)
244
425
 
426
+ def _process_existing_placeholder(self, node: exp.Placeholder, context: ParameterizationContext) -> exp.Expression:
427
+ """Process an existing placeholder when reordering parameters."""
428
+ if self._original_params is None:
429
+ return node
430
+
431
+ if isinstance(self._original_params, (list, tuple)):
432
+ self._handle_list_params_for_placeholder(node)
433
+ elif isinstance(self._original_params, dict):
434
+ self._handle_dict_params_for_placeholder(node)
435
+ else:
436
+ # Single value parameter
437
+ self._handle_single_value_param_for_placeholder(node)
438
+
439
+ return node
440
+
441
+ def _handle_list_params_for_placeholder(self, node: exp.Placeholder) -> None:
442
+ """Handle list/tuple parameters for placeholder."""
443
+ if isinstance(self._original_params, (list, tuple)) and self._user_param_index < len(self._original_params):
444
+ value = self._original_params[self._user_param_index]
445
+ self._add_to_final_params(value, node)
446
+ self._user_param_index += 1
447
+ else:
448
+ # More placeholders than user parameters
449
+ self._add_to_final_params(None, node)
450
+
451
+ def _handle_dict_params_for_placeholder(self, node: exp.Placeholder) -> None:
452
+ """Handle dict parameters for placeholder."""
453
+ if not isinstance(self._original_params, dict):
454
+ self._add_to_final_params(None, node)
455
+ return
456
+
457
+ raw_placeholder_name = node.this if hasattr(node, "this") else None
458
+ if not raw_placeholder_name:
459
+ # Unnamed placeholder '?' with dict params is ambiguous
460
+ self._add_to_final_params(None, node)
461
+ return
462
+
463
+ # FIX: Normalize the placeholder name by stripping leading sigils
464
+ placeholder_name = raw_placeholder_name.lstrip(":@")
465
+
466
+ # Debug logging
467
+
468
+ if placeholder_name in self._original_params:
469
+ # Direct match for placeholder name
470
+ self._add_to_final_params(self._original_params[placeholder_name], node)
471
+ elif placeholder_name.isdigit() and self._user_param_index == 0:
472
+ # Oracle-style numeric parameters
473
+ self._handle_oracle_numeric_params()
474
+ self._user_param_index += 1
475
+ elif placeholder_name.isdigit() and self._user_param_index > 0:
476
+ # Already handled Oracle params
477
+ pass
478
+ elif self._user_param_index == 0 and len(self._original_params) > 0:
479
+ # Single dict parameter case
480
+ self._handle_single_dict_param()
481
+ self._user_param_index += 1
482
+ else:
483
+ # No match found
484
+ self._add_to_final_params(None, node)
485
+
486
+ def _handle_single_value_param_for_placeholder(self, node: exp.Placeholder) -> None:
487
+ """Handle single value parameter for placeholder."""
488
+ if self._user_param_index == 0:
489
+ self._add_to_final_params(self._original_params, node)
490
+ self._user_param_index += 1
491
+ else:
492
+ self._add_to_final_params(None, node)
493
+
494
+ def _handle_oracle_numeric_params(self) -> None:
495
+ """Handle Oracle-style numeric parameters."""
496
+ if not isinstance(self._original_params, dict):
497
+ return
498
+
499
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
500
+ for k, v in self._original_params.items():
501
+ if k.isdigit():
502
+ self._final_params[k] = v
503
+ else:
504
+ # Convert to positional list
505
+ numeric_keys = [k for k in self._original_params if k.isdigit()]
506
+ if numeric_keys:
507
+ max_index = max(int(k) for k in numeric_keys)
508
+ param_list = [None] * (max_index + 1)
509
+ for k, v in self._original_params.items():
510
+ if k.isdigit():
511
+ param_list[int(k)] = v
512
+ if isinstance(self._final_params, list):
513
+ self._final_params.extend(param_list)
514
+ elif isinstance(self._final_params, dict):
515
+ for i, val in enumerate(param_list):
516
+ self._final_params[str(i)] = val
517
+
518
+ def _handle_single_dict_param(self) -> None:
519
+ """Handle single dict parameter case."""
520
+ if not isinstance(self._original_params, dict):
521
+ return
522
+
523
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
524
+ for k, v in self._original_params.items():
525
+ self._final_params[k] = v
526
+ elif isinstance(self._final_params, list):
527
+ self._final_params.append(self._original_params)
528
+ elif isinstance(self._final_params, dict):
529
+ param_name = f"param_{len(self._final_params)}"
530
+ self._final_params[param_name] = self._original_params
531
+
532
+ def _add_to_final_params(self, value: Any, node: exp.Placeholder) -> None:
533
+ """Add a value to final params with proper type handling."""
534
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
535
+ placeholder_name = node.this if hasattr(node, "this") else f"param_{self._user_param_index}"
536
+ self._final_params[placeholder_name] = value
537
+ elif isinstance(self._final_params, list):
538
+ self._final_params.append(value)
539
+ elif isinstance(self._final_params, dict):
540
+ param_name = f"param_{len(self._final_params)}"
541
+ self._final_params[param_name] = value
542
+
543
+ def _process_existing_parameter(self, node: exp.Parameter, context: ParameterizationContext) -> exp.Expression:
544
+ """Process existing parameters (both numeric and named) when reordering parameters."""
545
+ # First try to get parameter name for named parameters (like BigQuery @param_name)
546
+ param_name = self._extract_parameter_name(node)
547
+
548
+ if param_name and isinstance(self._original_params, dict) and param_name in self._original_params:
549
+ value = self._original_params[param_name]
550
+ self._add_param_value_to_finals(value)
551
+ return node
552
+
553
+ # Fall back to numeric parameter handling for PostgreSQL-style parameters ($1, $2)
554
+ param_index = self._extract_parameter_index(node)
555
+
556
+ if self._original_params is None:
557
+ self._add_none_to_final_params()
558
+ elif isinstance(self._original_params, (list, tuple)):
559
+ self._handle_list_params_for_parameter_node(param_index)
560
+ elif isinstance(self._original_params, dict):
561
+ self._handle_dict_params_for_parameter_node(param_index)
562
+ elif param_index == 0:
563
+ # Single parameter case
564
+ self._add_param_value_to_finals(self._original_params)
565
+ else:
566
+ self._add_none_to_final_params()
567
+
568
+ # Return the parameter unchanged
569
+ return node
570
+
571
+ @staticmethod
572
+ def _extract_parameter_name(node: exp.Parameter) -> Optional[str]:
573
+ """Extract parameter name from a Parameter node for named parameters."""
574
+ if hasattr(node, "this"):
575
+ if isinstance(node.this, exp.Var):
576
+ # Named parameter like @min_value -> min_value
577
+ return str(node.this.this)
578
+ if hasattr(node.this, "this"):
579
+ # Handle other node types that might contain the name
580
+ return str(node.this.this)
581
+ return None
582
+
583
+ @staticmethod
584
+ def _extract_parameter_index(node: exp.Parameter) -> Optional[int]:
585
+ """Extract parameter index from a Parameter node."""
586
+ if hasattr(node, "this") and isinstance(node.this, Literal):
587
+ import contextlib
588
+
589
+ with contextlib.suppress(ValueError, TypeError):
590
+ return int(node.this.this) - 1 # Convert to 0-based index
591
+ return None
592
+
593
+ def _handle_list_params_for_parameter_node(self, param_index: Optional[int]) -> None:
594
+ """Handle list/tuple parameters for Parameter node."""
595
+ if (
596
+ isinstance(self._original_params, (list, tuple))
597
+ and param_index is not None
598
+ and 0 <= param_index < len(self._original_params)
599
+ ):
600
+ # Use the parameter at the specified index
601
+ self._add_param_value_to_finals(self._original_params[param_index])
602
+ else:
603
+ # More parameters than user provided
604
+ self._add_none_to_final_params()
605
+
606
+ def _handle_dict_params_for_parameter_node(self, param_index: Optional[int]) -> None:
607
+ """Handle dict parameters for Parameter node."""
608
+ if param_index is not None:
609
+ self._handle_dict_param_with_index(param_index)
610
+ else:
611
+ self._add_none_to_final_params()
612
+
613
+ def _handle_dict_param_with_index(self, param_index: int) -> None:
614
+ """Handle dict parameter when we have an index."""
615
+ if not isinstance(self._original_params, dict):
616
+ self._add_none_to_final_params()
617
+ return
618
+
619
+ # Try param_N key first
620
+ param_key = f"param_{param_index}"
621
+ if param_key in self._original_params:
622
+ self._add_dict_value_to_finals(param_key)
623
+ return
624
+
625
+ # Try direct numeric key (1-based)
626
+ numeric_key = str(param_index + 1)
627
+ if numeric_key in self._original_params:
628
+ self._add_dict_value_to_finals(numeric_key)
629
+ else:
630
+ self._add_none_to_final_params()
631
+
632
+ def _add_dict_value_to_finals(self, key: str) -> None:
633
+ """Add a value from dict params to final params."""
634
+ if isinstance(self._original_params, dict) and key in self._original_params:
635
+ value = self._original_params[key]
636
+ if isinstance(self._final_params, list):
637
+ self._final_params.append(value)
638
+ elif isinstance(self._final_params, dict):
639
+ self._final_params[key] = value
640
+
641
+ def _add_param_value_to_finals(self, value: Any) -> None:
642
+ """Add a parameter value to final params."""
643
+ if isinstance(self._final_params, list):
644
+ self._final_params.append(value)
645
+ elif isinstance(self._final_params, dict):
646
+ param_name = f"param_{len(self._final_params)}"
647
+ self._final_params[param_name] = value
648
+
649
+ def _add_none_to_final_params(self) -> None:
650
+ """Add None to final params."""
651
+ if isinstance(self._final_params, list):
652
+ self._final_params.append(None)
653
+ elif isinstance(self._final_params, dict):
654
+ param_name = f"param_{len(self._final_params)}"
655
+ self._final_params[param_name] = None
656
+
657
+ def _process_postgresql_column_parameter(
658
+ self, node: exp.Column, context: ParameterizationContext
659
+ ) -> exp.Expression:
660
+ """Process PostgreSQL-style parameters that were parsed as columns ($1, $2)."""
661
+ # Extract the numeric part from $1, $2, etc.
662
+ column_name = str(node.this) if hasattr(node, "this") else ""
663
+ param_index = None
664
+
665
+ if column_name.startswith("$") and column_name[1:].isdigit():
666
+ import contextlib
667
+
668
+ with contextlib.suppress(ValueError, TypeError):
669
+ param_index = int(column_name[1:]) - 1 # Convert to 0-based index
670
+
671
+ if self._original_params is None:
672
+ # No user parameters provided - don't add None
673
+ return node
674
+ if isinstance(self._original_params, (list, tuple)):
675
+ # When we have mixed parameter styles and reordering is needed,
676
+ # use sequential assignment based on _user_param_index
677
+ if self._is_reordering_needed:
678
+ # For mixed styles, parameters should be assigned sequentially
679
+ # regardless of the numeric value in the placeholder
680
+ if self._user_param_index < len(self._original_params):
681
+ param_value = self._original_params[self._user_param_index]
682
+ self._user_param_index += 1
683
+ else:
684
+ param_value = None
685
+ else:
686
+ # Non-mixed styles - use the numeric value from the placeholder
687
+ param_value = (
688
+ self._original_params[param_index]
689
+ if param_index is not None and 0 <= param_index < len(self._original_params)
690
+ else None
691
+ )
692
+
693
+ if param_value is not None:
694
+ # Add the parameter value to final params
695
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
696
+ param_key = f"param_{len(self._final_params)}"
697
+ self._final_params[param_key] = param_value
698
+ elif isinstance(self._final_params, list):
699
+ self._final_params.append(param_value)
700
+ elif isinstance(self._final_params, dict):
701
+ param_name = f"param_{len(self._final_params)}"
702
+ self._final_params[param_name] = param_value
703
+ # More parameters than user provided - don't add None
704
+ elif isinstance(self._original_params, dict):
705
+ # For dict parameters with numeric placeholders, try to map by index
706
+ if param_index is not None:
707
+ param_key = f"param_{param_index}"
708
+ if param_key in self._original_params:
709
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
710
+ self._final_params[param_key] = self._original_params[param_key]
711
+ elif isinstance(self._final_params, list):
712
+ self._final_params.append(self._original_params[param_key])
713
+ elif isinstance(self._final_params, dict):
714
+ self._final_params[param_key] = self._original_params[param_key]
715
+ else:
716
+ # Try direct numeric key
717
+ numeric_key = str(param_index + 1) # 1-based
718
+ if numeric_key in self._original_params:
719
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
720
+ self._final_params[numeric_key] = self._original_params[numeric_key]
721
+ elif isinstance(self._final_params, list):
722
+ self._final_params.append(self._original_params[numeric_key])
723
+ elif isinstance(self._final_params, dict):
724
+ self._final_params[numeric_key] = self._original_params[numeric_key]
725
+ # Single parameter case
726
+ elif param_index == 0:
727
+ if self._preserve_dict_format and isinstance(self._final_params, dict):
728
+ param_key = f"param_{len(self._final_params)}"
729
+ self._final_params[param_key] = self._original_params
730
+ elif isinstance(self._final_params, list):
731
+ self._final_params.append(self._original_params)
732
+ elif isinstance(self._final_params, dict):
733
+ param_name = f"param_{len(self._final_params)}"
734
+ self._final_params[param_name] = self._original_params
735
+
736
+ # Return the column unchanged - it represents the parameter placeholder
737
+ return node
738
+
245
739
  def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
246
- """Context-aware decision on literal preservation."""
247
- # Check for NULL values
740
+ """Enhanced context-aware decision on literal preservation."""
741
+ # Existing preservation rules (maintain compatibility)
248
742
  if self.preserve_null and isinstance(literal, Null):
249
743
  return True
250
744
 
251
- # Check for boolean values
252
745
  if self.preserve_boolean and isinstance(literal, Boolean):
253
746
  return True
254
747
 
748
+ # NEW: Context-based preservation rules
749
+
750
+ # Rule 4: Preserve enum-like literals in subquery lookups (the main fix we need)
751
+ if context.in_subquery and self._is_scalar_lookup_pattern(literal, context):
752
+ return self._is_enum_like_literal(literal)
753
+
754
+ # Existing preservation rules continue...
755
+
255
756
  # Check if in preserved function arguments
256
757
  if context.in_function_args:
257
758
  return True
258
759
 
259
- # Preserve literals in recursive CTEs to avoid type inference issues
760
+ # ENHANCED: Intelligent recursive CTE literal preservation
260
761
  if self.preserve_in_recursive_cte and context.in_recursive_cte:
261
- return True
762
+ return self._should_preserve_literal_in_recursive_cte(literal, context)
262
763
 
263
764
  # Check if this literal is being used as an alias value in SELECT
264
765
  # e.g., 'computed' as process_status should be preserved
@@ -299,6 +800,76 @@ class ParameterizeLiterals(ProcessorProtocol):
299
800
 
300
801
  return False
301
802
 
803
+ def _is_in_select_expressions(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
804
+ """Check if literal is in SELECT clause expressions (critical for type inference)."""
805
+ for parent in reversed(context.parent_stack):
806
+ if isinstance(parent, exp.Select):
807
+ if hasattr(parent, "expressions") and parent.expressions:
808
+ return any(self._literal_is_in_expression_tree(literal, expr) for expr in parent.expressions)
809
+ elif isinstance(parent, (exp.Where, exp.Having, exp.Join)):
810
+ return False
811
+ return False
812
+
813
+ def _is_recursive_computation(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
814
+ """Check if literal is part of recursive computation logic."""
815
+ # Look for arithmetic operations that are part of recursive logic
816
+ for parent in reversed(context.parent_stack):
817
+ if isinstance(parent, exp.Binary) and parent.key in ("ADD", "SUB", "MUL", "DIV"):
818
+ # Check if this arithmetic is in a SELECT clause of a recursive part
819
+ return self._is_in_select_expressions(literal, context)
820
+ return False
821
+
822
+ def _should_preserve_literal_in_recursive_cte(
823
+ self, literal: exp.Expression, context: ParameterizationContext
824
+ ) -> bool:
825
+ """Intelligent recursive CTE literal preservation based on semantic role."""
826
+ # Preserve SELECT clause literals (type inference critical)
827
+ if self._is_in_select_expressions(literal, context):
828
+ return True
829
+
830
+ # Preserve recursive computation literals (core logic)
831
+ return self._is_recursive_computation(literal, context)
832
+
833
+ def _literal_is_in_expression_tree(self, target_literal: exp.Expression, expr: exp.Expression) -> bool:
834
+ """Check if target literal is within the given expression tree."""
835
+ if expr == target_literal:
836
+ return True
837
+ # Recursively check child expressions
838
+ return any(child == target_literal for child in expr.iter_expressions())
839
+
840
+ def _is_scalar_lookup_pattern(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
841
+ """Detect if literal is part of a scalar subquery lookup pattern."""
842
+ # Must be in a subquery for this pattern to apply
843
+ if context.subquery_depth == 0:
844
+ return False
845
+
846
+ # Check if we're in a WHERE clause of a subquery that returns a single column
847
+ # and the literal is being compared against a column
848
+ for parent in reversed(context.parent_stack):
849
+ if isinstance(parent, exp.Where):
850
+ # Look for pattern: WHERE column = 'literal'
851
+ if isinstance(parent.this, exp.Binary) and parent.this.right == literal:
852
+ return isinstance(parent.this.left, exp.Column)
853
+ # Also check for literal on the left side: WHERE 'literal' = column
854
+ if isinstance(parent.this, exp.Binary) and parent.this.left == literal:
855
+ return isinstance(parent.this.right, exp.Column)
856
+ return False
857
+
858
+ def _is_enum_like_literal(self, literal: exp.Expression) -> bool:
859
+ """Detect if literal looks like an enum/identifier constant."""
860
+ if not isinstance(literal, exp.Literal) or not self._is_string_literal(literal):
861
+ return False
862
+
863
+ value = str(literal.this)
864
+
865
+ # Conservative heuristics for enum-like values
866
+ return (
867
+ len(value) <= MAX_ENUM_LENGTH # Reasonable length limit
868
+ and value.replace("_", "").isalnum() # Only alphanumeric + underscores
869
+ and not value.isdigit() # Not a pure number
870
+ and len(value) > MIN_ENUM_LENGTH # Not too short to be meaningful
871
+ )
872
+
302
873
  def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
303
874
  """Extract the Python value and type info from a SQLGlot literal."""
304
875
  if isinstance(literal, Null) or literal.this is None:
@@ -551,7 +1122,6 @@ class ParameterizeLiterals(ProcessorProtocol):
551
1122
  array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
552
1123
 
553
1124
  # Create TypedParameter for the entire array
554
- from sqlspec.statement.parameters import TypedParameter
555
1125
 
556
1126
  typed_param = TypedParameter(
557
1127
  value=array_values,