sqlspec 0.14.1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (158) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +256 -120
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +6 -64
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +828 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +651 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +168 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/config.py +0 -1
  90. sqlspec/extensions/litestar/handlers.py +15 -26
  91. sqlspec/extensions/litestar/plugin.py +16 -14
  92. sqlspec/extensions/litestar/providers.py +17 -52
  93. sqlspec/loader.py +424 -105
  94. sqlspec/migrations/__init__.py +12 -0
  95. sqlspec/migrations/base.py +92 -68
  96. sqlspec/migrations/commands.py +24 -106
  97. sqlspec/migrations/loaders.py +402 -0
  98. sqlspec/migrations/runner.py +49 -51
  99. sqlspec/migrations/tracker.py +31 -44
  100. sqlspec/migrations/utils.py +64 -24
  101. sqlspec/protocols.py +7 -183
  102. sqlspec/storage/__init__.py +1 -1
  103. sqlspec/storage/backends/base.py +37 -40
  104. sqlspec/storage/backends/fsspec.py +136 -112
  105. sqlspec/storage/backends/obstore.py +138 -160
  106. sqlspec/storage/capabilities.py +5 -4
  107. sqlspec/storage/registry.py +57 -106
  108. sqlspec/typing.py +136 -115
  109. sqlspec/utils/__init__.py +2 -3
  110. sqlspec/utils/correlation.py +0 -3
  111. sqlspec/utils/deprecation.py +6 -6
  112. sqlspec/utils/fixtures.py +6 -6
  113. sqlspec/utils/logging.py +0 -2
  114. sqlspec/utils/module_loader.py +7 -12
  115. sqlspec/utils/singleton.py +0 -1
  116. sqlspec/utils/sync_tools.py +16 -37
  117. sqlspec/utils/text.py +12 -51
  118. sqlspec/utils/type_guards.py +443 -232
  119. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
  120. sqlspec-0.15.0.dist-info/RECORD +134 -0
  121. sqlspec/adapters/adbc/transformers.py +0 -108
  122. sqlspec/driver/connection.py +0 -207
  123. sqlspec/driver/mixins/_cache.py +0 -114
  124. sqlspec/driver/mixins/_csv_writer.py +0 -91
  125. sqlspec/driver/mixins/_pipeline.py +0 -508
  126. sqlspec/driver/mixins/_query_tools.py +0 -796
  127. sqlspec/driver/mixins/_result_utils.py +0 -138
  128. sqlspec/driver/mixins/_storage.py +0 -912
  129. sqlspec/driver/mixins/_type_coercion.py +0 -128
  130. sqlspec/driver/parameters.py +0 -138
  131. sqlspec/statement/__init__.py +0 -21
  132. sqlspec/statement/builder/_merge.py +0 -95
  133. sqlspec/statement/cache.py +0 -50
  134. sqlspec/statement/filters.py +0 -625
  135. sqlspec/statement/parameters.py +0 -956
  136. sqlspec/statement/pipelines/__init__.py +0 -210
  137. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  138. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  139. sqlspec/statement/pipelines/context.py +0 -109
  140. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  141. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  142. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  143. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  144. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  145. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  146. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  147. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  148. sqlspec/statement/pipelines/validators/_security.py +0 -967
  149. sqlspec/statement/result.py +0 -435
  150. sqlspec/statement/sql.py +0 -1774
  151. sqlspec/utils/cached_property.py +0 -25
  152. sqlspec/utils/statement_hashing.py +0 -203
  153. sqlspec-0.14.1.dist-info/RECORD +0 -145
  154. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  155. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/entry_points.txt +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,1247 +0,0 @@
1
- """Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
2
-
3
- from dataclasses import dataclass
4
- from typing import Any, Optional, Union
5
-
6
- from sqlglot import exp
7
- from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
8
-
9
- from sqlspec.protocols import ProcessorProtocol
10
- from sqlspec.statement.parameters import ParameterStyle, TypedParameter
11
- from sqlspec.statement.pipelines.context import SQLProcessingContext
12
-
13
- __all__ = ("ParameterizationContext", "ParameterizeLiterals")
14
-
15
- # Constants for magic values and literal parameterization
16
- MAX_DECIMAL_PRECISION = 6
17
- MAX_INT32_VALUE = 2147483647
18
- DEFAULT_MAX_STRING_LENGTH = 1000
19
- """Default maximum string length for literal parameterization."""
20
-
21
- DEFAULT_MAX_ARRAY_LENGTH = 100
22
- """Default maximum array length for literal parameterization."""
23
-
24
- DEFAULT_MAX_IN_LIST_SIZE = 50
25
- """Default maximum IN clause list size before parameterization."""
26
-
27
- MAX_ENUM_LENGTH = 50
28
- """Maximum length for enum-like string values."""
29
-
30
- MIN_ENUM_LENGTH = 2
31
- """Minimum length for enum-like string values to be meaningful."""
32
-
33
-
34
- @dataclass
35
- class ParameterizationContext:
36
- """Context for tracking parameterization state during AST traversal."""
37
-
38
- parent_stack: list[exp.Expression]
39
- in_function_args: bool = False
40
- in_case_when: bool = False
41
- in_array: bool = False
42
- in_in_clause: bool = False
43
- in_recursive_cte: bool = False
44
- in_subquery: bool = False
45
- in_select_list: bool = False
46
- in_join_condition: bool = False
47
- function_depth: int = 0
48
- cte_depth: int = 0
49
- subquery_depth: int = 0
50
-
51
-
52
- class ParameterizeLiterals(ProcessorProtocol):
53
- """Advanced literal parameterization using SQLGlot AST analysis.
54
-
55
- This enhanced version provides:
56
- - Context-aware parameterization based on AST position
57
- - Smart handling of arrays, IN clauses, and function arguments
58
- - Type-preserving parameter extraction
59
- - Configurable parameterization strategies
60
- - Performance optimization for query plan caching
61
-
62
- Args:
63
- placeholder_style: Style of placeholder to use ("?", ":name", "$1", etc.).
64
- preserve_null: Whether to preserve NULL literals as-is.
65
- preserve_boolean: Whether to preserve boolean literals as-is.
66
- preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
67
- preserve_in_functions: List of function names where literals should be preserved.
68
- preserve_in_recursive_cte: Whether to preserve literals in recursive CTEs (default True to avoid type inference issues).
69
- parameterize_arrays: Whether to parameterize array literals.
70
- parameterize_in_lists: Whether to parameterize IN clause lists.
71
- max_string_length: Maximum string length to parameterize.
72
- max_array_length: Maximum array length to parameterize.
73
- max_in_list_size: Maximum IN list size to parameterize.
74
- type_preservation: Whether to preserve exact literal types.
75
- """
76
-
77
- def __init__(
78
- self,
79
- placeholder_style: str = "?",
80
- preserve_null: bool = True,
81
- preserve_boolean: bool = True,
82
- preserve_numbers_in_limit: bool = True,
83
- preserve_in_functions: Optional[list[str]] = None,
84
- preserve_in_recursive_cte: bool = True,
85
- parameterize_arrays: bool = True,
86
- parameterize_in_lists: bool = True,
87
- max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
88
- max_array_length: int = DEFAULT_MAX_ARRAY_LENGTH,
89
- max_in_list_size: int = DEFAULT_MAX_IN_LIST_SIZE,
90
- type_preservation: bool = True,
91
- ) -> None:
92
- self.placeholder_style = placeholder_style
93
- self.preserve_null = preserve_null
94
- self.preserve_boolean = preserve_boolean
95
- self.preserve_numbers_in_limit = preserve_numbers_in_limit
96
- self.preserve_in_recursive_cte = preserve_in_recursive_cte
97
- self.preserve_in_functions = preserve_in_functions or [
98
- "COALESCE",
99
- "IFNULL",
100
- "NVL",
101
- "ISNULL",
102
- # Array functions that take dimension arguments
103
- "ARRAYSIZE", # SQLglot converts array_length to ArraySize
104
- "ARRAY_UPPER",
105
- "ARRAY_LOWER",
106
- "ARRAY_NDIMS",
107
- "ROUND",
108
- ]
109
- self.parameterize_arrays = parameterize_arrays
110
- self.parameterize_in_lists = parameterize_in_lists
111
- self.max_string_length = max_string_length
112
- self.max_array_length = max_array_length
113
- self.max_in_list_size = max_in_list_size
114
- self.type_preservation = type_preservation
115
- self.extracted_parameters: list[Any] = []
116
- self._parameter_counter = 0
117
- self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
118
- self._preserve_dict_format = False # Track whether to preserve dict format
119
-
120
- def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
121
- """Advanced literal parameterization with context-aware AST analysis."""
122
- if expression is None or context.current_expression is None:
123
- return expression
124
-
125
- # For named parameters (like BigQuery @param), don't reorder to avoid breaking name mapping
126
- if (
127
- context.config.input_sql_had_placeholders
128
- and context.parameter_info
129
- and any(p.name for p in context.parameter_info)
130
- ):
131
- return expression
132
-
133
- self.extracted_parameters = []
134
- self._parameter_metadata = []
135
-
136
- # When reordering is needed (SQL already has placeholders), we need to start
137
- # our counter at the number of existing parameters to avoid conflicts
138
- if context.config.input_sql_had_placeholders and context.parameter_info:
139
- # Find the highest ordinal among existing parameters
140
- max_ordinal = max(p.ordinal for p in context.parameter_info)
141
- self._parameter_counter = max_ordinal + 1
142
- else:
143
- self._parameter_counter = 0
144
-
145
- # Track original user parameters for proper merging
146
- self._original_params = context.merged_parameters
147
- self._user_param_index = 0
148
- # If original params are dict and we have named placeholders, preserve dict format
149
- if isinstance(context.merged_parameters, dict) and context.parameter_info:
150
- # Check if we have named placeholders
151
- has_named = any(p.name for p in context.parameter_info)
152
- if has_named:
153
- self._final_params: Union[dict[str, Any], list[Any]] = {}
154
- self._preserve_dict_format = True
155
- else:
156
- self._final_params = []
157
- self._preserve_dict_format = False
158
- else:
159
- self._final_params = []
160
- self._preserve_dict_format = False
161
- self._is_reordering_needed = context.config.input_sql_had_placeholders
162
-
163
- param_context = ParameterizationContext(parent_stack=[])
164
- transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
165
- context.current_expression = transformed_expression
166
-
167
- # If we're reordering, update the merged parameters with the reordered result
168
- # In this case, we don't need to add to extracted_parameters_from_pipeline
169
- # because the parameters are already in _final_params
170
- if self._is_reordering_needed and self._final_params:
171
- context.merged_parameters = self._final_params
172
- else:
173
- # Only add extracted parameters to the pipeline if we're not reordering
174
- # This prevents duplication when parameters are already in merged_parameters
175
- context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
176
-
177
- context.metadata["parameter_metadata"] = self._parameter_metadata
178
-
179
- return transformed_expression
180
-
181
- def _transform_with_context(self, node: exp.Expression, context: ParameterizationContext) -> exp.Expression:
182
- """Transform expression tree with context tracking."""
183
- # Update context based on node type
184
- self._update_context(node, context, entering=True)
185
-
186
- # Process the node
187
- if isinstance(node, Literal):
188
- result = self._process_literal_with_context(node, context)
189
- elif isinstance(node, (Boolean, Null)):
190
- # Boolean and Null are not Literal subclasses, handle them separately
191
- result = self._process_literal_with_context(node, context)
192
- elif isinstance(node, Array) and self.parameterize_arrays:
193
- result = self._process_array(node, context)
194
- elif isinstance(node, exp.In) and self.parameterize_in_lists:
195
- result = self._process_in_clause(node, context)
196
- elif isinstance(node, exp.Placeholder) and self._is_reordering_needed:
197
- # Handle existing placeholders when reordering is needed
198
- result = self._process_existing_placeholder(node, context)
199
- elif isinstance(node, exp.Parameter) and self._is_reordering_needed:
200
- # Handle PostgreSQL-style parameters ($1, $2) when reordering is needed
201
- result = self._process_existing_parameter(node, context)
202
- elif isinstance(node, exp.Column) and self._is_reordering_needed:
203
- # Check if this column looks like a PostgreSQL parameter ($1, $2, etc.)
204
- column_name = str(node.this) if hasattr(node, "this") else ""
205
- if column_name.startswith("$") and column_name[1:].isdigit():
206
- # This is a PostgreSQL-style parameter parsed as a column
207
- result = self._process_postgresql_column_parameter(node, context)
208
- else:
209
- # Regular column - process children
210
- for key, value in node.args.items():
211
- if isinstance(value, exp.Expression):
212
- node.set(key, self._transform_with_context(value, context))
213
- elif isinstance(value, list):
214
- node.set(
215
- key,
216
- [
217
- self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
218
- for v in value
219
- ],
220
- )
221
- result = node
222
- else:
223
- # Recursively process children
224
- for key, value in node.args.items():
225
- if isinstance(value, exp.Expression):
226
- node.set(key, self._transform_with_context(value, context))
227
- elif isinstance(value, list):
228
- node.set(
229
- key,
230
- [
231
- self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
232
- for v in value
233
- ],
234
- )
235
- result = node
236
-
237
- # Update context when leaving
238
- self._update_context(node, context, entering=False)
239
-
240
- return result
241
-
242
- def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
243
- """Update parameterization context based on current AST node."""
244
- if entering:
245
- self._update_context_entering(node, context)
246
- else:
247
- self._update_context_leaving(node, context)
248
-
249
- def _update_context_entering(self, node: exp.Expression, context: ParameterizationContext) -> None:
250
- """Update context when entering a node."""
251
- context.parent_stack.append(node)
252
-
253
- if isinstance(node, Func):
254
- self._update_context_entering_func(node, context)
255
- elif isinstance(node, exp.Case):
256
- context.in_case_when = True
257
- elif isinstance(node, Array):
258
- context.in_array = True
259
- elif isinstance(node, exp.In):
260
- context.in_in_clause = True
261
- elif isinstance(node, exp.CTE):
262
- self._update_context_entering_cte(node, context)
263
- elif isinstance(node, exp.Subquery):
264
- context.subquery_depth += 1
265
- context.in_subquery = True
266
- elif isinstance(node, exp.Select):
267
- self._update_context_entering_select(node, context)
268
- elif isinstance(node, exp.Join):
269
- context.in_join_condition = True
270
-
271
- def _update_context_entering_func(self, node: Func, context: ParameterizationContext) -> None:
272
- """Update context when entering a function node."""
273
- context.function_depth += 1
274
- # Get function name from class name or node.name
275
- func_name = node.__class__.__name__.upper()
276
- if func_name in self.preserve_in_functions or (node.name and node.name.upper() in self.preserve_in_functions):
277
- context.in_function_args = True
278
-
279
- def _update_context_entering_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
280
- """Update context when entering a CTE node."""
281
- context.cte_depth += 1
282
- # Check if this CTE is recursive:
283
- # 1. Parent WITH must be RECURSIVE
284
- # 2. CTE must contain UNION (characteristic of recursive CTEs)
285
- is_in_recursive_with = any(
286
- isinstance(parent, exp.With) and parent.args.get("recursive", False)
287
- for parent in reversed(context.parent_stack)
288
- )
289
- if is_in_recursive_with and self._contains_union(node):
290
- context.in_recursive_cte = True
291
-
292
- def _update_context_entering_select(self, node: exp.Select, context: ParameterizationContext) -> None:
293
- """Update context when entering a SELECT node."""
294
- # Only track nested SELECT statements as subqueries if they're not part of a recursive CTE
295
- is_in_recursive_cte = any(
296
- isinstance(parent, exp.CTE)
297
- and any(
298
- isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
299
- for grandparent in context.parent_stack
300
- )
301
- for parent in context.parent_stack[:-1]
302
- )
303
-
304
- if not is_in_recursive_cte and any(
305
- isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
306
- for parent in context.parent_stack[:-1] # Exclude the current node
307
- ):
308
- context.subquery_depth += 1
309
- context.in_subquery = True
310
- # Check if we're in a SELECT clause expressions list
311
- if hasattr(node, "expressions"):
312
- # We'll handle this specifically when processing individual expressions
313
- context.in_select_list = False # Will be detected by _is_in_select_expressions
314
-
315
- def _update_context_leaving(self, node: exp.Expression, context: ParameterizationContext) -> None:
316
- """Update context when leaving a node."""
317
- if context.parent_stack:
318
- context.parent_stack.pop()
319
-
320
- if isinstance(node, Func):
321
- self._update_context_leaving_func(node, context)
322
- elif isinstance(node, exp.Case):
323
- context.in_case_when = False
324
- elif isinstance(node, Array):
325
- context.in_array = False
326
- elif isinstance(node, exp.In):
327
- context.in_in_clause = False
328
- elif isinstance(node, exp.CTE):
329
- self._update_context_leaving_cte(node, context)
330
- elif isinstance(node, exp.Subquery):
331
- self._update_context_leaving_subquery(node, context)
332
- elif isinstance(node, exp.Select):
333
- self._update_context_leaving_select(node, context)
334
- elif isinstance(node, exp.Join):
335
- context.in_join_condition = False
336
-
337
- def _update_context_leaving_func(self, node: Func, context: ParameterizationContext) -> None:
338
- """Update context when leaving a function node."""
339
- context.function_depth -= 1
340
- if context.function_depth == 0:
341
- context.in_function_args = False
342
-
343
- def _update_context_leaving_cte(self, node: exp.CTE, context: ParameterizationContext) -> None:
344
- """Update context when leaving a CTE node."""
345
- context.cte_depth -= 1
346
- if context.cte_depth == 0:
347
- context.in_recursive_cte = False
348
-
349
- def _update_context_leaving_subquery(self, node: exp.Subquery, context: ParameterizationContext) -> None:
350
- """Update context when leaving a subquery node."""
351
- context.subquery_depth -= 1
352
- if context.subquery_depth == 0:
353
- context.in_subquery = False
354
-
355
- def _update_context_leaving_select(self, node: exp.Select, context: ParameterizationContext) -> None:
356
- """Update context when leaving a SELECT node."""
357
- # Only decrement if this was a nested SELECT (not part of recursive CTE)
358
- is_in_recursive_cte = any(
359
- isinstance(parent, exp.CTE)
360
- and any(
361
- isinstance(grandparent, exp.With) and grandparent.args.get("recursive", False)
362
- for grandparent in context.parent_stack
363
- )
364
- for parent in context.parent_stack[:-1]
365
- )
366
-
367
- if not is_in_recursive_cte and any(
368
- isinstance(parent, (exp.Select, exp.Subquery, exp.CTE))
369
- for parent in context.parent_stack[:-1] # Exclude current node
370
- ):
371
- context.subquery_depth -= 1
372
- if context.subquery_depth == 0:
373
- context.in_subquery = False
374
- context.in_select_list = False
375
-
376
- def _process_literal_with_context(
377
- self, literal: exp.Expression, context: ParameterizationContext
378
- ) -> exp.Expression:
379
- """Process a literal with awareness of its AST context."""
380
- # Check if this literal should be preserved based on context
381
- if self._should_preserve_literal_in_context(literal, context):
382
- return literal
383
-
384
- # Use optimized extraction for single-pass processing
385
- value, type_hint, sqlglot_type, semantic_name = self._extract_literal_value_and_type_optimized(literal, context)
386
-
387
- # Create TypedParameter object
388
- from sqlspec.statement.parameters import TypedParameter
389
-
390
- typed_param = TypedParameter(
391
- value=value,
392
- sqlglot_type=sqlglot_type or exp.DataType.build("VARCHAR"), # Fallback type
393
- type_hint=type_hint,
394
- semantic_name=semantic_name,
395
- )
396
-
397
- # Always track extracted parameters for proper merging
398
- self.extracted_parameters.append(typed_param)
399
-
400
- # If we're reordering, also add to final params directly
401
- if self._is_reordering_needed:
402
- if self._preserve_dict_format and isinstance(self._final_params, dict):
403
- # For dict format, we need a key
404
- param_key = semantic_name or f"param_{len(self._final_params)}"
405
- self._final_params[param_key] = typed_param
406
- elif isinstance(self._final_params, list):
407
- self._final_params.append(typed_param)
408
- else:
409
- # Fallback - this shouldn't happen but handle gracefully
410
- if not hasattr(self, "_fallback_params"):
411
- self._fallback_params = []
412
- self._fallback_params.append(typed_param)
413
-
414
- self._parameter_metadata.append(
415
- {
416
- "index": len(self._final_params if self._is_reordering_needed else self.extracted_parameters) - 1,
417
- "type": type_hint,
418
- "semantic_name": semantic_name,
419
- "context": self._get_context_description(context),
420
- }
421
- )
422
-
423
- # Create appropriate placeholder
424
- return self._create_placeholder(hint=semantic_name)
425
-
426
- def _process_existing_placeholder(self, node: exp.Placeholder, context: ParameterizationContext) -> exp.Expression:
427
- """Process an existing placeholder when reordering parameters."""
428
- if self._original_params is None:
429
- return node
430
-
431
- if isinstance(self._original_params, (list, tuple)):
432
- self._handle_list_params_for_placeholder(node)
433
- elif isinstance(self._original_params, dict):
434
- self._handle_dict_params_for_placeholder(node)
435
- else:
436
- # Single value parameter
437
- self._handle_single_value_param_for_placeholder(node)
438
-
439
- return node
440
-
441
- def _handle_list_params_for_placeholder(self, node: exp.Placeholder) -> None:
442
- """Handle list/tuple parameters for placeholder."""
443
- if isinstance(self._original_params, (list, tuple)) and self._user_param_index < len(self._original_params):
444
- value = self._original_params[self._user_param_index]
445
- self._add_to_final_params(value, node)
446
- self._user_param_index += 1
447
- else:
448
- # More placeholders than user parameters
449
- self._add_to_final_params(None, node)
450
-
451
- def _handle_dict_params_for_placeholder(self, node: exp.Placeholder) -> None:
452
- """Handle dict parameters for placeholder."""
453
- if not isinstance(self._original_params, dict):
454
- self._add_to_final_params(None, node)
455
- return
456
-
457
- raw_placeholder_name = node.this if hasattr(node, "this") else None
458
- if not raw_placeholder_name:
459
- # Unnamed placeholder '?' with dict params is ambiguous
460
- self._add_to_final_params(None, node)
461
- return
462
-
463
- # FIX: Normalize the placeholder name by stripping leading sigils
464
- placeholder_name = raw_placeholder_name.lstrip(":@")
465
-
466
- # Debug logging
467
-
468
- if placeholder_name in self._original_params:
469
- # Direct match for placeholder name
470
- self._add_to_final_params(self._original_params[placeholder_name], node)
471
- elif placeholder_name.isdigit() and self._user_param_index == 0:
472
- # Oracle-style numeric parameters
473
- self._handle_oracle_numeric_params()
474
- self._user_param_index += 1
475
- elif placeholder_name.isdigit() and self._user_param_index > 0:
476
- # Already handled Oracle params
477
- pass
478
- elif self._user_param_index == 0 and len(self._original_params) > 0:
479
- # Single dict parameter case
480
- self._handle_single_dict_param()
481
- self._user_param_index += 1
482
- else:
483
- # No match found
484
- self._add_to_final_params(None, node)
485
-
486
- def _handle_single_value_param_for_placeholder(self, node: exp.Placeholder) -> None:
487
- """Handle single value parameter for placeholder."""
488
- if self._user_param_index == 0:
489
- self._add_to_final_params(self._original_params, node)
490
- self._user_param_index += 1
491
- else:
492
- self._add_to_final_params(None, node)
493
-
494
- def _handle_oracle_numeric_params(self) -> None:
495
- """Handle Oracle-style numeric parameters."""
496
- if not isinstance(self._original_params, dict):
497
- return
498
-
499
- if self._preserve_dict_format and isinstance(self._final_params, dict):
500
- for k, v in self._original_params.items():
501
- if k.isdigit():
502
- self._final_params[k] = v
503
- else:
504
- # Convert to positional list
505
- numeric_keys = [k for k in self._original_params if k.isdigit()]
506
- if numeric_keys:
507
- max_index = max(int(k) for k in numeric_keys)
508
- param_list = [None] * (max_index + 1)
509
- for k, v in self._original_params.items():
510
- if k.isdigit():
511
- param_list[int(k)] = v
512
- if isinstance(self._final_params, list):
513
- self._final_params.extend(param_list)
514
- elif isinstance(self._final_params, dict):
515
- for i, val in enumerate(param_list):
516
- self._final_params[str(i)] = val
517
-
518
- def _handle_single_dict_param(self) -> None:
519
- """Handle single dict parameter case."""
520
- if not isinstance(self._original_params, dict):
521
- return
522
-
523
- if self._preserve_dict_format and isinstance(self._final_params, dict):
524
- for k, v in self._original_params.items():
525
- self._final_params[k] = v
526
- elif isinstance(self._final_params, list):
527
- self._final_params.append(self._original_params)
528
- elif isinstance(self._final_params, dict):
529
- param_name = f"param_{len(self._final_params)}"
530
- self._final_params[param_name] = self._original_params
531
-
532
- def _add_to_final_params(self, value: Any, node: exp.Placeholder) -> None:
533
- """Add a value to final params with proper type handling."""
534
- if self._preserve_dict_format and isinstance(self._final_params, dict):
535
- placeholder_name = node.this if hasattr(node, "this") else f"param_{self._user_param_index}"
536
- self._final_params[placeholder_name] = value
537
- elif isinstance(self._final_params, list):
538
- self._final_params.append(value)
539
- elif isinstance(self._final_params, dict):
540
- param_name = f"param_{len(self._final_params)}"
541
- self._final_params[param_name] = value
542
-
543
- def _process_existing_parameter(self, node: exp.Parameter, context: ParameterizationContext) -> exp.Expression:
544
- """Process existing parameters (both numeric and named) when reordering parameters."""
545
- # First try to get parameter name for named parameters (like BigQuery @param_name)
546
- param_name = self._extract_parameter_name(node)
547
-
548
- if param_name and isinstance(self._original_params, dict) and param_name in self._original_params:
549
- value = self._original_params[param_name]
550
- self._add_param_value_to_finals(value)
551
- return node
552
-
553
- # Fall back to numeric parameter handling for PostgreSQL-style parameters ($1, $2)
554
- param_index = self._extract_parameter_index(node)
555
-
556
- if self._original_params is None:
557
- self._add_none_to_final_params()
558
- elif isinstance(self._original_params, (list, tuple)):
559
- self._handle_list_params_for_parameter_node(param_index)
560
- elif isinstance(self._original_params, dict):
561
- self._handle_dict_params_for_parameter_node(param_index)
562
- elif param_index == 0:
563
- # Single parameter case
564
- self._add_param_value_to_finals(self._original_params)
565
- else:
566
- self._add_none_to_final_params()
567
-
568
- # Return the parameter unchanged
569
- return node
570
-
571
- @staticmethod
572
- def _extract_parameter_name(node: exp.Parameter) -> Optional[str]:
573
- """Extract parameter name from a Parameter node for named parameters."""
574
- if hasattr(node, "this"):
575
- if isinstance(node.this, exp.Var):
576
- # Named parameter like @min_value -> min_value
577
- return str(node.this.this)
578
- if hasattr(node.this, "this"):
579
- # Handle other node types that might contain the name
580
- return str(node.this.this)
581
- return None
582
-
583
- @staticmethod
584
- def _extract_parameter_index(node: exp.Parameter) -> Optional[int]:
585
- """Extract parameter index from a Parameter node."""
586
- if hasattr(node, "this") and isinstance(node.this, Literal):
587
- import contextlib
588
-
589
- with contextlib.suppress(ValueError, TypeError):
590
- return int(node.this.this) - 1 # Convert to 0-based index
591
- return None
592
-
593
- def _handle_list_params_for_parameter_node(self, param_index: Optional[int]) -> None:
594
- """Handle list/tuple parameters for Parameter node."""
595
- if (
596
- isinstance(self._original_params, (list, tuple))
597
- and param_index is not None
598
- and 0 <= param_index < len(self._original_params)
599
- ):
600
- # Use the parameter at the specified index
601
- self._add_param_value_to_finals(self._original_params[param_index])
602
- else:
603
- # More parameters than user provided
604
- self._add_none_to_final_params()
605
-
606
- def _handle_dict_params_for_parameter_node(self, param_index: Optional[int]) -> None:
607
- """Handle dict parameters for Parameter node."""
608
- if param_index is not None:
609
- self._handle_dict_param_with_index(param_index)
610
- else:
611
- self._add_none_to_final_params()
612
-
613
- def _handle_dict_param_with_index(self, param_index: int) -> None:
614
- """Handle dict parameter when we have an index."""
615
- if not isinstance(self._original_params, dict):
616
- self._add_none_to_final_params()
617
- return
618
-
619
- # Try param_N key first
620
- param_key = f"param_{param_index}"
621
- if param_key in self._original_params:
622
- self._add_dict_value_to_finals(param_key)
623
- return
624
-
625
- # Try direct numeric key (1-based)
626
- numeric_key = str(param_index + 1)
627
- if numeric_key in self._original_params:
628
- self._add_dict_value_to_finals(numeric_key)
629
- else:
630
- self._add_none_to_final_params()
631
-
632
- def _add_dict_value_to_finals(self, key: str) -> None:
633
- """Add a value from dict params to final params."""
634
- if isinstance(self._original_params, dict) and key in self._original_params:
635
- value = self._original_params[key]
636
- if isinstance(self._final_params, list):
637
- self._final_params.append(value)
638
- elif isinstance(self._final_params, dict):
639
- self._final_params[key] = value
640
-
641
- def _add_param_value_to_finals(self, value: Any) -> None:
642
- """Add a parameter value to final params."""
643
- if isinstance(self._final_params, list):
644
- self._final_params.append(value)
645
- elif isinstance(self._final_params, dict):
646
- param_name = f"param_{len(self._final_params)}"
647
- self._final_params[param_name] = value
648
-
649
- def _add_none_to_final_params(self) -> None:
650
- """Add None to final params."""
651
- if isinstance(self._final_params, list):
652
- self._final_params.append(None)
653
- elif isinstance(self._final_params, dict):
654
- param_name = f"param_{len(self._final_params)}"
655
- self._final_params[param_name] = None
656
-
657
- def _process_postgresql_column_parameter(
658
- self, node: exp.Column, context: ParameterizationContext
659
- ) -> exp.Expression:
660
- """Process PostgreSQL-style parameters that were parsed as columns ($1, $2)."""
661
- # Extract the numeric part from $1, $2, etc.
662
- column_name = str(node.this) if hasattr(node, "this") else ""
663
- param_index = None
664
-
665
- if column_name.startswith("$") and column_name[1:].isdigit():
666
- import contextlib
667
-
668
- with contextlib.suppress(ValueError, TypeError):
669
- param_index = int(column_name[1:]) - 1 # Convert to 0-based index
670
-
671
- if self._original_params is None:
672
- # No user parameters provided - don't add None
673
- return node
674
- if isinstance(self._original_params, (list, tuple)):
675
- # When we have mixed parameter styles and reordering is needed,
676
- # use sequential assignment based on _user_param_index
677
- if self._is_reordering_needed:
678
- # For mixed styles, parameters should be assigned sequentially
679
- # regardless of the numeric value in the placeholder
680
- if self._user_param_index < len(self._original_params):
681
- param_value = self._original_params[self._user_param_index]
682
- self._user_param_index += 1
683
- else:
684
- param_value = None
685
- else:
686
- # Non-mixed styles - use the numeric value from the placeholder
687
- param_value = (
688
- self._original_params[param_index]
689
- if param_index is not None and 0 <= param_index < len(self._original_params)
690
- else None
691
- )
692
-
693
- if param_value is not None:
694
- # Add the parameter value to final params
695
- if self._preserve_dict_format and isinstance(self._final_params, dict):
696
- param_key = f"param_{len(self._final_params)}"
697
- self._final_params[param_key] = param_value
698
- elif isinstance(self._final_params, list):
699
- self._final_params.append(param_value)
700
- elif isinstance(self._final_params, dict):
701
- param_name = f"param_{len(self._final_params)}"
702
- self._final_params[param_name] = param_value
703
- # More parameters than user provided - don't add None
704
- elif isinstance(self._original_params, dict):
705
- # For dict parameters with numeric placeholders, try to map by index
706
- if param_index is not None:
707
- param_key = f"param_{param_index}"
708
- if param_key in self._original_params:
709
- if self._preserve_dict_format and isinstance(self._final_params, dict):
710
- self._final_params[param_key] = self._original_params[param_key]
711
- elif isinstance(self._final_params, list):
712
- self._final_params.append(self._original_params[param_key])
713
- elif isinstance(self._final_params, dict):
714
- self._final_params[param_key] = self._original_params[param_key]
715
- else:
716
- # Try direct numeric key
717
- numeric_key = str(param_index + 1) # 1-based
718
- if numeric_key in self._original_params:
719
- if self._preserve_dict_format and isinstance(self._final_params, dict):
720
- self._final_params[numeric_key] = self._original_params[numeric_key]
721
- elif isinstance(self._final_params, list):
722
- self._final_params.append(self._original_params[numeric_key])
723
- elif isinstance(self._final_params, dict):
724
- self._final_params[numeric_key] = self._original_params[numeric_key]
725
- # Single parameter case
726
- elif param_index == 0:
727
- if self._preserve_dict_format and isinstance(self._final_params, dict):
728
- param_key = f"param_{len(self._final_params)}"
729
- self._final_params[param_key] = self._original_params
730
- elif isinstance(self._final_params, list):
731
- self._final_params.append(self._original_params)
732
- elif isinstance(self._final_params, dict):
733
- param_name = f"param_{len(self._final_params)}"
734
- self._final_params[param_name] = self._original_params
735
-
736
- # Return the column unchanged - it represents the parameter placeholder
737
- return node
738
-
739
- def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
740
- """Enhanced context-aware decision on literal preservation."""
741
- # Existing preservation rules (maintain compatibility)
742
- if self.preserve_null and isinstance(literal, Null):
743
- return True
744
-
745
- if self.preserve_boolean and isinstance(literal, Boolean):
746
- return True
747
-
748
- # NEW: Context-based preservation rules
749
-
750
- # Rule 4: Preserve enum-like literals in subquery lookups (the main fix we need)
751
- if context.in_subquery and self._is_scalar_lookup_pattern(literal, context):
752
- return self._is_enum_like_literal(literal)
753
-
754
- # Existing preservation rules continue...
755
-
756
- # Check if in preserved function arguments
757
- if context.in_function_args:
758
- return True
759
-
760
- # ENHANCED: Intelligent recursive CTE literal preservation
761
- if self.preserve_in_recursive_cte and context.in_recursive_cte:
762
- return self._should_preserve_literal_in_recursive_cte(literal, context)
763
-
764
- # Check if this literal is being used as an alias value in SELECT
765
- # e.g., 'computed' as process_status should be preserved
766
- if hasattr(literal, "parent") and literal.parent:
767
- parent = literal.parent
768
- # Check if it's an Alias node and the literal is the expression (not the alias name)
769
- if isinstance(parent, exp.Alias) and parent.this == literal:
770
- # Check if this alias is in a SELECT clause
771
- for ancestor in context.parent_stack:
772
- if isinstance(ancestor, exp.Select):
773
- return True
774
-
775
- # Check parent context more intelligently
776
- for parent in context.parent_stack:
777
- # Preserve in schema/DDL contexts
778
- if isinstance(parent, (DataType, exp.ColumnDef, exp.Create, exp.Schema)):
779
- return True
780
-
781
- # Preserve numbers in LIMIT/OFFSET
782
- if (
783
- self.preserve_numbers_in_limit
784
- and isinstance(parent, (exp.Limit, exp.Offset))
785
- and isinstance(literal, exp.Literal)
786
- and self._is_number_literal(literal)
787
- ):
788
- return True
789
-
790
- # Preserve in CASE conditions for readability
791
- if isinstance(parent, exp.Case) and context.in_case_when:
792
- # Only preserve simple comparisons
793
- return not isinstance(literal.parent, Binary)
794
-
795
- # Check string length
796
- if isinstance(literal, exp.Literal) and self._is_string_literal(literal):
797
- string_value = str(literal.this)
798
- if len(string_value) > self.max_string_length:
799
- return True
800
-
801
- return False
802
-
803
- def _is_in_select_expressions(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
804
- """Check if literal is in SELECT clause expressions (critical for type inference)."""
805
- for parent in reversed(context.parent_stack):
806
- if isinstance(parent, exp.Select):
807
- if hasattr(parent, "expressions") and parent.expressions:
808
- return any(self._literal_is_in_expression_tree(literal, expr) for expr in parent.expressions)
809
- elif isinstance(parent, (exp.Where, exp.Having, exp.Join)):
810
- return False
811
- return False
812
-
813
- def _is_recursive_computation(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
814
- """Check if literal is part of recursive computation logic."""
815
- # Look for arithmetic operations that are part of recursive logic
816
- for parent in reversed(context.parent_stack):
817
- if isinstance(parent, exp.Binary) and parent.key in ("ADD", "SUB", "MUL", "DIV"):
818
- # Check if this arithmetic is in a SELECT clause of a recursive part
819
- return self._is_in_select_expressions(literal, context)
820
- return False
821
-
822
- def _should_preserve_literal_in_recursive_cte(
823
- self, literal: exp.Expression, context: ParameterizationContext
824
- ) -> bool:
825
- """Intelligent recursive CTE literal preservation based on semantic role."""
826
- # Preserve SELECT clause literals (type inference critical)
827
- if self._is_in_select_expressions(literal, context):
828
- return True
829
-
830
- # Preserve recursive computation literals (core logic)
831
- return self._is_recursive_computation(literal, context)
832
-
833
- def _literal_is_in_expression_tree(self, target_literal: exp.Expression, expr: exp.Expression) -> bool:
834
- """Check if target literal is within the given expression tree."""
835
- if expr == target_literal:
836
- return True
837
- # Recursively check child expressions
838
- return any(child == target_literal for child in expr.iter_expressions())
839
-
840
- def _is_scalar_lookup_pattern(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
841
- """Detect if literal is part of a scalar subquery lookup pattern."""
842
- # Must be in a subquery for this pattern to apply
843
- if context.subquery_depth == 0:
844
- return False
845
-
846
- # Check if we're in a WHERE clause of a subquery that returns a single column
847
- # and the literal is being compared against a column
848
- for parent in reversed(context.parent_stack):
849
- if isinstance(parent, exp.Where):
850
- # Look for pattern: WHERE column = 'literal'
851
- if isinstance(parent.this, exp.Binary) and parent.this.right == literal:
852
- return isinstance(parent.this.left, exp.Column)
853
- # Also check for literal on the left side: WHERE 'literal' = column
854
- if isinstance(parent.this, exp.Binary) and parent.this.left == literal:
855
- return isinstance(parent.this.right, exp.Column)
856
- return False
857
-
858
- def _is_enum_like_literal(self, literal: exp.Expression) -> bool:
859
- """Detect if literal looks like an enum/identifier constant."""
860
- if not isinstance(literal, exp.Literal) or not self._is_string_literal(literal):
861
- return False
862
-
863
- value = str(literal.this)
864
-
865
- # Conservative heuristics for enum-like values
866
- return (
867
- len(value) <= MAX_ENUM_LENGTH # Reasonable length limit
868
- and value.replace("_", "").isalnum() # Only alphanumeric + underscores
869
- and not value.isdigit() # Not a pure number
870
- and len(value) > MIN_ENUM_LENGTH # Not too short to be meaningful
871
- )
872
-
873
- def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
874
- """Extract the Python value and type info from a SQLGlot literal."""
875
- if isinstance(literal, Null) or literal.this is None:
876
- return None, "null"
877
-
878
- # Ensure we have a Literal for type checking methods
879
- if not isinstance(literal, exp.Literal):
880
- return str(literal), "string"
881
-
882
- if isinstance(literal, Boolean) or isinstance(literal.this, bool):
883
- return literal.this, "boolean"
884
-
885
- if self._is_string_literal(literal):
886
- return str(literal.this), "string"
887
-
888
- if self._is_number_literal(literal):
889
- # Preserve numeric precision if enabled
890
- if self.type_preservation:
891
- value_str = str(literal.this)
892
- if "." in value_str or "e" in value_str.lower():
893
- try:
894
- # Check if it's a decimal that needs precision
895
- decimal_places = len(value_str.split(".")[1]) if "." in value_str else 0
896
- if decimal_places > MAX_DECIMAL_PRECISION: # Likely needs decimal precision
897
- return value_str, "decimal"
898
- return float(literal.this), "float"
899
- except (ValueError, IndexError):
900
- return str(literal.this), "numeric_string"
901
- else:
902
- try:
903
- value = int(literal.this)
904
- except ValueError:
905
- return str(literal.this), "numeric_string"
906
- else:
907
- # Check for bigint
908
- if abs(value) > MAX_INT32_VALUE: # Max 32-bit int
909
- return value, "bigint"
910
- return value, "integer"
911
- else:
912
- # Simple type conversion
913
- try:
914
- if "." in str(literal.this):
915
- return float(literal.this), "float"
916
- return int(literal.this), "integer"
917
- except ValueError:
918
- return str(literal.this), "numeric_string"
919
-
920
- # Handle date/time literals - these are DataType attributes not Literal attributes
921
- # Date/time values are typically string literals that need context-aware processing
922
- # We'll return them as strings and let the database handle type conversion
923
-
924
- # Fallback
925
- return str(literal.this), "unknown"
926
-
927
- def _extract_literal_value_and_type_optimized(
928
- self, literal: exp.Expression, context: ParameterizationContext
929
- ) -> "tuple[Any, str, Optional[exp.DataType], Optional[str]]":
930
- """Single-pass extraction of value, type hint, SQLGlot type, and semantic name.
931
-
932
- This optimized method extracts all information in one pass, avoiding redundant
933
- AST traversals and expensive operations like literal.sql().
934
-
935
- Args:
936
- literal: The literal expression to extract from
937
- context: Current parameterization context with parent stack
938
-
939
- Returns:
940
- Tuple of (value, type_hint, sqlglot_type, semantic_name)
941
- """
942
- # Extract value and basic type hint using existing logic
943
- value, type_hint = self._extract_literal_value_and_type(literal)
944
-
945
- # Determine SQLGlot type based on the type hint without additional parsing
946
- sqlglot_type = self._infer_sqlglot_type(type_hint, value)
947
-
948
- # Generate semantic name from context if available
949
- semantic_name = self._generate_semantic_name_from_context(literal, context)
950
-
951
- return value, type_hint, sqlglot_type, semantic_name
952
-
953
- @staticmethod
954
- def _infer_sqlglot_type(type_hint: str, value: Any) -> "Optional[exp.DataType]":
955
- """Infer SQLGlot DataType from type hint without parsing.
956
-
957
- Args:
958
- type_hint: The simple type hint string
959
- value: The actual value for additional context
960
-
961
- Returns:
962
- SQLGlot DataType instance or None
963
- """
964
- type_mapping = {
965
- "null": "NULL",
966
- "boolean": "BOOLEAN",
967
- "integer": "INT",
968
- "bigint": "BIGINT",
969
- "float": "FLOAT",
970
- "decimal": "DECIMAL",
971
- "string": "VARCHAR",
972
- "numeric_string": "VARCHAR",
973
- "unknown": "VARCHAR",
974
- }
975
-
976
- type_name = type_mapping.get(type_hint, "VARCHAR")
977
-
978
- # Build DataType with appropriate parameters
979
- if type_hint == "decimal" and isinstance(value, str):
980
- # Try to infer precision and scale
981
- parts = value.split(".")
982
- precision = len(parts[0]) + len(parts[1]) if len(parts) > 1 else len(parts[0])
983
- scale = len(parts[1]) if len(parts) > 1 else 0
984
- return exp.DataType.build(type_name, expressions=[exp.Literal.number(precision), exp.Literal.number(scale)])
985
- if type_hint == "string" and isinstance(value, str):
986
- # Infer VARCHAR length
987
- length = len(value)
988
- if length > 0:
989
- return exp.DataType.build(type_name, expressions=[exp.Literal.number(length)])
990
-
991
- # Default case - just the type name
992
- return exp.DataType.build(type_name)
993
-
994
- @staticmethod
995
- def _generate_semantic_name_from_context(
996
- literal: exp.Expression, context: ParameterizationContext
997
- ) -> "Optional[str]":
998
- """Generate semantic name from AST context using existing parent stack.
999
-
1000
- Args:
1001
- literal: The literal being parameterized
1002
- context: Current context with parent stack
1003
-
1004
- Returns:
1005
- Semantic name or None
1006
- """
1007
- # Look for column comparisons in parent stack
1008
- for parent in reversed(context.parent_stack):
1009
- if isinstance(parent, Binary):
1010
- # It's a comparison - check if we're comparing to a column
1011
- if parent.left == literal and isinstance(parent.right, exp.Column):
1012
- return parent.right.name
1013
- if parent.right == literal and isinstance(parent.left, exp.Column):
1014
- return parent.left.name
1015
- elif isinstance(parent, exp.In):
1016
- # IN clause - check the left side for column
1017
- if parent.this and isinstance(parent.this, exp.Column):
1018
- return f"{parent.this.name}_value"
1019
-
1020
- # Check if we're in a specific SQL clause
1021
- for parent in reversed(context.parent_stack):
1022
- if isinstance(parent, exp.Where):
1023
- return "where_value"
1024
- if isinstance(parent, exp.Having):
1025
- return "having_value"
1026
- if isinstance(parent, exp.Join):
1027
- return "join_value"
1028
- if isinstance(parent, exp.Select):
1029
- return "select_value"
1030
-
1031
- return None
1032
-
1033
- def _is_string_literal(self, literal: exp.Literal) -> bool:
1034
- """Check if a literal is a string."""
1035
- # Check if it's explicitly a string literal
1036
- return (hasattr(literal, "is_string") and literal.is_string) or (
1037
- isinstance(literal.this, str) and not self._is_number_literal(literal)
1038
- )
1039
-
1040
- @staticmethod
1041
- def _is_number_literal(literal: exp.Literal) -> bool:
1042
- """Check if a literal is a number."""
1043
- # Check if it's explicitly a number literal
1044
- if hasattr(literal, "is_number") and literal.is_number:
1045
- return True
1046
- if literal.this is None:
1047
- return False
1048
- # Try to determine if it's numeric by attempting conversion
1049
- try:
1050
- float(str(literal.this))
1051
- except (ValueError, TypeError):
1052
- return False
1053
- return True
1054
-
1055
- def _create_placeholder(self, hint: Optional[str] = None) -> exp.Expression:
1056
- """Create a placeholder expression with optional type hint."""
1057
- # Import ParameterStyle for proper comparison
1058
-
1059
- # Handle both style names and actual placeholder prefixes
1060
- style = self.placeholder_style
1061
- if style in {"?", ParameterStyle.QMARK, "qmark"}:
1062
- placeholder = exp.Placeholder()
1063
- elif style == ":name":
1064
- # Use hint in parameter name if available
1065
- param_name = f"{hint}_{self._parameter_counter}" if hint else f"param_{self._parameter_counter}"
1066
- placeholder = exp.Placeholder(this=param_name)
1067
- elif style in {ParameterStyle.NAMED_COLON, "named_colon"} or style.startswith(":"):
1068
- param_name = f"param_{self._parameter_counter}"
1069
- placeholder = exp.Placeholder(this=param_name)
1070
- elif style in {ParameterStyle.NUMERIC, "numeric"} or style.startswith("$"):
1071
- # PostgreSQL style numbered parameters - use Var for consistent $N format
1072
- # Note: PostgreSQL uses 1-based indexing
1073
- placeholder = exp.Var(this=f"${self._parameter_counter + 1}") # type: ignore[assignment]
1074
- elif style in {ParameterStyle.NAMED_AT, "named_at"}:
1075
- # BigQuery style @param - don't include @ in the placeholder name
1076
- # The @ will be added during SQL generation
1077
- # Use 0-based indexing for consistency with parameter arrays
1078
- param_name = f"param_{self._parameter_counter}"
1079
- placeholder = exp.Placeholder(this=param_name)
1080
- elif style in {ParameterStyle.POSITIONAL_PYFORMAT, "pyformat"}:
1081
- # Don't use pyformat directly in SQLGlot - use standard placeholder
1082
- # and let the compile method convert it later
1083
- placeholder = exp.Placeholder()
1084
- else:
1085
- # Default to question mark
1086
- placeholder = exp.Placeholder()
1087
-
1088
- # Increment counter after creating placeholder
1089
- self._parameter_counter += 1
1090
- return placeholder
1091
-
1092
- def _process_array(self, array_node: Array, context: ParameterizationContext) -> exp.Expression:
1093
- """Process array literals for parameterization."""
1094
- if not array_node.expressions:
1095
- return array_node
1096
-
1097
- # Check array size
1098
- if len(array_node.expressions) > self.max_array_length:
1099
- # Too large, preserve as-is
1100
- return array_node
1101
-
1102
- # Extract all array elements
1103
- array_values = []
1104
- element_types = []
1105
- all_literals = True
1106
-
1107
- for expr in array_node.expressions:
1108
- if isinstance(expr, Literal):
1109
- value, type_hint = self._extract_literal_value_and_type(expr)
1110
- array_values.append(value)
1111
- element_types.append(type_hint)
1112
- else:
1113
- all_literals = False
1114
- break
1115
-
1116
- if all_literals:
1117
- # Determine array element type from the first element
1118
- element_type = element_types[0] if element_types else "unknown"
1119
-
1120
- # Create SQLGlot array type
1121
- element_sqlglot_type = self._infer_sqlglot_type(element_type, array_values[0] if array_values else None)
1122
- array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
1123
-
1124
- # Create TypedParameter for the entire array
1125
-
1126
- typed_param = TypedParameter(
1127
- value=array_values,
1128
- sqlglot_type=array_sqlglot_type,
1129
- type_hint=f"array<{element_type}>",
1130
- semantic_name="array_values",
1131
- )
1132
-
1133
- # Replace entire array with a single parameter
1134
- self.extracted_parameters.append(typed_param)
1135
- self._parameter_metadata.append(
1136
- {
1137
- "index": len(self.extracted_parameters) - 1,
1138
- "type": f"array<{element_type}>",
1139
- "length": len(array_values),
1140
- "context": "array_literal",
1141
- }
1142
- )
1143
- return self._create_placeholder("array")
1144
- # Process individual elements
1145
- new_expressions = []
1146
- for expr in array_node.expressions:
1147
- if isinstance(expr, Literal):
1148
- new_expressions.append(self._process_literal_with_context(expr, context))
1149
- else:
1150
- new_expressions.append(self._transform_with_context(expr, context))
1151
- array_node.set("expressions", new_expressions)
1152
- return array_node
1153
-
1154
- def _process_in_clause(self, in_node: exp.In, context: ParameterizationContext) -> exp.Expression:
1155
- """Process IN clause for intelligent parameterization."""
1156
- # Check if it's a subquery IN clause (has 'query' in args)
1157
- if in_node.args.get("query"):
1158
- # Don't parameterize subqueries, just process them recursively
1159
- in_node.set("query", self._transform_with_context(in_node.args["query"], context))
1160
- return in_node
1161
-
1162
- # Check if it has literal expressions (the values on the right side)
1163
- if "expressions" not in in_node.args or not in_node.args["expressions"]:
1164
- return in_node
1165
-
1166
- # Check if the IN list is too large
1167
- expressions = in_node.args["expressions"]
1168
- if len(expressions) > self.max_in_list_size:
1169
- # Consider alternative strategies for large IN lists
1170
- return in_node
1171
-
1172
- # Process the expressions in the IN clause
1173
- has_literals = any(isinstance(expr, Literal) for expr in expressions)
1174
-
1175
- if has_literals:
1176
- # Transform literals in the IN list
1177
- new_expressions = []
1178
- for expr in expressions:
1179
- if isinstance(expr, Literal):
1180
- new_expressions.append(self._process_literal_with_context(expr, context))
1181
- else:
1182
- new_expressions.append(self._transform_with_context(expr, context))
1183
-
1184
- # Update the IN node's expressions using set method
1185
- in_node.set("expressions", new_expressions)
1186
-
1187
- return in_node
1188
-
1189
- def _get_context_description(self, context: ParameterizationContext) -> str:
1190
- """Get a description of the current parameterization context."""
1191
- descriptions = []
1192
-
1193
- if context.in_function_args:
1194
- descriptions.append("function_args")
1195
- if context.in_case_when:
1196
- descriptions.append("case_when")
1197
- if context.in_array:
1198
- descriptions.append("array")
1199
- if context.in_in_clause:
1200
- descriptions.append("in_clause")
1201
-
1202
- if not descriptions:
1203
- # Try to determine from parent stack
1204
- for parent in reversed(context.parent_stack):
1205
- if isinstance(parent, exp.Select):
1206
- descriptions.append("select")
1207
- break
1208
- if isinstance(parent, exp.Where):
1209
- descriptions.append("where")
1210
- break
1211
- if isinstance(parent, exp.Join):
1212
- descriptions.append("join")
1213
- break
1214
-
1215
- return "_".join(descriptions) if descriptions else "general"
1216
-
1217
- def get_parameters(self) -> list[Any]:
1218
- """Get the list of extracted parameters from the last processing operation.
1219
-
1220
- Returns:
1221
- List of parameter values extracted during the last process() call.
1222
- """
1223
- return self.extracted_parameters.copy()
1224
-
1225
- def get_parameter_metadata(self) -> list[dict[str, Any]]:
1226
- """Get metadata about extracted parameters for advanced usage.
1227
-
1228
- Returns:
1229
- List of parameter metadata dictionaries.
1230
- """
1231
- return self._parameter_metadata.copy()
1232
-
1233
- def _contains_union(self, cte_node: exp.CTE) -> bool:
1234
- """Check if a CTE contains a UNION (characteristic of recursive CTEs)."""
1235
-
1236
- def has_union(node: exp.Expression) -> bool:
1237
- if isinstance(node, exp.Union):
1238
- return True
1239
- return any(has_union(child) for child in node.iter_expressions())
1240
-
1241
- return cte_node.this and has_union(cte_node.this)
1242
-
1243
- def clear_parameters(self) -> None:
1244
- """Clear the extracted parameters list."""
1245
- self.extracted_parameters = []
1246
- self._parameter_counter = 0
1247
- self._parameter_metadata = []