sqlspec 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +16 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +17 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +17 -29
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +32 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +18 -9
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +44 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/METADATA +1 -1
- sqlspec-0.13.1.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Removes SQL comments and hints from expressions."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
|
|
7
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
11
|
+
|
|
12
|
+
__all__ = ("CommentAndHintRemover",)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CommentAndHintRemover(ProcessorProtocol):
|
|
16
|
+
"""Removes SQL comments and hints from expressions using SQLGlot's AST traversal."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, enabled: bool = True, remove_comments: bool = True, remove_hints: bool = False) -> None:
|
|
19
|
+
self.enabled = enabled
|
|
20
|
+
self.remove_comments = remove_comments
|
|
21
|
+
self.remove_hints = remove_hints
|
|
22
|
+
|
|
23
|
+
def process(
|
|
24
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
25
|
+
) -> "Optional[exp.Expression]":
|
|
26
|
+
if not self.enabled or expression is None:
|
|
27
|
+
return expression
|
|
28
|
+
|
|
29
|
+
comments_removed_count = 0
|
|
30
|
+
hints_removed_count = 0
|
|
31
|
+
|
|
32
|
+
def _remove_comments_and_hints(node: exp.Expression) -> "Optional[exp.Expression]":
|
|
33
|
+
nonlocal comments_removed_count, hints_removed_count
|
|
34
|
+
|
|
35
|
+
if self.remove_hints and isinstance(node, exp.Hint):
|
|
36
|
+
hints_removed_count += 1
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
if hasattr(node, "comments") and node.comments:
|
|
40
|
+
original_comment_count = len(node.comments)
|
|
41
|
+
comments_to_keep = []
|
|
42
|
+
for comment in node.comments:
|
|
43
|
+
comment_text = str(comment).strip()
|
|
44
|
+
is_hint = self._is_hint(comment_text)
|
|
45
|
+
|
|
46
|
+
if is_hint:
|
|
47
|
+
if not self.remove_hints:
|
|
48
|
+
comments_to_keep.append(comment)
|
|
49
|
+
elif not self.remove_comments:
|
|
50
|
+
comments_to_keep.append(comment)
|
|
51
|
+
|
|
52
|
+
removed_count = original_comment_count - len(comments_to_keep)
|
|
53
|
+
if removed_count > 0:
|
|
54
|
+
if self.remove_hints:
|
|
55
|
+
hints_removed_count += sum(1 for c in node.comments if self._is_hint(str(c).strip()))
|
|
56
|
+
if self.remove_comments:
|
|
57
|
+
comments_removed_count += sum(1 for c in node.comments if not self._is_hint(str(c).strip()))
|
|
58
|
+
|
|
59
|
+
node.pop_comments()
|
|
60
|
+
if comments_to_keep:
|
|
61
|
+
node.add_comments(comments_to_keep)
|
|
62
|
+
|
|
63
|
+
return node
|
|
64
|
+
|
|
65
|
+
cleaned_expression = expression.transform(_remove_comments_and_hints, copy=True)
|
|
66
|
+
|
|
67
|
+
context.metadata["comments_removed"] = comments_removed_count
|
|
68
|
+
context.metadata["hints_removed"] = hints_removed_count
|
|
69
|
+
|
|
70
|
+
return cleaned_expression
|
|
71
|
+
|
|
72
|
+
def _is_hint(self, comment_text: str) -> bool:
|
|
73
|
+
hint_keywords = ["INDEX", "USE_NL", "USE_HASH", "PARALLEL", "FULL", "FIRST_ROWS", "ALL_ROWS"]
|
|
74
|
+
return any(keyword in comment_text.upper() for keyword in hint_keywords) or (
|
|
75
|
+
comment_text.startswith("!") and comment_text.endswith("")
|
|
76
|
+
)
|
|
@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Optional
|
|
|
6
6
|
from sqlglot import expressions as exp
|
|
7
7
|
|
|
8
8
|
from sqlspec.exceptions import RiskLevel
|
|
9
|
-
from sqlspec.
|
|
9
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
10
|
+
from sqlspec.statement.pipelines.context import ValidationError
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
@@ -36,7 +37,7 @@ class DMLSafetyConfig:
|
|
|
36
37
|
max_affected_rows: "Optional[int]" = None # Limit for DML operations
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
class DMLSafetyValidator(
|
|
40
|
+
class DMLSafetyValidator(ProcessorProtocol):
|
|
40
41
|
"""Unified validator for DML/DDL safety checks.
|
|
41
42
|
|
|
42
43
|
This validator consolidates:
|
|
@@ -52,9 +53,31 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
52
53
|
Args:
|
|
53
54
|
config: Configuration for safety validation
|
|
54
55
|
"""
|
|
55
|
-
super().__init__()
|
|
56
56
|
self.config = config or DMLSafetyConfig()
|
|
57
57
|
|
|
58
|
+
def process(
|
|
59
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
60
|
+
) -> "Optional[exp.Expression]":
|
|
61
|
+
"""Process the expression for validation (implements ProcessorProtocol)."""
|
|
62
|
+
if expression is None:
|
|
63
|
+
return None
|
|
64
|
+
self.validate(expression, context)
|
|
65
|
+
return expression
|
|
66
|
+
|
|
67
|
+
def add_error(
|
|
68
|
+
self,
|
|
69
|
+
context: "SQLProcessingContext",
|
|
70
|
+
message: str,
|
|
71
|
+
code: str,
|
|
72
|
+
risk_level: RiskLevel,
|
|
73
|
+
expression: "Optional[exp.Expression]" = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Add a validation error to the context."""
|
|
76
|
+
error = ValidationError(
|
|
77
|
+
message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
|
|
78
|
+
)
|
|
79
|
+
context.validation_errors.append(error)
|
|
80
|
+
|
|
58
81
|
def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
|
|
59
82
|
"""Validate SQL statement for safety issues.
|
|
60
83
|
|
|
@@ -66,7 +89,6 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
66
89
|
category = self._categorize_statement(expression)
|
|
67
90
|
operation = self._get_operation_type(expression)
|
|
68
91
|
|
|
69
|
-
# Check DDL restrictions
|
|
70
92
|
if category == StatementCategory.DDL and self.config.prevent_ddl:
|
|
71
93
|
if operation not in self.config.allowed_ddl_operations:
|
|
72
94
|
self.add_error(
|
|
@@ -77,7 +99,6 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
77
99
|
expression=expression,
|
|
78
100
|
)
|
|
79
101
|
|
|
80
|
-
# Check DML safety
|
|
81
102
|
elif category == StatementCategory.DML:
|
|
82
103
|
if operation in self.config.require_where_clause and not self._has_where_clause(expression):
|
|
83
104
|
self.add_error(
|
|
@@ -88,7 +109,6 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
88
109
|
expression=expression,
|
|
89
110
|
)
|
|
90
111
|
|
|
91
|
-
# Check affected row limits
|
|
92
112
|
if self.config.max_affected_rows:
|
|
93
113
|
estimated_rows = self._estimate_affected_rows(expression)
|
|
94
114
|
if estimated_rows > self.config.max_affected_rows:
|
|
@@ -100,7 +120,6 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
100
120
|
expression=expression,
|
|
101
121
|
)
|
|
102
122
|
|
|
103
|
-
# Check DCL restrictions
|
|
104
123
|
elif category == StatementCategory.DCL and self.config.prevent_dcl:
|
|
105
124
|
self.add_error(
|
|
106
125
|
context,
|
|
@@ -187,10 +206,8 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
187
206
|
|
|
188
207
|
where = expression.args.get("where")
|
|
189
208
|
if where:
|
|
190
|
-
# Check for primary key or unique conditions
|
|
191
209
|
if self._has_unique_condition(where):
|
|
192
210
|
return 1
|
|
193
|
-
# Check for indexed conditions
|
|
194
211
|
if self._has_indexed_condition(where):
|
|
195
212
|
return 100 # Rough estimate
|
|
196
213
|
|
|
@@ -230,8 +247,10 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
230
247
|
return False
|
|
231
248
|
# Look for common indexed column patterns
|
|
232
249
|
for condition in where.find_all(exp.Predicate):
|
|
233
|
-
if
|
|
234
|
-
|
|
250
|
+
if isinstance(condition, (exp.EQ, exp.GT, exp.GTE, exp.LT, exp.LTE, exp.NEQ)) and isinstance(
|
|
251
|
+
condition.left, exp.Column
|
|
252
|
+
):
|
|
253
|
+
col_name = condition.left.name.lower()
|
|
235
254
|
# Common indexed columns
|
|
236
255
|
if col_name in {"created_at", "updated_at", "email", "username", "status", "type"}:
|
|
237
256
|
return True
|
|
@@ -251,20 +270,16 @@ class DMLSafetyValidator(BaseValidator):
|
|
|
251
270
|
|
|
252
271
|
# For DML statements
|
|
253
272
|
if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
|
|
254
|
-
if
|
|
273
|
+
if expression.this:
|
|
255
274
|
table_expr = expression.this
|
|
256
275
|
if isinstance(table_expr, exp.Table):
|
|
257
276
|
tables.append(table_expr.name)
|
|
258
277
|
|
|
259
278
|
# For DDL statements
|
|
260
|
-
elif (
|
|
261
|
-
isinstance(expression, (exp.Create, exp.Drop, exp.Alter))
|
|
262
|
-
and hasattr(expression, "this")
|
|
263
|
-
and expression.this
|
|
264
|
-
):
|
|
279
|
+
elif isinstance(expression, (exp.Create, exp.Drop, exp.Alter)) and expression.this:
|
|
265
280
|
# For CREATE TABLE, the table is in expression.this.this
|
|
266
281
|
if isinstance(expression, exp.Create) and isinstance(expression.this, exp.Schema):
|
|
267
|
-
if
|
|
282
|
+
if expression.this.this:
|
|
268
283
|
table_expr = expression.this.this
|
|
269
284
|
if isinstance(table_expr, exp.Table):
|
|
270
285
|
tables.append(table_expr.name)
|
|
@@ -6,9 +6,9 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|
|
6
6
|
from sqlglot import exp
|
|
7
7
|
|
|
8
8
|
from sqlspec.exceptions import MissingParameterError, RiskLevel, SQLValidationError
|
|
9
|
-
from sqlspec.
|
|
10
|
-
from sqlspec.statement.pipelines.
|
|
11
|
-
from sqlspec.
|
|
9
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
10
|
+
from sqlspec.statement.pipelines.context import ValidationError
|
|
11
|
+
from sqlspec.utils.type_guards import is_dict
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
14
|
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
@@ -73,12 +73,15 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
73
73
|
config = context.config
|
|
74
74
|
param_info = context.parameter_info
|
|
75
75
|
|
|
76
|
-
#
|
|
76
|
+
# Check if parameters were normalized by looking for param_ placeholders
|
|
77
|
+
# This happens when Oracle numeric parameters (:1, :2) are normalized
|
|
78
|
+
is_normalized = param_info and any(p.name and p.name.startswith("param_") for p in param_info)
|
|
79
|
+
|
|
80
|
+
# First check parameter styles if configured (skip if normalized)
|
|
77
81
|
has_style_errors = False
|
|
78
|
-
if config.allowed_parameter_styles is not None and param_info:
|
|
82
|
+
if not is_normalized and config.allowed_parameter_styles is not None and param_info:
|
|
79
83
|
unique_styles = {p.style for p in param_info}
|
|
80
84
|
|
|
81
|
-
# Check for mixed styles first (before checking individual styles)
|
|
82
85
|
if len(unique_styles) > 1 and not config.allow_mixed_parameter_styles:
|
|
83
86
|
detected_style_strs = [str(s) for s in unique_styles]
|
|
84
87
|
detected_styles = ", ".join(sorted(detected_style_strs))
|
|
@@ -95,7 +98,6 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
95
98
|
context.validation_errors.append(error)
|
|
96
99
|
has_style_errors = True
|
|
97
100
|
|
|
98
|
-
# Check for disallowed styles
|
|
99
101
|
disallowed_styles = {str(s) for s in unique_styles if not config.validate_parameter_style(s)}
|
|
100
102
|
if disallowed_styles:
|
|
101
103
|
disallowed_str = ", ".join(sorted(disallowed_styles))
|
|
@@ -276,13 +278,84 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
276
278
|
) -> None:
|
|
277
279
|
"""Handle validation for named parameters."""
|
|
278
280
|
missing: list[str] = []
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
281
|
+
|
|
282
|
+
# Check if we have normalized parameters (e.g., param_0)
|
|
283
|
+
is_normalized = any(p.name and p.name.startswith("param_") for p in param_info)
|
|
284
|
+
|
|
285
|
+
if is_normalized and hasattr(context, "extra_info"):
|
|
286
|
+
# For normalized parameters, we need to check against the original placeholder mapping
|
|
287
|
+
placeholder_map = context.extra_info.get("placeholder_map", {})
|
|
288
|
+
|
|
289
|
+
# Check if we have Oracle numeric keys in merged_params
|
|
290
|
+
all_numeric_keys = all(key.isdigit() for key in merged_params)
|
|
291
|
+
|
|
292
|
+
if all_numeric_keys:
|
|
293
|
+
# Parameters were provided as list and converted to Oracle numeric dict {"1": val1, "2": val2}
|
|
294
|
+
for i, _p in enumerate(param_info):
|
|
295
|
+
normalized_name = f"param_{i}"
|
|
296
|
+
original_key = placeholder_map.get(normalized_name)
|
|
297
|
+
|
|
298
|
+
if original_key is not None:
|
|
299
|
+
# Check using the original key (e.g., "1", "2" for Oracle)
|
|
300
|
+
original_key_str = str(original_key)
|
|
301
|
+
if original_key_str not in merged_params or merged_params[original_key_str] is None:
|
|
302
|
+
if original_key_str.isdigit():
|
|
303
|
+
missing.append(f":{original_key}")
|
|
304
|
+
else:
|
|
305
|
+
missing.append(f":{original_key}")
|
|
306
|
+
else:
|
|
307
|
+
# Check if all params follow param_N pattern
|
|
308
|
+
all_param_keys = all(key.startswith("param_") and key[6:].isdigit() for key in merged_params)
|
|
309
|
+
|
|
310
|
+
if all_param_keys:
|
|
311
|
+
# This was originally a list converted to dict with param_N keys
|
|
312
|
+
for i, _p in enumerate(param_info):
|
|
313
|
+
normalized_name = f"param_{i}"
|
|
314
|
+
if normalized_name not in merged_params or merged_params[normalized_name] is None:
|
|
315
|
+
# Get original parameter style from placeholder map
|
|
316
|
+
original_key = placeholder_map.get(normalized_name)
|
|
317
|
+
if original_key is not None:
|
|
318
|
+
original_key_str = str(original_key)
|
|
319
|
+
if original_key_str.isdigit():
|
|
320
|
+
missing.append(f":{original_key}")
|
|
321
|
+
else:
|
|
322
|
+
missing.append(f":{original_key}")
|
|
323
|
+
else:
|
|
324
|
+
# Mixed parameter names, check using placeholder map
|
|
325
|
+
for i, _p in enumerate(param_info):
|
|
326
|
+
normalized_name = f"param_{i}"
|
|
327
|
+
original_key = placeholder_map.get(normalized_name)
|
|
328
|
+
|
|
329
|
+
if original_key is not None:
|
|
330
|
+
# For mixed params, check both normalized and original keys
|
|
331
|
+
original_key_str = str(original_key)
|
|
332
|
+
|
|
333
|
+
# First check with normalized name
|
|
334
|
+
found = normalized_name in merged_params and merged_params[normalized_name] is not None
|
|
335
|
+
|
|
336
|
+
# If not found, check with original key
|
|
337
|
+
if not found:
|
|
338
|
+
found = (
|
|
339
|
+
original_key_str in merged_params and merged_params[original_key_str] is not None
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if not found:
|
|
343
|
+
# Format the missing parameter based on original style
|
|
344
|
+
if original_key_str.isdigit():
|
|
345
|
+
# It was an Oracle numeric parameter (e.g., :1)
|
|
346
|
+
missing.append(f":{original_key}")
|
|
347
|
+
else:
|
|
348
|
+
# It was a named parameter (e.g., :status)
|
|
349
|
+
missing.append(f":{original_key}")
|
|
350
|
+
else:
|
|
351
|
+
# Regular parameter validation
|
|
352
|
+
for p in param_info:
|
|
353
|
+
param_name = p.name
|
|
354
|
+
if param_name not in merged_params or merged_params.get(param_name) is None:
|
|
355
|
+
is_synthetic = any(key.startswith(("arg_", "param_")) for key in merged_params)
|
|
356
|
+
is_named_style = p.style.value not in {"qmark", "numeric"}
|
|
357
|
+
if (not is_synthetic or is_named_style) and param_name:
|
|
358
|
+
missing.append(param_name)
|
|
286
359
|
|
|
287
360
|
if missing:
|
|
288
361
|
msg = f"Missing required parameters: {', '.join(missing)}"
|
|
@@ -18,7 +18,9 @@ from sqlglot.optimizer import (
|
|
|
18
18
|
)
|
|
19
19
|
|
|
20
20
|
from sqlspec.exceptions import RiskLevel
|
|
21
|
-
from sqlspec.
|
|
21
|
+
from sqlspec.protocols import ProcessorProtocol
|
|
22
|
+
from sqlspec.statement.pipelines.context import ValidationError
|
|
23
|
+
from sqlspec.utils.type_guards import has_expressions
|
|
22
24
|
|
|
23
25
|
if TYPE_CHECKING:
|
|
24
26
|
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
@@ -126,7 +128,7 @@ class PerformanceAnalysis:
|
|
|
126
128
|
potential_improvement: float = 0.0
|
|
127
129
|
|
|
128
130
|
|
|
129
|
-
class PerformanceValidator(
|
|
131
|
+
class PerformanceValidator(ProcessorProtocol):
|
|
130
132
|
"""Comprehensive query performance validator.
|
|
131
133
|
|
|
132
134
|
Validates query performance by detecting:
|
|
@@ -143,9 +145,31 @@ class PerformanceValidator(BaseValidator):
|
|
|
143
145
|
Args:
|
|
144
146
|
config: Configuration for performance validation
|
|
145
147
|
"""
|
|
146
|
-
super().__init__()
|
|
147
148
|
self.config = config or PerformanceConfig()
|
|
148
149
|
|
|
150
|
+
def process(
|
|
151
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
152
|
+
) -> "Optional[exp.Expression]":
|
|
153
|
+
"""Process the expression for validation (implements ProcessorProtocol)."""
|
|
154
|
+
if expression is None:
|
|
155
|
+
return None
|
|
156
|
+
self.validate(expression, context)
|
|
157
|
+
return expression
|
|
158
|
+
|
|
159
|
+
def add_error(
|
|
160
|
+
self,
|
|
161
|
+
context: "SQLProcessingContext",
|
|
162
|
+
message: str,
|
|
163
|
+
code: str,
|
|
164
|
+
risk_level: RiskLevel,
|
|
165
|
+
expression: "Optional[exp.Expression]" = None,
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Add a validation error to the context."""
|
|
168
|
+
error = ValidationError(
|
|
169
|
+
message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
|
|
170
|
+
)
|
|
171
|
+
context.validation_errors.append(error)
|
|
172
|
+
|
|
149
173
|
def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
|
|
150
174
|
"""Validate SQL statement for performance issues.
|
|
151
175
|
|
|
@@ -167,7 +191,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
167
191
|
if self.config.enable_optimization_analysis:
|
|
168
192
|
self._analyze_optimization_opportunities(expression, analysis, context)
|
|
169
193
|
|
|
170
|
-
# Check for cartesian products
|
|
171
194
|
if self.config.warn_on_cartesian:
|
|
172
195
|
cartesian_issues = self._check_cartesian_products(analysis)
|
|
173
196
|
for issue in cartesian_issues:
|
|
@@ -179,7 +202,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
179
202
|
expression=expression,
|
|
180
203
|
)
|
|
181
204
|
|
|
182
|
-
# Check join complexity
|
|
183
205
|
if analysis.join_count > self.config.max_joins:
|
|
184
206
|
self.add_error(
|
|
185
207
|
context,
|
|
@@ -189,7 +211,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
189
211
|
expression=expression,
|
|
190
212
|
)
|
|
191
213
|
|
|
192
|
-
# Check subquery depth
|
|
193
214
|
if analysis.max_subquery_depth > self.config.max_subqueries:
|
|
194
215
|
self.add_error(
|
|
195
216
|
context,
|
|
@@ -213,7 +234,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
213
234
|
# Calculate overall complexity score
|
|
214
235
|
complexity_score = self._calculate_complexity(analysis)
|
|
215
236
|
|
|
216
|
-
# Build metadata
|
|
217
237
|
context.metadata[self.__class__.__name__] = {
|
|
218
238
|
"complexity_score": complexity_score,
|
|
219
239
|
"join_analysis": {
|
|
@@ -260,7 +280,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
260
280
|
analysis.current_subquery_depth = max(analysis.current_subquery_depth, depth + 1)
|
|
261
281
|
analysis.max_subquery_depth = max(analysis.max_subquery_depth, analysis.current_subquery_depth)
|
|
262
282
|
|
|
263
|
-
# Check if correlated
|
|
264
283
|
if self._is_correlated_subquery(expr):
|
|
265
284
|
analysis.correlated_subqueries += 1
|
|
266
285
|
|
|
@@ -270,7 +289,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
270
289
|
join_type = expr.args.get("kind", "INNER").upper()
|
|
271
290
|
analysis.join_types[join_type] = analysis.join_types.get(join_type, 0) + 1
|
|
272
291
|
|
|
273
|
-
# Extract join condition
|
|
274
292
|
condition = expr.args.get("on")
|
|
275
293
|
left_table = self._get_table_name(expr.parent) if expr.parent else "unknown"
|
|
276
294
|
right_table = self._get_table_name(expr.this)
|
|
@@ -287,10 +305,10 @@ class PerformanceValidator(BaseValidator):
|
|
|
287
305
|
analysis.where_conditions += len(list(expr.find_all(exp.Predicate)))
|
|
288
306
|
|
|
289
307
|
elif isinstance(expr, exp.Group):
|
|
290
|
-
analysis.group_by_columns += len(expr.expressions) if
|
|
308
|
+
analysis.group_by_columns += len(expr.expressions) if has_expressions(expr) else 0
|
|
291
309
|
|
|
292
310
|
elif isinstance(expr, exp.Order):
|
|
293
|
-
analysis.order_by_columns += len(expr.expressions) if
|
|
311
|
+
analysis.order_by_columns += len(expr.expressions) if has_expressions(expr) else 0
|
|
294
312
|
|
|
295
313
|
elif isinstance(expr, exp.Distinct):
|
|
296
314
|
analysis.distinct_operations += 1
|
|
@@ -302,13 +320,15 @@ class PerformanceValidator(BaseValidator):
|
|
|
302
320
|
analysis.select_star_count += 1
|
|
303
321
|
|
|
304
322
|
# Recursive traversal
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
323
|
+
expr_args = getattr(expr, "args", None)
|
|
324
|
+
if expr_args is not None and isinstance(expr_args, dict):
|
|
325
|
+
for child in expr_args.values():
|
|
326
|
+
if isinstance(child, exp.Expression):
|
|
327
|
+
self._analyze_expression(child, analysis, depth)
|
|
328
|
+
elif isinstance(child, list):
|
|
329
|
+
for item in child:
|
|
330
|
+
if isinstance(item, exp.Expression):
|
|
331
|
+
self._analyze_expression(item, analysis, depth)
|
|
312
332
|
|
|
313
333
|
def _check_cartesian_products(self, analysis: PerformanceAnalysis) -> "list[PerformanceIssue]":
|
|
314
334
|
"""Detect potential cartesian products from join analysis.
|
|
@@ -335,11 +355,9 @@ class PerformanceValidator(BaseValidator):
|
|
|
335
355
|
)
|
|
336
356
|
)
|
|
337
357
|
else:
|
|
338
|
-
# Build join graph
|
|
339
358
|
join_graph[condition.left_table].add(condition.right_table)
|
|
340
359
|
join_graph[condition.right_table].add(condition.left_table)
|
|
341
360
|
|
|
342
|
-
# Check for disconnected tables (implicit cartesian)
|
|
343
361
|
if len(analysis.tables) > 1:
|
|
344
362
|
connected = self._find_connected_components(join_graph, analysis.tables)
|
|
345
363
|
if len(connected) > 1:
|
|
@@ -595,7 +613,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
595
613
|
|
|
596
614
|
for opt_type, optimizer, description in optimizations:
|
|
597
615
|
try:
|
|
598
|
-
# Apply the optimization
|
|
599
616
|
optimized = optimizer(expression.copy(), dialect=context.dialect) # type: ignore[operator]
|
|
600
617
|
|
|
601
618
|
if optimized is None:
|
|
@@ -623,7 +640,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
623
640
|
else:
|
|
624
641
|
improvement = 0.0
|
|
625
642
|
|
|
626
|
-
# Only add if improvement meets threshold
|
|
627
643
|
if improvement >= self.config.optimization_threshold:
|
|
628
644
|
opportunities.append(
|
|
629
645
|
OptimizationOpportunity(
|
|
@@ -636,7 +652,6 @@ class PerformanceValidator(BaseValidator):
|
|
|
636
652
|
)
|
|
637
653
|
)
|
|
638
654
|
|
|
639
|
-
# Update the best optimization if this is better
|
|
640
655
|
if improvement > cumulative_improvement:
|
|
641
656
|
best_optimized = optimized
|
|
642
657
|
cumulative_improvement = improvement
|