sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,623 @@
|
|
|
1
|
+
"""Replaces literals in SQL with placeholders and extracts them using SQLGlot AST."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
from sqlglot.expressions import Array, Binary, Boolean, DataType, Func, Literal, Null
|
|
8
|
+
|
|
9
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
10
|
+
from sqlspec.statement.pipelines.base import ProcessorProtocol
|
|
11
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
12
|
+
|
|
13
|
+
__all__ = ("ParameterizationContext", "ParameterizeLiterals")
|
|
14
|
+
|
|
15
|
+
# Constants for magic values and literal parameterization
|
|
16
|
+
MAX_DECIMAL_PRECISION = 6
|
|
17
|
+
MAX_INT32_VALUE = 2147483647
|
|
18
|
+
DEFAULT_MAX_STRING_LENGTH = 1000
|
|
19
|
+
"""Default maximum string length for literal parameterization."""
|
|
20
|
+
|
|
21
|
+
DEFAULT_MAX_ARRAY_LENGTH = 100
|
|
22
|
+
"""Default maximum array length for literal parameterization."""
|
|
23
|
+
|
|
24
|
+
DEFAULT_MAX_IN_LIST_SIZE = 50
|
|
25
|
+
"""Default maximum IN clause list size before parameterization."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ParameterizationContext:
|
|
30
|
+
"""Context for tracking parameterization state during AST traversal."""
|
|
31
|
+
|
|
32
|
+
parent_stack: list[exp.Expression]
|
|
33
|
+
in_function_args: bool = False
|
|
34
|
+
in_case_when: bool = False
|
|
35
|
+
in_array: bool = False
|
|
36
|
+
in_in_clause: bool = False
|
|
37
|
+
function_depth: int = 0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ParameterizeLiterals(ProcessorProtocol):
|
|
41
|
+
"""Advanced literal parameterization using SQLGlot AST analysis.
|
|
42
|
+
|
|
43
|
+
This enhanced version provides:
|
|
44
|
+
- Context-aware parameterization based on AST position
|
|
45
|
+
- Smart handling of arrays, IN clauses, and function arguments
|
|
46
|
+
- Type-preserving parameter extraction
|
|
47
|
+
- Configurable parameterization strategies
|
|
48
|
+
- Performance optimization for query plan caching
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
placeholder_style: Style of placeholder to use ("?", ":name", "$1", etc.).
|
|
52
|
+
preserve_null: Whether to preserve NULL literals as-is.
|
|
53
|
+
preserve_boolean: Whether to preserve boolean literals as-is.
|
|
54
|
+
preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
|
|
55
|
+
preserve_in_functions: List of function names where literals should be preserved.
|
|
56
|
+
parameterize_arrays: Whether to parameterize array literals.
|
|
57
|
+
parameterize_in_lists: Whether to parameterize IN clause lists.
|
|
58
|
+
max_string_length: Maximum string length to parameterize.
|
|
59
|
+
max_array_length: Maximum array length to parameterize.
|
|
60
|
+
max_in_list_size: Maximum IN list size to parameterize.
|
|
61
|
+
type_preservation: Whether to preserve exact literal types.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
placeholder_style: str = "?",
|
|
67
|
+
preserve_null: bool = True,
|
|
68
|
+
preserve_boolean: bool = True,
|
|
69
|
+
preserve_numbers_in_limit: bool = True,
|
|
70
|
+
preserve_in_functions: Optional[list[str]] = None,
|
|
71
|
+
parameterize_arrays: bool = True,
|
|
72
|
+
parameterize_in_lists: bool = True,
|
|
73
|
+
max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
|
|
74
|
+
max_array_length: int = DEFAULT_MAX_ARRAY_LENGTH,
|
|
75
|
+
max_in_list_size: int = DEFAULT_MAX_IN_LIST_SIZE,
|
|
76
|
+
type_preservation: bool = True,
|
|
77
|
+
) -> None:
|
|
78
|
+
self.placeholder_style = placeholder_style
|
|
79
|
+
self.preserve_null = preserve_null
|
|
80
|
+
self.preserve_boolean = preserve_boolean
|
|
81
|
+
self.preserve_numbers_in_limit = preserve_numbers_in_limit
|
|
82
|
+
self.preserve_in_functions = preserve_in_functions or ["COALESCE", "IFNULL", "NVL", "ISNULL"]
|
|
83
|
+
self.parameterize_arrays = parameterize_arrays
|
|
84
|
+
self.parameterize_in_lists = parameterize_in_lists
|
|
85
|
+
self.max_string_length = max_string_length
|
|
86
|
+
self.max_array_length = max_array_length
|
|
87
|
+
self.max_in_list_size = max_in_list_size
|
|
88
|
+
self.type_preservation = type_preservation
|
|
89
|
+
self.extracted_parameters: list[Any] = []
|
|
90
|
+
self._parameter_counter = 0
|
|
91
|
+
self._parameter_metadata: list[dict[str, Any]] = [] # Track parameter types and context
|
|
92
|
+
|
|
93
|
+
def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
|
|
94
|
+
"""Advanced literal parameterization with context-aware AST analysis."""
|
|
95
|
+
if expression is None or context.current_expression is None or context.config.input_sql_had_placeholders:
|
|
96
|
+
return expression
|
|
97
|
+
|
|
98
|
+
self.extracted_parameters = []
|
|
99
|
+
self._parameter_counter = 0
|
|
100
|
+
self._parameter_metadata = []
|
|
101
|
+
|
|
102
|
+
param_context = ParameterizationContext(parent_stack=[])
|
|
103
|
+
transformed_expression = self._transform_with_context(context.current_expression.copy(), param_context)
|
|
104
|
+
context.current_expression = transformed_expression
|
|
105
|
+
context.extracted_parameters_from_pipeline.extend(self.extracted_parameters)
|
|
106
|
+
|
|
107
|
+
context.metadata["parameter_metadata"] = self._parameter_metadata
|
|
108
|
+
|
|
109
|
+
return transformed_expression
|
|
110
|
+
|
|
111
|
+
def _transform_with_context(self, node: exp.Expression, context: ParameterizationContext) -> exp.Expression:
|
|
112
|
+
"""Transform expression tree with context tracking."""
|
|
113
|
+
# Update context based on node type
|
|
114
|
+
self._update_context(node, context, entering=True)
|
|
115
|
+
|
|
116
|
+
# Process the node
|
|
117
|
+
if isinstance(node, Literal):
|
|
118
|
+
result = self._process_literal_with_context(node, context)
|
|
119
|
+
elif isinstance(node, (Boolean, Null)):
|
|
120
|
+
# Boolean and Null are not Literal subclasses, handle them separately
|
|
121
|
+
result = self._process_literal_with_context(node, context)
|
|
122
|
+
elif isinstance(node, Array) and self.parameterize_arrays:
|
|
123
|
+
result = self._process_array(node, context)
|
|
124
|
+
elif isinstance(node, exp.In) and self.parameterize_in_lists:
|
|
125
|
+
result = self._process_in_clause(node, context)
|
|
126
|
+
else:
|
|
127
|
+
# Recursively process children
|
|
128
|
+
for key, value in node.args.items():
|
|
129
|
+
if isinstance(value, exp.Expression):
|
|
130
|
+
node.set(key, self._transform_with_context(value, context))
|
|
131
|
+
elif isinstance(value, list):
|
|
132
|
+
node.set(
|
|
133
|
+
key,
|
|
134
|
+
[
|
|
135
|
+
self._transform_with_context(v, context) if isinstance(v, exp.Expression) else v
|
|
136
|
+
for v in value
|
|
137
|
+
],
|
|
138
|
+
)
|
|
139
|
+
result = node
|
|
140
|
+
|
|
141
|
+
# Update context when leaving
|
|
142
|
+
self._update_context(node, context, entering=False)
|
|
143
|
+
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
def _update_context(self, node: exp.Expression, context: ParameterizationContext, entering: bool) -> None:
|
|
147
|
+
"""Update parameterization context based on current AST node."""
|
|
148
|
+
if entering:
|
|
149
|
+
context.parent_stack.append(node)
|
|
150
|
+
|
|
151
|
+
if isinstance(node, Func):
|
|
152
|
+
context.function_depth += 1
|
|
153
|
+
# Get function name from class name or node.name
|
|
154
|
+
func_name = node.__class__.__name__.upper()
|
|
155
|
+
if func_name in self.preserve_in_functions or (
|
|
156
|
+
node.name and node.name.upper() in self.preserve_in_functions
|
|
157
|
+
):
|
|
158
|
+
context.in_function_args = True
|
|
159
|
+
elif isinstance(node, exp.Case):
|
|
160
|
+
context.in_case_when = True
|
|
161
|
+
elif isinstance(node, Array):
|
|
162
|
+
context.in_array = True
|
|
163
|
+
elif isinstance(node, exp.In):
|
|
164
|
+
context.in_in_clause = True
|
|
165
|
+
else:
|
|
166
|
+
if context.parent_stack:
|
|
167
|
+
context.parent_stack.pop()
|
|
168
|
+
|
|
169
|
+
if isinstance(node, Func):
|
|
170
|
+
context.function_depth -= 1
|
|
171
|
+
if context.function_depth == 0:
|
|
172
|
+
context.in_function_args = False
|
|
173
|
+
elif isinstance(node, exp.Case):
|
|
174
|
+
context.in_case_when = False
|
|
175
|
+
elif isinstance(node, Array):
|
|
176
|
+
context.in_array = False
|
|
177
|
+
elif isinstance(node, exp.In):
|
|
178
|
+
context.in_in_clause = False
|
|
179
|
+
|
|
180
|
+
def _process_literal_with_context(
|
|
181
|
+
self, literal: exp.Expression, context: ParameterizationContext
|
|
182
|
+
) -> exp.Expression:
|
|
183
|
+
"""Process a literal with awareness of its AST context."""
|
|
184
|
+
# Check if this literal should be preserved based on context
|
|
185
|
+
if self._should_preserve_literal_in_context(literal, context):
|
|
186
|
+
return literal
|
|
187
|
+
|
|
188
|
+
# Use optimized extraction for single-pass processing
|
|
189
|
+
value, type_hint, sqlglot_type, semantic_name = self._extract_literal_value_and_type_optimized(literal, context)
|
|
190
|
+
|
|
191
|
+
# Create TypedParameter object
|
|
192
|
+
from sqlspec.statement.parameters import TypedParameter
|
|
193
|
+
|
|
194
|
+
typed_param = TypedParameter(
|
|
195
|
+
value=value,
|
|
196
|
+
sqlglot_type=sqlglot_type or exp.DataType.build("VARCHAR"), # Fallback type
|
|
197
|
+
type_hint=type_hint,
|
|
198
|
+
semantic_name=semantic_name,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Add to parameters list
|
|
202
|
+
self.extracted_parameters.append(typed_param)
|
|
203
|
+
self._parameter_metadata.append(
|
|
204
|
+
{
|
|
205
|
+
"index": len(self.extracted_parameters) - 1,
|
|
206
|
+
"type": type_hint,
|
|
207
|
+
"semantic_name": semantic_name,
|
|
208
|
+
"context": self._get_context_description(context),
|
|
209
|
+
# Note: We avoid calling literal.sql() for performance
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Create appropriate placeholder
|
|
214
|
+
return self._create_placeholder(hint=semantic_name)
|
|
215
|
+
|
|
216
|
+
def _should_preserve_literal_in_context(self, literal: exp.Expression, context: ParameterizationContext) -> bool:
|
|
217
|
+
"""Context-aware decision on literal preservation."""
|
|
218
|
+
# Check for NULL values
|
|
219
|
+
if self.preserve_null and isinstance(literal, Null):
|
|
220
|
+
return True
|
|
221
|
+
|
|
222
|
+
# Check for boolean values
|
|
223
|
+
if self.preserve_boolean and isinstance(literal, Boolean):
|
|
224
|
+
return True
|
|
225
|
+
|
|
226
|
+
# Check if in preserved function arguments
|
|
227
|
+
if context.in_function_args:
|
|
228
|
+
return True
|
|
229
|
+
|
|
230
|
+
# Check parent context more intelligently
|
|
231
|
+
for parent in context.parent_stack:
|
|
232
|
+
# Preserve in schema/DDL contexts
|
|
233
|
+
if isinstance(parent, (DataType, exp.ColumnDef, exp.Create, exp.Schema)):
|
|
234
|
+
return True
|
|
235
|
+
|
|
236
|
+
# Preserve numbers in LIMIT/OFFSET
|
|
237
|
+
if (
|
|
238
|
+
self.preserve_numbers_in_limit
|
|
239
|
+
and isinstance(parent, (exp.Limit, exp.Offset))
|
|
240
|
+
and isinstance(literal, exp.Literal)
|
|
241
|
+
and self._is_number_literal(literal)
|
|
242
|
+
):
|
|
243
|
+
return True
|
|
244
|
+
|
|
245
|
+
# Preserve in CASE conditions for readability
|
|
246
|
+
if isinstance(parent, exp.Case) and context.in_case_when:
|
|
247
|
+
# Only preserve simple comparisons
|
|
248
|
+
return not isinstance(literal.parent, Binary)
|
|
249
|
+
|
|
250
|
+
# Check string length
|
|
251
|
+
if isinstance(literal, exp.Literal) and self._is_string_literal(literal):
|
|
252
|
+
string_value = str(literal.this)
|
|
253
|
+
if len(string_value) > self.max_string_length:
|
|
254
|
+
return True
|
|
255
|
+
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
def _extract_literal_value_and_type(self, literal: exp.Expression) -> tuple[Any, str]:
|
|
259
|
+
"""Extract the Python value and type info from a SQLGlot literal."""
|
|
260
|
+
if isinstance(literal, Null) or literal.this is None:
|
|
261
|
+
return None, "null"
|
|
262
|
+
|
|
263
|
+
# Ensure we have a Literal for type checking methods
|
|
264
|
+
if not isinstance(literal, exp.Literal):
|
|
265
|
+
return str(literal), "string"
|
|
266
|
+
|
|
267
|
+
if isinstance(literal, Boolean) or isinstance(literal.this, bool):
|
|
268
|
+
return literal.this, "boolean"
|
|
269
|
+
|
|
270
|
+
if self._is_string_literal(literal):
|
|
271
|
+
return str(literal.this), "string"
|
|
272
|
+
|
|
273
|
+
if self._is_number_literal(literal):
|
|
274
|
+
# Preserve numeric precision if enabled
|
|
275
|
+
if self.type_preservation:
|
|
276
|
+
value_str = str(literal.this)
|
|
277
|
+
if "." in value_str or "e" in value_str.lower():
|
|
278
|
+
try:
|
|
279
|
+
# Check if it's a decimal that needs precision
|
|
280
|
+
decimal_places = len(value_str.split(".")[1]) if "." in value_str else 0
|
|
281
|
+
if decimal_places > MAX_DECIMAL_PRECISION: # Likely needs decimal precision
|
|
282
|
+
return value_str, "decimal"
|
|
283
|
+
return float(literal.this), "float"
|
|
284
|
+
except (ValueError, IndexError):
|
|
285
|
+
return str(literal.this), "numeric_string"
|
|
286
|
+
else:
|
|
287
|
+
try:
|
|
288
|
+
value = int(literal.this)
|
|
289
|
+
except ValueError:
|
|
290
|
+
return str(literal.this), "numeric_string"
|
|
291
|
+
else:
|
|
292
|
+
# Check for bigint
|
|
293
|
+
if abs(value) > MAX_INT32_VALUE: # Max 32-bit int
|
|
294
|
+
return value, "bigint"
|
|
295
|
+
return value, "integer"
|
|
296
|
+
else:
|
|
297
|
+
# Simple type conversion
|
|
298
|
+
try:
|
|
299
|
+
if "." in str(literal.this):
|
|
300
|
+
return float(literal.this), "float"
|
|
301
|
+
return int(literal.this), "integer"
|
|
302
|
+
except ValueError:
|
|
303
|
+
return str(literal.this), "numeric_string"
|
|
304
|
+
|
|
305
|
+
# Handle date/time literals - these are DataType attributes not Literal attributes
|
|
306
|
+
# Date/time values are typically string literals that need context-aware processing
|
|
307
|
+
# We'll return them as strings and let the database handle type conversion
|
|
308
|
+
|
|
309
|
+
# Fallback
|
|
310
|
+
return str(literal.this), "unknown"
|
|
311
|
+
|
|
312
|
+
def _extract_literal_value_and_type_optimized(
|
|
313
|
+
self, literal: exp.Expression, context: ParameterizationContext
|
|
314
|
+
) -> "tuple[Any, str, Optional[exp.DataType], Optional[str]]":
|
|
315
|
+
"""Single-pass extraction of value, type hint, SQLGlot type, and semantic name.
|
|
316
|
+
|
|
317
|
+
This optimized method extracts all information in one pass, avoiding redundant
|
|
318
|
+
AST traversals and expensive operations like literal.sql().
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
literal: The literal expression to extract from
|
|
322
|
+
context: Current parameterization context with parent stack
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Tuple of (value, type_hint, sqlglot_type, semantic_name)
|
|
326
|
+
"""
|
|
327
|
+
# Extract value and basic type hint using existing logic
|
|
328
|
+
value, type_hint = self._extract_literal_value_and_type(literal)
|
|
329
|
+
|
|
330
|
+
# Determine SQLGlot type based on the type hint without additional parsing
|
|
331
|
+
sqlglot_type = self._infer_sqlglot_type(type_hint, value)
|
|
332
|
+
|
|
333
|
+
# Generate semantic name from context if available
|
|
334
|
+
semantic_name = self._generate_semantic_name_from_context(literal, context)
|
|
335
|
+
|
|
336
|
+
return value, type_hint, sqlglot_type, semantic_name
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
def _infer_sqlglot_type(type_hint: str, value: Any) -> "Optional[exp.DataType]":
|
|
340
|
+
"""Infer SQLGlot DataType from type hint without parsing.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
type_hint: The simple type hint string
|
|
344
|
+
value: The actual value for additional context
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
SQLGlot DataType instance or None
|
|
348
|
+
"""
|
|
349
|
+
type_mapping = {
|
|
350
|
+
"null": "NULL",
|
|
351
|
+
"boolean": "BOOLEAN",
|
|
352
|
+
"integer": "INT",
|
|
353
|
+
"bigint": "BIGINT",
|
|
354
|
+
"float": "FLOAT",
|
|
355
|
+
"decimal": "DECIMAL",
|
|
356
|
+
"string": "VARCHAR",
|
|
357
|
+
"numeric_string": "VARCHAR",
|
|
358
|
+
"unknown": "VARCHAR",
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
type_name = type_mapping.get(type_hint, "VARCHAR")
|
|
362
|
+
|
|
363
|
+
# Build DataType with appropriate parameters
|
|
364
|
+
if type_hint == "decimal" and isinstance(value, str):
|
|
365
|
+
# Try to infer precision and scale
|
|
366
|
+
parts = value.split(".")
|
|
367
|
+
precision = len(parts[0]) + len(parts[1]) if len(parts) > 1 else len(parts[0])
|
|
368
|
+
scale = len(parts[1]) if len(parts) > 1 else 0
|
|
369
|
+
return exp.DataType.build(type_name, expressions=[exp.Literal.number(precision), exp.Literal.number(scale)])
|
|
370
|
+
if type_hint == "string" and isinstance(value, str):
|
|
371
|
+
# Infer VARCHAR length
|
|
372
|
+
length = len(value)
|
|
373
|
+
if length > 0:
|
|
374
|
+
return exp.DataType.build(type_name, expressions=[exp.Literal.number(length)])
|
|
375
|
+
|
|
376
|
+
# Default case - just the type name
|
|
377
|
+
return exp.DataType.build(type_name)
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
def _generate_semantic_name_from_context(
|
|
381
|
+
literal: exp.Expression, context: ParameterizationContext
|
|
382
|
+
) -> "Optional[str]":
|
|
383
|
+
"""Generate semantic name from AST context using existing parent stack.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
literal: The literal being parameterized
|
|
387
|
+
context: Current context with parent stack
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
Semantic name or None
|
|
391
|
+
"""
|
|
392
|
+
# Look for column comparisons in parent stack
|
|
393
|
+
for parent in reversed(context.parent_stack):
|
|
394
|
+
if isinstance(parent, Binary):
|
|
395
|
+
# It's a comparison - check if we're comparing to a column
|
|
396
|
+
if parent.left == literal and isinstance(parent.right, exp.Column):
|
|
397
|
+
return parent.right.name
|
|
398
|
+
if parent.right == literal and isinstance(parent.left, exp.Column):
|
|
399
|
+
return parent.left.name
|
|
400
|
+
elif isinstance(parent, exp.In):
|
|
401
|
+
# IN clause - check the left side for column
|
|
402
|
+
if parent.this and isinstance(parent.this, exp.Column):
|
|
403
|
+
return f"{parent.this.name}_value"
|
|
404
|
+
|
|
405
|
+
# Check if we're in a specific SQL clause
|
|
406
|
+
for parent in reversed(context.parent_stack):
|
|
407
|
+
if isinstance(parent, exp.Where):
|
|
408
|
+
return "where_value"
|
|
409
|
+
if isinstance(parent, exp.Having):
|
|
410
|
+
return "having_value"
|
|
411
|
+
if isinstance(parent, exp.Join):
|
|
412
|
+
return "join_value"
|
|
413
|
+
if isinstance(parent, exp.Select):
|
|
414
|
+
return "select_value"
|
|
415
|
+
|
|
416
|
+
return None
|
|
417
|
+
|
|
418
|
+
def _is_string_literal(self, literal: exp.Literal) -> bool:
|
|
419
|
+
"""Check if a literal is a string."""
|
|
420
|
+
# Check if it's explicitly a string literal
|
|
421
|
+
return (hasattr(literal, "is_string") and literal.is_string) or (
|
|
422
|
+
isinstance(literal.this, str) and not self._is_number_literal(literal)
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
@staticmethod
|
|
426
|
+
def _is_number_literal(literal: exp.Literal) -> bool:
|
|
427
|
+
"""Check if a literal is a number."""
|
|
428
|
+
# Check if it's explicitly a number literal
|
|
429
|
+
if hasattr(literal, "is_number") and literal.is_number:
|
|
430
|
+
return True
|
|
431
|
+
if literal.this is None:
|
|
432
|
+
return False
|
|
433
|
+
# Try to determine if it's numeric by attempting conversion
|
|
434
|
+
try:
|
|
435
|
+
float(str(literal.this))
|
|
436
|
+
except (ValueError, TypeError):
|
|
437
|
+
return False
|
|
438
|
+
return True
|
|
439
|
+
|
|
440
|
+
def _create_placeholder(self, hint: Optional[str] = None) -> exp.Expression:
|
|
441
|
+
"""Create a placeholder expression with optional type hint."""
|
|
442
|
+
# Import ParameterStyle for proper comparison
|
|
443
|
+
|
|
444
|
+
# Handle both style names and actual placeholder prefixes
|
|
445
|
+
style = self.placeholder_style
|
|
446
|
+
if style in {"?", ParameterStyle.QMARK, "qmark"}:
|
|
447
|
+
placeholder = exp.Placeholder()
|
|
448
|
+
elif style == ":name":
|
|
449
|
+
# Use hint in parameter name if available
|
|
450
|
+
param_name = f"{hint}_{self._parameter_counter}" if hint else f"param_{self._parameter_counter}"
|
|
451
|
+
placeholder = exp.Placeholder(this=param_name)
|
|
452
|
+
elif style in {ParameterStyle.NAMED_COLON, "named_colon"} or style.startswith(":"):
|
|
453
|
+
param_name = f"param_{self._parameter_counter}"
|
|
454
|
+
placeholder = exp.Placeholder(this=param_name)
|
|
455
|
+
elif style in {ParameterStyle.NUMERIC, "numeric"} or style.startswith("$"):
|
|
456
|
+
# PostgreSQL style numbered parameters - use Var for consistent $N format
|
|
457
|
+
# Note: PostgreSQL uses 1-based indexing
|
|
458
|
+
placeholder = exp.Var(this=f"${self._parameter_counter + 1}") # type: ignore[assignment]
|
|
459
|
+
elif style in {ParameterStyle.NAMED_AT, "named_at"}:
|
|
460
|
+
# BigQuery style @param - don't include @ in the placeholder name
|
|
461
|
+
# The @ will be added during SQL generation
|
|
462
|
+
# Use 0-based indexing for consistency with parameter arrays
|
|
463
|
+
param_name = f"param_{self._parameter_counter}"
|
|
464
|
+
placeholder = exp.Placeholder(this=param_name)
|
|
465
|
+
elif style in {ParameterStyle.POSITIONAL_PYFORMAT, "pyformat"}:
|
|
466
|
+
# Don't use pyformat directly in SQLGlot - use standard placeholder
|
|
467
|
+
# and let the compile method convert it later
|
|
468
|
+
placeholder = exp.Placeholder()
|
|
469
|
+
else:
|
|
470
|
+
# Default to question mark
|
|
471
|
+
placeholder = exp.Placeholder()
|
|
472
|
+
|
|
473
|
+
# Increment counter after creating placeholder
|
|
474
|
+
self._parameter_counter += 1
|
|
475
|
+
return placeholder
|
|
476
|
+
|
|
477
|
+
def _process_array(self, array_node: Array, context: ParameterizationContext) -> exp.Expression:
|
|
478
|
+
"""Process array literals for parameterization."""
|
|
479
|
+
if not array_node.expressions:
|
|
480
|
+
return array_node
|
|
481
|
+
|
|
482
|
+
# Check array size
|
|
483
|
+
if len(array_node.expressions) > self.max_array_length:
|
|
484
|
+
# Too large, preserve as-is
|
|
485
|
+
return array_node
|
|
486
|
+
|
|
487
|
+
# Extract all array elements
|
|
488
|
+
array_values = []
|
|
489
|
+
element_types = []
|
|
490
|
+
all_literals = True
|
|
491
|
+
|
|
492
|
+
for expr in array_node.expressions:
|
|
493
|
+
if isinstance(expr, Literal):
|
|
494
|
+
value, type_hint = self._extract_literal_value_and_type(expr)
|
|
495
|
+
array_values.append(value)
|
|
496
|
+
element_types.append(type_hint)
|
|
497
|
+
else:
|
|
498
|
+
all_literals = False
|
|
499
|
+
break
|
|
500
|
+
|
|
501
|
+
if all_literals:
|
|
502
|
+
# Determine array element type from the first element
|
|
503
|
+
element_type = element_types[0] if element_types else "unknown"
|
|
504
|
+
|
|
505
|
+
# Create SQLGlot array type
|
|
506
|
+
element_sqlglot_type = self._infer_sqlglot_type(element_type, array_values[0] if array_values else None)
|
|
507
|
+
array_sqlglot_type = exp.DataType.build("ARRAY", expressions=[element_sqlglot_type])
|
|
508
|
+
|
|
509
|
+
# Create TypedParameter for the entire array
|
|
510
|
+
from sqlspec.statement.parameters import TypedParameter
|
|
511
|
+
|
|
512
|
+
typed_param = TypedParameter(
|
|
513
|
+
value=array_values,
|
|
514
|
+
sqlglot_type=array_sqlglot_type,
|
|
515
|
+
type_hint=f"array<{element_type}>",
|
|
516
|
+
semantic_name="array_values",
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Replace entire array with a single parameter
|
|
520
|
+
self.extracted_parameters.append(typed_param)
|
|
521
|
+
self._parameter_metadata.append(
|
|
522
|
+
{
|
|
523
|
+
"index": len(self.extracted_parameters) - 1,
|
|
524
|
+
"type": f"array<{element_type}>",
|
|
525
|
+
"length": len(array_values),
|
|
526
|
+
"context": "array_literal",
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
return self._create_placeholder("array")
|
|
530
|
+
# Process individual elements
|
|
531
|
+
new_expressions = []
|
|
532
|
+
for expr in array_node.expressions:
|
|
533
|
+
if isinstance(expr, Literal):
|
|
534
|
+
new_expressions.append(self._process_literal_with_context(expr, context))
|
|
535
|
+
else:
|
|
536
|
+
new_expressions.append(self._transform_with_context(expr, context))
|
|
537
|
+
array_node.set("expressions", new_expressions)
|
|
538
|
+
return array_node
|
|
539
|
+
|
|
540
|
+
def _process_in_clause(self, in_node: exp.In, context: ParameterizationContext) -> exp.Expression:
|
|
541
|
+
"""Process IN clause for intelligent parameterization."""
|
|
542
|
+
# Check if it's a subquery IN clause (has 'query' in args)
|
|
543
|
+
if in_node.args.get("query"):
|
|
544
|
+
# Don't parameterize subqueries, just process them recursively
|
|
545
|
+
in_node.set("query", self._transform_with_context(in_node.args["query"], context))
|
|
546
|
+
return in_node
|
|
547
|
+
|
|
548
|
+
# Check if it has literal expressions (the values on the right side)
|
|
549
|
+
if "expressions" not in in_node.args or not in_node.args["expressions"]:
|
|
550
|
+
return in_node
|
|
551
|
+
|
|
552
|
+
# Check if the IN list is too large
|
|
553
|
+
expressions = in_node.args["expressions"]
|
|
554
|
+
if len(expressions) > self.max_in_list_size:
|
|
555
|
+
# Consider alternative strategies for large IN lists
|
|
556
|
+
return in_node
|
|
557
|
+
|
|
558
|
+
# Process the expressions in the IN clause
|
|
559
|
+
has_literals = any(isinstance(expr, Literal) for expr in expressions)
|
|
560
|
+
|
|
561
|
+
if has_literals:
|
|
562
|
+
# Transform literals in the IN list
|
|
563
|
+
new_expressions = []
|
|
564
|
+
for expr in expressions:
|
|
565
|
+
if isinstance(expr, Literal):
|
|
566
|
+
new_expressions.append(self._process_literal_with_context(expr, context))
|
|
567
|
+
else:
|
|
568
|
+
new_expressions.append(self._transform_with_context(expr, context))
|
|
569
|
+
|
|
570
|
+
# Update the IN node's expressions using set method
|
|
571
|
+
in_node.set("expressions", new_expressions)
|
|
572
|
+
|
|
573
|
+
return in_node
|
|
574
|
+
|
|
575
|
+
def _get_context_description(self, context: ParameterizationContext) -> str:
|
|
576
|
+
"""Get a description of the current parameterization context."""
|
|
577
|
+
descriptions = []
|
|
578
|
+
|
|
579
|
+
if context.in_function_args:
|
|
580
|
+
descriptions.append("function_args")
|
|
581
|
+
if context.in_case_when:
|
|
582
|
+
descriptions.append("case_when")
|
|
583
|
+
if context.in_array:
|
|
584
|
+
descriptions.append("array")
|
|
585
|
+
if context.in_in_clause:
|
|
586
|
+
descriptions.append("in_clause")
|
|
587
|
+
|
|
588
|
+
if not descriptions:
|
|
589
|
+
# Try to determine from parent stack
|
|
590
|
+
for parent in reversed(context.parent_stack):
|
|
591
|
+
if isinstance(parent, exp.Select):
|
|
592
|
+
descriptions.append("select")
|
|
593
|
+
break
|
|
594
|
+
if isinstance(parent, exp.Where):
|
|
595
|
+
descriptions.append("where")
|
|
596
|
+
break
|
|
597
|
+
if isinstance(parent, exp.Join):
|
|
598
|
+
descriptions.append("join")
|
|
599
|
+
break
|
|
600
|
+
|
|
601
|
+
return "_".join(descriptions) if descriptions else "general"
|
|
602
|
+
|
|
603
|
+
def get_parameters(self) -> list[Any]:
|
|
604
|
+
"""Get the list of extracted parameters from the last processing operation.
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
List of parameter values extracted during the last process() call.
|
|
608
|
+
"""
|
|
609
|
+
return self.extracted_parameters.copy()
|
|
610
|
+
|
|
611
|
+
def get_parameter_metadata(self) -> list[dict[str, Any]]:
|
|
612
|
+
"""Get metadata about extracted parameters for advanced usage.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
List of parameter metadata dictionaries.
|
|
616
|
+
"""
|
|
617
|
+
return self._parameter_metadata.copy()
|
|
618
|
+
|
|
619
|
+
def clear_parameters(self) -> None:
|
|
620
|
+
"""Clear the extracted parameters list."""
|
|
621
|
+
self.extracted_parameters = []
|
|
622
|
+
self._parameter_counter = 0
|
|
623
|
+
self._parameter_metadata = []
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from sqlglot import exp
|
|
4
|
+
|
|
5
|
+
from sqlspec.statement.pipelines.base import ProcessorProtocol
|
|
6
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
7
|
+
|
|
8
|
+
__all__ = ("CommentRemover",)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CommentRemover(ProcessorProtocol):
|
|
12
|
+
"""Removes standard SQL comments from expressions using SQLGlot's AST traversal.
|
|
13
|
+
|
|
14
|
+
This transformer removes SQL comments while preserving functionality:
|
|
15
|
+
- Removes line comments (-- comment)
|
|
16
|
+
- Removes block comments (/* comment */)
|
|
17
|
+
- Preserves string literals that contain comment-like patterns
|
|
18
|
+
- Always preserves SQL hints and MySQL version comments (use HintRemover separately)
|
|
19
|
+
- Uses SQLGlot's AST for reliable, context-aware comment detection
|
|
20
|
+
|
|
21
|
+
Note: This transformer now focuses only on standard comments. Use HintRemover
|
|
22
|
+
separately if you need to remove Oracle hints (/*+ hint */) or MySQL version
|
|
23
|
+
comments (/*!50000 */).
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
enabled: Whether comment removal is enabled.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, enabled: bool = True) -> None:
|
|
30
|
+
self.enabled = enabled
|
|
31
|
+
|
|
32
|
+
def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
|
|
33
|
+
"""Process the expression to remove comments using SQLGlot AST traversal."""
|
|
34
|
+
if not self.enabled or expression is None or context.current_expression is None:
|
|
35
|
+
return expression
|
|
36
|
+
|
|
37
|
+
comments_removed_count = 0
|
|
38
|
+
|
|
39
|
+
def _remove_comments(node: exp.Expression) -> "Optional[exp.Expression]":
|
|
40
|
+
nonlocal comments_removed_count
|
|
41
|
+
if hasattr(node, "comments") and node.comments:
|
|
42
|
+
original_comment_count = len(node.comments)
|
|
43
|
+
comments_to_keep = []
|
|
44
|
+
|
|
45
|
+
for comment in node.comments:
|
|
46
|
+
comment_text = str(comment).strip()
|
|
47
|
+
hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
|
|
48
|
+
is_hint = any(keyword in comment_text.upper() for keyword in hint_keywords)
|
|
49
|
+
|
|
50
|
+
if is_hint or (comment_text.startswith("!") and comment_text.endswith("")):
|
|
51
|
+
comments_to_keep.append(comment)
|
|
52
|
+
|
|
53
|
+
if len(comments_to_keep) < original_comment_count:
|
|
54
|
+
comments_removed_count += original_comment_count - len(comments_to_keep)
|
|
55
|
+
node.pop_comments()
|
|
56
|
+
if comments_to_keep:
|
|
57
|
+
node.add_comments(comments_to_keep)
|
|
58
|
+
|
|
59
|
+
return node
|
|
60
|
+
|
|
61
|
+
cleaned_expression = context.current_expression.transform(_remove_comments, copy=True)
|
|
62
|
+
context.current_expression = cleaned_expression
|
|
63
|
+
|
|
64
|
+
context.metadata["comments_removed"] = comments_removed_count
|
|
65
|
+
|
|
66
|
+
return cleaned_expression
|