sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -644
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -462
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +217 -451
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +418 -498
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +592 -634
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +393 -436
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +549 -942
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -550
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +741 -0
  31. sqlspec/adapters/psycopg/driver.py +732 -733
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +243 -426
  35. sqlspec/base.py +220 -825
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
  137. sqlspec-0.12.0.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -330
  150. sqlspec/mixins.py +0 -306
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.0.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,623 @@
1
+ """Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional
5
+
6
+ from sqlglot import exp
7
+ from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
8
+
9
+ from sqlspec.statement.parameters import ParameterStyle
10
+ from sqlspec.statement.pipelines.base import ProcessorProtocol
11
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
12
+
13
+ __all__ = ("ParameterizationContext", "ParameterizeLiterals")
14
+
15
+ # Constants for magic values and literal parameterization
16
+ MAX_DECIMAL_PRECISION = 6
17
+ MAX_INT32_VALUE = 2147483647
18
+ DEFAULT_MAX_STRING_LENGTH = 1000
19
+ """Default maximum string length for literal parameterization."""
20
+
21
+ DEFAULT_MAX_ARRAY_LENGTH = 100
22
+ """Default maximum array length for literal parameterization."""
23
+
24
+ DEFAULT_MAX_IN_LIST_SIZE = 50
25
+ """Default maximum IN clause list size before parameterization."""
26
+
27
+
28
+ @dataclass
29
+ class ParameterizationContext:
30
+ """Context for tracking parameterization state during AST traversal."""
31
+
32
+ parent_stack: list[exp.Expression]
33
+ in_function_args: bool = False
34
+ in_case_when: bool = False
35
+ in_array: bool = False
36
+ in_in_clause: bool = False
37
+ function_depth: int = 0
38
+
39
+
40
+ class ParameterizeLiterals(ProcessorProtocol):
41
+ """Advanced literal parameterization using SQLGlot AST analysis.
42
+
43
+ This enhanced version provides:
44
+ - Context-aware parameterization based on AST position
45
+ - Smart handling of arrays, IN clauses, and function arguments
46
+ - Type-preserving parameter extraction
47
+ - Configurable parameterization strategies
48
+ - Performance optimization for query plan caching
49
+
50
+ Args:
51
+ placeholder_style: Style of placeholder to use ("?", ":name", "$1", etc.).
52
+ preserve_null: Whether to preserve NULL literals as-is.
53
+ preserve_boolean: Whether to preserve boolean literals as-is.
54
+ preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
55
+ preserve_in_functions: List of function names where literals should be preserved.
56
+ parameterize_arrays: Whether to parameterize array literals.
57
+ parameterize_in_lists: Whether to parameterize IN clause lists.
58
+ max_string_length: Maximum string length to parameterize.
59
+ max_array_length: Maximum array length to parameterize.
60
+ max_in_list_size: Maximum IN list size to parameterize.
61
+ type_preservation: Whether to preserve exact literal types.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ placeholder_style: str = "?",
67
+ preserve_null: bool = True,
68
+ preserve_boolean: bool = True,
69
+ preserve_numbers_in_limit: bool = True,
70
+ preserve_in_functions: Optional[list[str]] = None,
71
+ parameterize_arrays: bool = True,
72
+ parameterize_in_lists: bool = True,
73
+ max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
74
+ max_array_length: int = DEFAULT_MAX_ARRAY_LENGTH,
75
+ max_in_list_size: int = DEFAULT_MAX_IN_LIST_SIZE,
76
+ type_preservation: bool = True,
77
+ ) -> None:
78
+ self.placeholder_style = placeholder_style
79
+ self.preserve_null = preserve_null
80
+ self.preserve_boolean = preserve_boolean
81
+ self.preserve_numbers_in_limit = preserve_numbers_in_limit
82
+ self.preserve_in_functions = preserve_in_functions or ["COALESCE", "IFNULL", "NVL", "ISNULL"]
83
+ self.parameterize_arrays = parameterize_arrays
84
+ self.parameterize_in_lists = parameterize_in_lists
85
+ self.max_string_length = max_string_length
86
+ self.max_array_length = max_array_length
87
+ self.max_in_list_size = max_in_list_size
88
+ self.type_preservation = type_preservation
89
+ self.extracted_parameters: list[Any] = []
90
+ self._parameter_counter = 0
91
+ self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
92
+
93
+ def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
94
+ """Advanced literal parameterization with context-aware AST analysis."""
95
+ if expression is None or context.current_expression is None or context.config.input_sql_had_placeholders:
96
+ return expression
97
+
98
+ self.extracted_parameters = []
99
+ self._parameter_counter = 0
100
+ self._parameter_metadata = []
101
+
102
+ param_context = ParameterizationContext(parent_stack=[])
103
+ transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
104
+ context.current_expression = transformed_expression
105
+ context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
106
+
107
+ context.metadata["parameter_metadata"] = self._parameter_metadata
108
+
109
+ return transformed_expression
110
+
111
+ def _transform_with_context(self, node: exp.Expression, context: ParameterizationContext) -> exp.Expression:
112
+ """Transform expression tree with context tracking."""
113
+ # Update context based on node type
114
+ self._update_context(node, context, entering=True)
115
+
116
+ # Process the node
117
+ if isinstance(node, Literal):
118
+ result = self._process_literal_with_context(node, context)
119
+ elif isinstance(node, (Boolean, Null)):
120
+ # Boolean and Null are not Literal subclasses, handle them separately
121
+ result = self._process_literal_with_context(node, context)
122
+ elif isinstance(node, Array) and self.parameterize_arrays:
123
+ result = self._process_array(node, context)
124
+ elif isinstance(node, exp.In) and self.parameterize_in_lists:
125
+ result = self._process_in_clause(node, context)
126
+ else:
127
+ # Recursively process children
128
+ for key, value in node.args.items():
129
+ if isinstance(value, exp.Expression):
130
+ node.set(key, self._transform_with_context(value, context))
131
+ elif isinstance(value, list):
132
+ node.set(
133
+ key,
134
+ [
135
+ self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
136
+ for v in value
137
+ ],
138
+ )
139
+ result = node
140
+
141
+ # Update context when leaving
142
+ self._update_context(node, context, entering=False)
143
+
144
+ return result
145
+
146
+ def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
147
+ """Update parameterization context based on current AST node."""
148
+ if entering:
149
+ context.parent_stack.append(node)
150
+
151
+ if isinstance(node, Func):
152
+ context.function_depth += 1
153
+ # Get function name from class name or node.name
154
+ func_name = node.__class__.__name__.upper()
155
+ if func_name in self.preserve_in_functions or (
156
+ node.name and node.name.upper() in self.preserve_in_functions
157
+ ):
158
+ context.in_function_args = True
159
+ elif isinstance(node, exp.Case):
160
+ context.in_case_when = True
161
+ elif isinstance(node, Array):
162
+ context.in_array = True
163
+ elif isinstance(node, exp.In):
164
+ context.in_in_clause = True
165
+ else:
166
+ if context.parent_stack:
167
+ context.parent_stack.pop()
168
+
169
+ if isinstance(node, Func):
170
+ context.function_depth -= 1
171
+ if context.function_depth == 0:
172
+ context.in_function_args = False
173
+ elif isinstance(node, exp.Case):
174
+ context.in_case_when = False
175
+ elif isinstance(node, Array):
176
+ context.in_array = False
177
+ elif isinstance(node, exp.In):
178
+ context.in_in_clause = False
179
+
180
+ def _process_literal_with_context(
181
+ self, literal: exp.Expression, context: ParameterizationContext
182
+ ) -> exp.Expression:
183
+ """Process a literal with awareness of its AST context."""
184
+ # Check if this literal should be preserved based on context
185
+ if self._should_preserve_literal_in_context(literal, context):
186
+ return literal
187
+
188
+ # Use optimized extraction for single-pass processing
189
+ value, type_hint, sqlglot_type, semantic_name = self._extract_literal_value_and_type_optimized(literal, context)
190
+
191
+ # Create TypedParameter object
192
+ from sqlspec.statement.parameters import TypedParameter
193
+
194
+ typed_param = TypedParameter(
195
+ value=value,
196
+ sqlglot_type=sqlglot_type or exp.DataType.build("VARCHAR"), # Fallback type
197
+ type_hint=type_hint,
198
+ semantic_name=semantic_name,
199
+ )
200
+
201
+ # Add to parameters list
202
+ self.extracted_parameters.append(typed_param)
203
+ self._parameter_metadata.append(
204
+ {
205
+ "index": len(self.extracted_parameters) - 1,
206
+ "type": type_hint,
207
+ "semantic_name": semantic_name,
208
+ "context": self._get_context_description(context),
209
+ # Note: We avoid calling literal.sql() for performance
210
+ }
211
+ )
212
+
213
+ # Create appropriate placeholder
214
+ return self._create_placeholder(hint=semantic_name)
215
+
216
+ def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
217
+ """Context-aware decision on literal preservation."""
218
+ # Check for NULL values
219
+ if self.preserve_null and isinstance(literal, Null):
220
+ return True
221
+
222
+ # Check for boolean values
223
+ if self.preserve_boolean and isinstance(literal, Boolean):
224
+ return True
225
+
226
+ # Check if in preserved function arguments
227
+ if context.in_function_args:
228
+ return True
229
+
230
+ # Check parent context more intelligently
231
+ for parent in context.parent_stack:
232
+ # Preserve in schema/DDL contexts
233
+ if isinstance(parent, (DataType, exp.ColumnDef, exp.Create, exp.Schema)):
234
+ return True
235
+
236
+ # Preserve numbers in LIMIT/OFFSET
237
+ if (
238
+ self.preserve_numbers_in_limit
239
+ and isinstance(parent, (exp.Limit, exp.Offset))
240
+ and isinstance(literal, exp.Literal)
241
+ and self._is_number_literal(literal)
242
+ ):
243
+ return True
244
+
245
+ # Preserve in CASE conditions for readability
246
+ if isinstance(parent, exp.Case) and context.in_case_when:
247
+ # Only preserve simple comparisons
248
+ return not isinstance(literal.parent, Binary)
249
+
250
+ # Check string length
251
+ if isinstance(literal, exp.Literal) and self._is_string_literal(literal):
252
+ string_value = str(literal.this)
253
+ if len(string_value) > self.max_string_length:
254
+ return True
255
+
256
+ return False
257
+
258
+ def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
259
+ """Extract the Python value and type info from a SQLGlot literal."""
260
+ if isinstance(literal, Null) or literal.this is None:
261
+ return None, "null"
262
+
263
+ # Ensure we have a Literal for type checking methods
264
+ if not isinstance(literal, exp.Literal):
265
+ return str(literal), "string"
266
+
267
+ if isinstance(literal, Boolean) or isinstance(literal.this, bool):
268
+ return literal.this, "boolean"
269
+
270
+ if self._is_string_literal(literal):
271
+ return str(literal.this), "string"
272
+
273
+ if self._is_number_literal(literal):
274
+ # Preserve numeric precision if enabled
275
+ if self.type_preservation:
276
+ value_str = str(literal.this)
277
+ if "." in value_str or "e" in value_str.lower():
278
+ try:
279
+ # Check if it's a decimal that needs precision
280
+ decimal_places = len(value_str.split(".")[1]) if "." in value_str else 0
281
+ if decimal_places > MAX_DECIMAL_PRECISION: # Likely needs decimal precision
282
+ return value_str, "decimal"
283
+ return float(literal.this), "float"
284
+ except (ValueError, IndexError):
285
+ return str(literal.this), "numeric_string"
286
+ else:
287
+ try:
288
+ value = int(literal.this)
289
+ except ValueError:
290
+ return str(literal.this), "numeric_string"
291
+ else:
292
+ # Check for bigint
293
+ if abs(value) > MAX_INT32_VALUE: # Max 32-bit int
294
+ return value, "bigint"
295
+ return value, "integer"
296
+ else:
297
+ # Simple type conversion
298
+ try:
299
+ if "." in str(literal.this):
300
+ return float(literal.this), "float"
301
+ return int(literal.this), "integer"
302
+ except ValueError:
303
+ return str(literal.this), "numeric_string"
304
+
305
+ # Handle date/time literals - these are DataType attributes not Literal attributes
306
+ # Date/time values are typically string literals that need context-aware processing
307
+ # We'll return them as strings and let the database handle type conversion
308
+
309
+ # Fallback
310
+ return str(literal.this), "unknown"
311
+
312
+ def _extract_literal_value_and_type_optimized(
313
+ self, literal: exp.Expression, context: ParameterizationContext
314
+ ) -> "tuple[Any, str, Optional[exp.DataType], Optional[str]]":
315
+ """Single-pass extraction of value, type hint, SQLGlot type, and semantic name.
316
+
317
+ This optimized method extracts all information in one pass, avoiding redundant
318
+ AST traversals and expensive operations like literal.sql().
319
+
320
+ Args:
321
+ literal: The literal expression to extract from
322
+ context: Current parameterization context with parent stack
323
+
324
+ Returns:
325
+ Tuple of (value, type_hint, sqlglot_type, semantic_name)
326
+ """
327
+ # Extract value and basic type hint using existing logic
328
+ value, type_hint = self._extract_literal_value_and_type(literal)
329
+
330
+ # Determine SQLGlot type based on the type hint without additional parsing
331
+ sqlglot_type = self._infer_sqlglot_type(type_hint, value)
332
+
333
+ # Generate semantic name from context if available
334
+ semantic_name = self._generate_semantic_name_from_context(literal, context)
335
+
336
+ return value, type_hint, sqlglot_type, semantic_name
337
+
338
+ @staticmethod
339
+ def _infer_sqlglot_type(type_hint: str, value: Any) -> "Optional[exp.DataType]":
340
+ """Infer SQLGlot DataType from type hint without parsing.
341
+
342
+ Args:
343
+ type_hint: The simple type hint string
344
+ value: The actual value for additional context
345
+
346
+ Returns:
347
+ SQLGlot DataType instance or None
348
+ """
349
+ type_mapping = {
350
+ "null": "NULL",
351
+ "boolean": "BOOLEAN",
352
+ "integer": "INT",
353
+ "bigint": "BIGINT",
354
+ "float": "FLOAT",
355
+ "decimal": "DECIMAL",
356
+ "string": "VARCHAR",
357
+ "numeric_string": "VARCHAR",
358
+ "unknown": "VARCHAR",
359
+ }
360
+
361
+ type_name = type_mapping.get(type_hint, "VARCHAR")
362
+
363
+ # Build DataType with appropriate parameters
364
+ if type_hint == "decimal" and isinstance(value, str):
365
+ # Try to infer precision and scale
366
+ parts = value.split(".")
367
+ precision = len(parts[0]) + len(parts[1]) if len(parts) > 1 else len(parts[0])
368
+ scale = len(parts[1]) if len(parts) > 1 else 0
369
+ return exp.DataType.build(type_name, expressions=[exp.Literal.number(precision), exp.Literal.number(scale)])
370
+ if type_hint == "string" and isinstance(value, str):
371
+ # Infer VARCHAR length
372
+ length = len(value)
373
+ if length > 0:
374
+ return exp.DataType.build(type_name, expressions=[exp.Literal.number(length)])
375
+
376
+ # Default case - just the type name
377
+ return exp.DataType.build(type_name)
378
+
379
+ @staticmethod
380
+ def _generate_semantic_name_from_context(
381
+ literal: exp.Expression, context: ParameterizationContext
382
+ ) -> "Optional[str]":
383
+ """Generate semantic name from AST context using existing parent stack.
384
+
385
+ Args:
386
+ literal: The literal being parameterized
387
+ context: Current context with parent stack
388
+
389
+ Returns:
390
+ Semantic name or None
391
+ """
392
+ # Look for column comparisons in parent stack
393
+ for parent in reversed(context.parent_stack):
394
+ if isinstance(parent, Binary):
395
+ # It's a comparison - check if we're comparing to a column
396
+ if parent.left == literal and isinstance(parent.right, exp.Column):
397
+ return parent.right.name
398
+ if parent.right == literal and isinstance(parent.left, exp.Column):
399
+ return parent.left.name
400
+ elif isinstance(parent, exp.In):
401
+ # IN clause - check the left side for column
402
+ if parent.this and isinstance(parent.this, exp.Column):
403
+ return f"{parent.this.name}_value"
404
+
405
+ # Check if we're in a specific SQL clause
406
+ for parent in reversed(context.parent_stack):
407
+ if isinstance(parent, exp.Where):
408
+ return "where_value"
409
+ if isinstance(parent, exp.Having):
410
+ return "having_value"
411
+ if isinstance(parent, exp.Join):
412
+ return "join_value"
413
+ if isinstance(parent, exp.Select):
414
+ return "select_value"
415
+
416
+ return None
417
+
418
+ def _is_string_literal(self, literal: exp.Literal) -> bool:
419
+ """Check if a literal is a string."""
420
+ # Check if it's explicitly a string literal
421
+ return (hasattr(literal, "is_string") and literal.is_string) or (
422
+ isinstance(literal.this, str) and not self._is_number_literal(literal)
423
+ )
424
+
425
+ @staticmethod
426
+ def _is_number_literal(literal: exp.Literal) -> bool:
427
+ """Check if a literal is a number."""
428
+ # Check if it's explicitly a number literal
429
+ if hasattr(literal, "is_number") and literal.is_number:
430
+ return True
431
+ if literal.this is None:
432
+ return False
433
+ # Try to determine if it's numeric by attempting conversion
434
+ try:
435
+ float(str(literal.this))
436
+ except (ValueError, TypeError):
437
+ return False
438
+ return True
439
+
440
+ def _create_placeholder(self, hint: Optional[str] = None) -> exp.Expression:
441
+ """Create a placeholder expression with optional type hint."""
442
+ # Import ParameterStyle for proper comparison
443
+
444
+ # Handle both style names and actual placeholder prefixes
445
+ style = self.placeholder_style
446
+ if style in {"?", ParameterStyle.QMARK, "qmark"}:
447
+ placeholder = exp.Placeholder()
448
+ elif style == ":name":
449
+ # Use hint in parameter name if available
450
+ param_name = f"{hint}_{self._parameter_counter}" if hint else f"param_{self._parameter_counter}"
451
+ placeholder = exp.Placeholder(this=param_name)
452
+ elif style in {ParameterStyle.NAMED_COLON, "named_colon"} or style.startswith(":"):
453
+ param_name = f"param_{self._parameter_counter}"
454
+ placeholder = exp.Placeholder(this=param_name)
455
+ elif style in {ParameterStyle.NUMERIC, "numeric"} or style.startswith("$"):
456
+ # PostgreSQL style numbered parameters - use Var for consistent $N format
457
+ # Note: PostgreSQL uses 1-based indexing
458
+ placeholder = exp.Var(this=f"${self._parameter_counter + 1}") # type: ignore[assignment]
459
+ elif style in {ParameterStyle.NAMED_AT, "named_at"}:
460
+ # BigQuery style @param - don't include @ in the placeholder name
461
+ # The @ will be added during SQL generation
462
+ # Use 0-based indexing for consistency with parameter arrays
463
+ param_name = f"param_{self._parameter_counter}"
464
+ placeholder = exp.Placeholder(this=param_name)
465
+ elif style in {ParameterStyle.POSITIONAL_PYFORMAT, "pyformat"}:
466
+ # Don't use pyformat directly in SQLGlot - use standard placeholder
467
+ # and let the compile method convert it later
468
+ placeholder = exp.Placeholder()
469
+ else:
470
+ # Default to question mark
471
+ placeholder = exp.Placeholder()
472
+
473
+ # Increment counter after creating placeholder
474
+ self._parameter_counter += 1
475
+ return placeholder
476
+
477
+ def _process_array(self, array_node: Array, context: ParameterizationContext) -> exp.Expression:
478
+ """Process array literals for parameterization."""
479
+ if not array_node.expressions:
480
+ return array_node
481
+
482
+ # Check array size
483
+ if len(array_node.expressions) > self.max_array_length:
484
+ # Too large, preserve as-is
485
+ return array_node
486
+
487
+ # Extract all array elements
488
+ array_values = []
489
+ element_types = []
490
+ all_literals = True
491
+
492
+ for expr in array_node.expressions:
493
+ if isinstance(expr, Literal):
494
+ value, type_hint = self._extract_literal_value_and_type(expr)
495
+ array_values.append(value)
496
+ element_types.append(type_hint)
497
+ else:
498
+ all_literals = False
499
+ break
500
+
501
+ if all_literals:
502
+ # Determine array element type from the first element
503
+ element_type = element_types[0] if element_types else "unknown"
504
+
505
+ # Create SQLGlot array type
506
+ element_sqlglot_type = self._infer_sqlglot_type(element_type, array_values[0] if array_values else None)
507
+ array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
508
+
509
+ # Create TypedParameter for the entire array
510
+ from sqlspec.statement.parameters import TypedParameter
511
+
512
+ typed_param = TypedParameter(
513
+ value=array_values,
514
+ sqlglot_type=array_sqlglot_type,
515
+ type_hint=f"array<{element_type}>",
516
+ semantic_name="array_values",
517
+ )
518
+
519
+ # Replace entire array with a single parameter
520
+ self.extracted_parameters.append(typed_param)
521
+ self._parameter_metadata.append(
522
+ {
523
+ "index": len(self.extracted_parameters) - 1,
524
+ "type": f"array<{element_type}>",
525
+ "length": len(array_values),
526
+ "context": "array_literal",
527
+ }
528
+ )
529
+ return self._create_placeholder("array")
530
+ # Process individual elements
531
+ new_expressions = []
532
+ for expr in array_node.expressions:
533
+ if isinstance(expr, Literal):
534
+ new_expressions.append(self._process_literal_with_context(expr, context))
535
+ else:
536
+ new_expressions.append(self._transform_with_context(expr, context))
537
+ array_node.set("expressions", new_expressions)
538
+ return array_node
539
+
540
+ def _process_in_clause(self, in_node: exp.In, context: ParameterizationContext) -> exp.Expression:
541
+ """Process IN clause for intelligent parameterization."""
542
+ # Check if it's a subquery IN clause (has 'query' in args)
543
+ if in_node.args.get("query"):
544
+ # Don't parameterize subqueries, just process them recursively
545
+ in_node.set("query", self._transform_with_context(in_node.args["query"], context))
546
+ return in_node
547
+
548
+ # Check if it has literal expressions (the values on the right side)
549
+ if "expressions" not in in_node.args or not in_node.args["expressions"]:
550
+ return in_node
551
+
552
+ # Check if the IN list is too large
553
+ expressions = in_node.args["expressions"]
554
+ if len(expressions) > self.max_in_list_size:
555
+ # Consider alternative strategies for large IN lists
556
+ return in_node
557
+
558
+ # Process the expressions in the IN clause
559
+ has_literals = any(isinstance(expr, Literal) for expr in expressions)
560
+
561
+ if has_literals:
562
+ # Transform literals in the IN list
563
+ new_expressions = []
564
+ for expr in expressions:
565
+ if isinstance(expr, Literal):
566
+ new_expressions.append(self._process_literal_with_context(expr, context))
567
+ else:
568
+ new_expressions.append(self._transform_with_context(expr, context))
569
+
570
+ # Update the IN node's expressions using set method
571
+ in_node.set("expressions", new_expressions)
572
+
573
+ return in_node
574
+
575
+ def _get_context_description(self, context: ParameterizationContext) -> str:
576
+ """Get a description of the current parameterization context."""
577
+ descriptions = []
578
+
579
+ if context.in_function_args:
580
+ descriptions.append("function_args")
581
+ if context.in_case_when:
582
+ descriptions.append("case_when")
583
+ if context.in_array:
584
+ descriptions.append("array")
585
+ if context.in_in_clause:
586
+ descriptions.append("in_clause")
587
+
588
+ if not descriptions:
589
+ # Try to determine from parent stack
590
+ for parent in reversed(context.parent_stack):
591
+ if isinstance(parent, exp.Select):
592
+ descriptions.append("select")
593
+ break
594
+ if isinstance(parent, exp.Where):
595
+ descriptions.append("where")
596
+ break
597
+ if isinstance(parent, exp.Join):
598
+ descriptions.append("join")
599
+ break
600
+
601
+ return "_".join(descriptions) if descriptions else "general"
602
+
603
+ def get_parameters(self) -> list[Any]:
604
+ """Get the list of extracted parameters from the last processing operation.
605
+
606
+ Returns:
607
+ List of parameter values extracted during the last process() call.
608
+ """
609
+ return self.extracted_parameters.copy()
610
+
611
+ def get_parameter_metadata(self) -> list[dict[str, Any]]:
612
+ """Get metadata about extracted parameters for advanced usage.
613
+
614
+ Returns:
615
+ List of parameter metadata dictionaries.
616
+ """
617
+ return self._parameter_metadata.copy()
618
+
619
+ def clear_parameters(self) -> None:
620
+ """Clear the extracted parameters list."""
621
+ self.extracted_parameters = []
622
+ self._parameter_counter = 0
623
+ self._parameter_metadata = []
@@ -0,0 +1,66 @@
1
+ from typing import Optional
2
+
3
+ from sqlglot import exp
4
+
5
+ from sqlspec.statement.pipelines.base import ProcessorProtocol
6
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
7
+
8
+ __all__ = ("CommentRemover",)
9
+
10
+
11
+ class CommentRemover(ProcessorProtocol):
12
+ """Removes standard SQL comments from expressions using SQLGlot's AST traversal.
13
+
14
+ This transformer removes SQL comments while preserving functionality:
15
+ - Removes line comments (-- comment)
16
+ - Removes block comments (/* comment */)
17
+ - Preserves string literals that contain comment-like patterns
18
+ - Always preserves SQL hints and MySQL version comments (use HintRemover separately)
19
+ - Uses SQLGlot's AST for reliable, context-aware comment detection
20
+
21
+ Note: This transformer now focuses only on standard comments. Use HintRemover
22
+ separately if you need to remove Oracle hints (/*+ hint */) or MySQL version
23
+ comments (/*!50000 */).
24
+
25
+ Args:
26
+ enabled: Whether comment removal is enabled.
27
+ """
28
+
29
+ def __init__(self, enabled: bool = True) -> None:
30
+ self.enabled = enabled
31
+
32
+ def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
33
+ """Process the expression to remove comments using SQLGlot AST traversal."""
34
+ if not self.enabled or expression is None or context.current_expression is None:
35
+ return expression
36
+
37
+ comments_removed_count = 0
38
+
39
+ def _remove_comments(node: exp.Expression) -> "Optional[exp.Expression]":
40
+ nonlocal comments_removed_count
41
+ if hasattr(node, "comments") and node.comments:
42
+ original_comment_count = len(node.comments)
43
+ comments_to_keep = []
44
+
45
+ for comment in node.comments:
46
+ comment_text = str(comment).strip()
47
+ hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
48
+ is_hint = any(keyword in comment_text.upper() for keyword in hint_keywords)
49
+
50
+ if is_hint or (comment_text.startswith("!") and comment_text.endswith("")):
51
+ comments_to_keep.append(comment)
52
+
53
+ if len(comments_to_keep) < original_comment_count:
54
+ comments_removed_count += original_comment_count - len(comments_to_keep)
55
+ node.pop_comments()
56
+ if comments_to_keep:
57
+ node.add_comments(comments_to_keep)
58
+
59
+ return node
60
+
61
+ cleaned_expression = context.current_expression.transform(_remove_comments, copy=True)
62
+ context.current_expression = cleaned_expression
63
+
64
+ context.metadata["comments_removed"] = comments_removed_count
65
+
66
+ return cleaned_expression