sqlchecker 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlchecker/__init__.py +58 -0
- sqlchecker/detectors/__init__.py +103 -0
- sqlchecker/detectors/base.py +44 -0
- sqlchecker/detectors/complications.py +375 -0
- sqlchecker/detectors/logical.py +732 -0
- sqlchecker/detectors/semantic.py +289 -0
- sqlchecker/detectors/syntax.py +1140 -0
- sqlchecker-0.3.1.dist-info/METADATA +153 -0
- sqlchecker-0.3.1.dist-info/RECORD +11 -0
- sqlchecker-0.3.1.dist-info/WHEEL +4 -0
- sqlchecker-0.3.1.dist-info/licenses/LICENSE +21 -0
sqlchecker/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
'''Detect and categorize SQL errors in queries.'''
|
|
2
|
+
|
|
3
|
+
# Hidden, internal use only
|
|
4
|
+
from .detectors import BaseDetector as _BaseDetector, Detector as _Detector
|
|
5
|
+
|
|
6
|
+
# Public API
|
|
7
|
+
from sqlerrors import SqlErrors
|
|
8
|
+
from sqlscope import Catalog, build_catalog, load_catalog, build_catalog_from_postgres, build_catalog_from_sql
|
|
9
|
+
from .detectors import SyntaxErrorDetector, SemanticErrorDetector, LogicalErrorDetector, ComplicationDetector, DetectedError
|
|
10
|
+
|
|
11
|
+
def get_errors(query_str: str,
|
|
12
|
+
solutions: list[str] = [],
|
|
13
|
+
catalog: Catalog = Catalog(),
|
|
14
|
+
search_path: str = 'public',
|
|
15
|
+
solution_search_path: str = 'public',
|
|
16
|
+
detectors: list[type[_BaseDetector]] = [
|
|
17
|
+
SyntaxErrorDetector,
|
|
18
|
+
SemanticErrorDetector,
|
|
19
|
+
LogicalErrorDetector,
|
|
20
|
+
ComplicationDetector
|
|
21
|
+
],
|
|
22
|
+
debug: bool = False) -> list[DetectedError]:
|
|
23
|
+
'''Detect SQL errors in the given query string.'''
|
|
24
|
+
det = _Detector(query_str,
|
|
25
|
+
solutions=solutions,
|
|
26
|
+
catalog=catalog,
|
|
27
|
+
search_path=search_path,
|
|
28
|
+
solution_search_path=solution_search_path,
|
|
29
|
+
debug=debug)
|
|
30
|
+
|
|
31
|
+
for detector in detectors:
|
|
32
|
+
det.add_detector(detector)
|
|
33
|
+
|
|
34
|
+
return det.run()
|
|
35
|
+
|
|
36
|
+
def get_error_types(query_str: str,
|
|
37
|
+
solutions: list[str] = [],
|
|
38
|
+
catalog: Catalog = Catalog(),
|
|
39
|
+
search_path: str = 'public',
|
|
40
|
+
solution_search_path: str = 'public',
|
|
41
|
+
detectors: list[type[_BaseDetector]] = [
|
|
42
|
+
SyntaxErrorDetector,
|
|
43
|
+
SemanticErrorDetector,
|
|
44
|
+
LogicalErrorDetector,
|
|
45
|
+
ComplicationDetector
|
|
46
|
+
],
|
|
47
|
+
debug: bool = False) -> set[SqlErrors]:
|
|
48
|
+
'''Detect SQL error types in the given query string.'''
|
|
49
|
+
|
|
50
|
+
detected_errors = get_errors(query_str,
|
|
51
|
+
solutions=solutions,
|
|
52
|
+
catalog=catalog,
|
|
53
|
+
search_path=search_path,
|
|
54
|
+
solution_search_path=solution_search_path,
|
|
55
|
+
detectors=detectors,
|
|
56
|
+
debug=debug)
|
|
57
|
+
|
|
58
|
+
return {error.error for error in detected_errors}
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
'''SQL error detectors.'''
|
|
2
|
+
|
|
3
|
+
from sqlscope import Query, Catalog
|
|
4
|
+
from .base import BaseDetector, DetectedError
|
|
5
|
+
|
|
6
|
+
# exported detectors
|
|
7
|
+
from .syntax import SyntaxErrorDetector
|
|
8
|
+
from .semantic import SemanticErrorDetector
|
|
9
|
+
from .logical import LogicalErrorDetector
|
|
10
|
+
from .complications import ComplicationDetector
|
|
11
|
+
|
|
12
|
+
class Detector:
|
|
13
|
+
'''Manages and runs SQL error detectors on a query.'''
|
|
14
|
+
def __init__(self,
|
|
15
|
+
query: str,
|
|
16
|
+
*,
|
|
17
|
+
search_path: str = 'public',
|
|
18
|
+
solution_search_path: str = 'public',
|
|
19
|
+
solutions: list[str] = [],
|
|
20
|
+
catalog: Catalog = Catalog(),
|
|
21
|
+
detectors: list[type[BaseDetector]] = [],
|
|
22
|
+
debug: bool = False):
|
|
23
|
+
|
|
24
|
+
# Context data: they don't need to be parsed again if the query changes
|
|
25
|
+
self.search_path = search_path
|
|
26
|
+
self.solution_search_path = solution_search_path
|
|
27
|
+
self.catalog = catalog
|
|
28
|
+
self.solutions = [Query(sol, catalog=self.catalog, search_path=self.solution_search_path) for sol in solutions]
|
|
29
|
+
self.detectors: list[BaseDetector] = []
|
|
30
|
+
self.debug = debug
|
|
31
|
+
|
|
32
|
+
self.set_query(query)
|
|
33
|
+
|
|
34
|
+
# NOTE: Add detectors after setting the query to ensure they are correctly initialized
|
|
35
|
+
for detector_cls in detectors:
|
|
36
|
+
self.add_detector(detector_cls)
|
|
37
|
+
|
|
38
|
+
def set_query(self, query: str, reason: str | None = None) -> None:
|
|
39
|
+
'''Set a new query, re-parse it, and update all detectors. Doesn't affect detected errors.'''
|
|
40
|
+
|
|
41
|
+
if self.debug:
|
|
42
|
+
print('=' * 20)
|
|
43
|
+
if reason:
|
|
44
|
+
print(f'Updating query ({reason}):\n{query}')
|
|
45
|
+
else:
|
|
46
|
+
print(f'Updating query:\n{query}')
|
|
47
|
+
print('=' * 20)
|
|
48
|
+
|
|
49
|
+
self.query = Query(query, catalog=self.catalog, search_path=self.search_path)
|
|
50
|
+
|
|
51
|
+
# Update all detectors with the new query and parse results
|
|
52
|
+
for detector in self.detectors:
|
|
53
|
+
detector.query = self.query
|
|
54
|
+
detector.update_query = lambda new_query, reason=None: self.set_query(new_query, reason)
|
|
55
|
+
|
|
56
|
+
def add_detector(self, detector_cls: type[BaseDetector]) -> None:
|
|
57
|
+
'''Add a detector instance to the list of detectors'''
|
|
58
|
+
|
|
59
|
+
# Make copies to avoid possible modifications during detection
|
|
60
|
+
# TODO: check if it's needed
|
|
61
|
+
detector = detector_cls(
|
|
62
|
+
query=self.query,
|
|
63
|
+
solutions=self.solutions,
|
|
64
|
+
update_query=lambda new_query, reason=None: self.set_query(new_query, reason),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.detectors.append(detector)
|
|
68
|
+
|
|
69
|
+
def run(self) -> list[DetectedError]:
|
|
70
|
+
'''
|
|
71
|
+
Run all detectors and return a list of detected errors.
|
|
72
|
+
This function can return duplicate errors, as well as additional information on the detected errors.
|
|
73
|
+
'''
|
|
74
|
+
|
|
75
|
+
if self.debug:
|
|
76
|
+
print('===== Query =====')
|
|
77
|
+
print(self.query.sql)
|
|
78
|
+
|
|
79
|
+
print('===== search_path =====')
|
|
80
|
+
print(self.search_path)
|
|
81
|
+
|
|
82
|
+
print('===== solution_search_path =====')
|
|
83
|
+
print(self.solution_search_path)
|
|
84
|
+
|
|
85
|
+
print('===== Solutions =====')
|
|
86
|
+
print('\n-----\n'.join(sol.sql for sol in self.solutions))
|
|
87
|
+
|
|
88
|
+
print('===== Catalog =====')
|
|
89
|
+
print(self.catalog)
|
|
90
|
+
|
|
91
|
+
results: list[DetectedError] = []
|
|
92
|
+
|
|
93
|
+
for detector in self.detectors:
|
|
94
|
+
errors = detector.run()
|
|
95
|
+
|
|
96
|
+
if self.debug:
|
|
97
|
+
print(f'===== Detected errors from {detector.__class__.__name__} =====')
|
|
98
|
+
for error in errors:
|
|
99
|
+
print(error)
|
|
100
|
+
|
|
101
|
+
results.extend(errors)
|
|
102
|
+
|
|
103
|
+
return results
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
'''Base classes for SQL error detectors.'''
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from sqlerrors import SqlErrors
|
|
8
|
+
from sqlscope.query import Query
|
|
9
|
+
|
|
10
|
+
@dataclass(repr=False)
|
|
11
|
+
class DetectedError:
|
|
12
|
+
'''Represents a detected SQL error with its type and associated data.'''
|
|
13
|
+
|
|
14
|
+
error: SqlErrors
|
|
15
|
+
data: tuple[Any, ...] = field(default_factory=tuple)
|
|
16
|
+
|
|
17
|
+
def __repr__(self):
|
|
18
|
+
return f"DetectedError({self.error.value} - {self.error.name}: {self.data})"
|
|
19
|
+
|
|
20
|
+
def __str__(self) -> str:
|
|
21
|
+
if self.data:
|
|
22
|
+
return f'[{self.error.value:3}] {self.error.name}: {self.data}'
|
|
23
|
+
return f'[{self.error.value:3}] {self.error.name}'
|
|
24
|
+
|
|
25
|
+
def __hash__(self) -> int:
|
|
26
|
+
return hash((self.error, self.data))
|
|
27
|
+
|
|
28
|
+
class BaseDetector(ABC):
|
|
29
|
+
'''Abstract base class for SQL error detectors.'''
|
|
30
|
+
|
|
31
|
+
def __init__(self, *,
|
|
32
|
+
query: Query,
|
|
33
|
+
solutions: list[Query] = [],
|
|
34
|
+
update_query: Callable[[str, str | None], None],
|
|
35
|
+
):
|
|
36
|
+
self.query = query
|
|
37
|
+
self.solutions = solutions
|
|
38
|
+
self.update_query = update_query
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def run(self) -> list[DetectedError]:
|
|
42
|
+
'''Run the detector and return a list of detected errors with their descriptions'''
|
|
43
|
+
return []
|
|
44
|
+
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
'''Detector for complications in SQL queries.'''
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from sqlglot import exp
|
|
5
|
+
from sqlerrors import SqlErrors
|
|
6
|
+
from sqlscope.catalog import ConstraintType, ConstraintColumn
|
|
7
|
+
from sqlscope import Query
|
|
8
|
+
from sqlscope import util
|
|
9
|
+
|
|
10
|
+
from .base import BaseDetector, DetectedError
|
|
11
|
+
|
|
12
|
+
class ComplicationDetector(BaseDetector):
|
|
13
|
+
'''Detector for complications in SQL queries.'''
|
|
14
|
+
|
|
15
|
+
def __init__(self,
|
|
16
|
+
*,
|
|
17
|
+
query: Query,
|
|
18
|
+
update_query: Callable[[str, str | None], None],
|
|
19
|
+
solutions: list[Query] = [],
|
|
20
|
+
):
|
|
21
|
+
super().__init__(
|
|
22
|
+
query=query,
|
|
23
|
+
solutions=solutions,
|
|
24
|
+
update_query=update_query,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def run(self) -> list[DetectedError]:
|
|
28
|
+
'''
|
|
29
|
+
Executes all complication checks and returns a list of identified misconceptions.
|
|
30
|
+
'''
|
|
31
|
+
|
|
32
|
+
results: list[DetectedError] = super().run()
|
|
33
|
+
|
|
34
|
+
checks = [
|
|
35
|
+
self.detect_82_unnecessary_complication, # ok
|
|
36
|
+
self.detect_83_unnecessary_distinct_in_select_clause, # ok
|
|
37
|
+
self.detect_84_unnecessary_table_reference, # TODO: refactor/implement
|
|
38
|
+
self.detect_85_unused_correlation_name, # TODO: implement
|
|
39
|
+
self.detect_86_tables_have_same_data, # TODO: implement
|
|
40
|
+
self.detect_125_correlation_name_identical_to_table_name, # TODO: implement
|
|
41
|
+
self.detect_87_unnecessary_general_comparison_operator, # TODO: implement
|
|
42
|
+
self.detect_88_like_without_wildcards, # ok
|
|
43
|
+
self.detect_89_unnecessarily_complicated_select_in_exists_subquery, # TODO: implement
|
|
44
|
+
self.detect_90_in_exists_can_be_replaced_by_comparison, # TODO: implement
|
|
45
|
+
self.detect_91_unnecessary_aggregate_function, # TODO: implement
|
|
46
|
+
self.detect_92_unnecessary_distinct_in_aggregate_function, # ok
|
|
47
|
+
self.detect_93_unnecessary_argument_of_count, # ok
|
|
48
|
+
self.detect_94_unnecessary_group_by_in_exists_subquery, # TODO: implement
|
|
49
|
+
self.detect_95_group_by_with_singleton_groups, # ok
|
|
50
|
+
self.detect_96_group_by_with_only_a_single_group, # TODO: implement
|
|
51
|
+
self.detect_97_group_by_can_be_replaced_by_distinct, # ok
|
|
52
|
+
self.detect_98_union_can_be_replaced_by_or, # TODO: implement
|
|
53
|
+
self.detect_99_unnecessary_column_in_order_by_clause, # TODO: refactor/implement
|
|
54
|
+
self.detect_100_order_by_in_subquery, # TODO: implement
|
|
55
|
+
self.detect_101_inefficient_having, # TODO: implement
|
|
56
|
+
self.detect_102_inefficient_union, # TODO: implement
|
|
57
|
+
self.detect_103_condition_in_the_subquery_can_be_moved_up, # TODO: implement
|
|
58
|
+
self.detect_104_outer_join_can_be_replaced_by_inner_join, # TODO: implement
|
|
59
|
+
self.detect_126_unused_cte, #
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
for chk in checks:
|
|
63
|
+
results.extend(chk())
|
|
64
|
+
|
|
65
|
+
return results
|
|
66
|
+
|
|
67
|
+
def detect_82_unnecessary_complication(self) -> list[DetectedError]:
|
|
68
|
+
'''NOTE: this is an umbrella term, so it can't be directly detected.'''
|
|
69
|
+
return []
|
|
70
|
+
|
|
71
|
+
def detect_83_unnecessary_distinct_in_select_clause(self) -> list[DetectedError]:
|
|
72
|
+
'''
|
|
73
|
+
Flags a SELECT DISTINCT clause that is unnecessary because the selected
|
|
74
|
+
columns are already unique due to existing constraints.
|
|
75
|
+
'''
|
|
76
|
+
result: list[DetectedError] = []
|
|
77
|
+
|
|
78
|
+
for select in self.query.selects:
|
|
79
|
+
if not select.distinct:
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
# Remove DISTINCT constraint
|
|
83
|
+
constraints = [c for c in select.output.unique_constraints if c.constraint_type != ConstraintType.DISTINCT]
|
|
84
|
+
|
|
85
|
+
if len(constraints) > 0:
|
|
86
|
+
result.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_SELECT_CLAUSE, (select.sql,)))
|
|
87
|
+
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
def detect_84_unnecessary_table_reference(self) -> list[DetectedError]:
|
|
91
|
+
'''
|
|
92
|
+
Flags a query that joins to a table not present in the correct solution.
|
|
93
|
+
'''
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
results = []
|
|
97
|
+
if not self.q_ast or not self.s_ast:
|
|
98
|
+
return results
|
|
99
|
+
|
|
100
|
+
q_tables = self._get_from_tables(self.q_ast)
|
|
101
|
+
s_tables = self._get_from_tables(self.s_ast)
|
|
102
|
+
|
|
103
|
+
q_tables_set = {t.lower() for t in q_tables}
|
|
104
|
+
s_tables_set = {t.lower() for t in s_tables}
|
|
105
|
+
|
|
106
|
+
extraneous_tables = q_tables_set - s_tables_set
|
|
107
|
+
|
|
108
|
+
if extraneous_tables:
|
|
109
|
+
original_q_tables = self._get_from_tables(self.q_ast, with_alias=True)
|
|
110
|
+
for table_name_lower in extraneous_tables:
|
|
111
|
+
# Find the original table name (with alias if it was used) to report back
|
|
112
|
+
original_table_name = next((t for t in original_q_tables if t.lower().startswith(table_name_lower)), table_name_lower)
|
|
113
|
+
results.append((
|
|
114
|
+
SqlErrors.UNNECESSARY_TABLE_REFERENCE,
|
|
115
|
+
f"Unnecessary JOIN: The table '{original_table_name}' is not needed to answer the query."
|
|
116
|
+
))
|
|
117
|
+
|
|
118
|
+
return results
|
|
119
|
+
|
|
120
|
+
def detect_85_unused_correlation_name(self) -> list[DetectedError]:
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
def detect_86_tables_have_same_data(self) -> list[DetectedError]:
|
|
124
|
+
return []
|
|
125
|
+
|
|
126
|
+
def detect_125_correlation_name_identical_to_table_name(self) -> list[DetectedError]:
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
def detect_87_unnecessary_general_comparison_operator(self) -> list[DetectedError]:
|
|
130
|
+
return []
|
|
131
|
+
|
|
132
|
+
def detect_88_like_without_wildcards(self) -> list[DetectedError]:
|
|
133
|
+
'''
|
|
134
|
+
Flags queries where the LIKE operator is used without wildcards ('%' or '_').
|
|
135
|
+
This indicates a potential misunderstanding, where the '=' operator should
|
|
136
|
+
have been used instead.
|
|
137
|
+
'''
|
|
138
|
+
results: list[DetectedError] = []
|
|
139
|
+
|
|
140
|
+
for select in self.query.selects:
|
|
141
|
+
ast = select.ast
|
|
142
|
+
|
|
143
|
+
if not ast:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
for like in ast.find_all(exp.Like):
|
|
147
|
+
pattern_expr = like.args.get('expression')
|
|
148
|
+
|
|
149
|
+
if not pattern_expr:
|
|
150
|
+
# Malformed LIKE expression
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
if not isinstance(pattern_expr, exp.Literal):
|
|
154
|
+
# Some other expression type, e.g., a column reference
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
pattern_value = pattern_expr.this
|
|
158
|
+
if '%' not in pattern_value and '_' not in pattern_value:
|
|
159
|
+
full_expression = str(like)
|
|
160
|
+
|
|
161
|
+
results.append(DetectedError(SqlErrors.LIKE_WITHOUT_WILDCARDS, (full_expression,)))
|
|
162
|
+
|
|
163
|
+
return results
|
|
164
|
+
|
|
165
|
+
def detect_89_unnecessarily_complicated_select_in_exists_subquery(self) -> list[DetectedError]:
|
|
166
|
+
return []
|
|
167
|
+
|
|
168
|
+
def detect_90_in_exists_can_be_replaced_by_comparison(self) -> list[DetectedError]:
|
|
169
|
+
return []
|
|
170
|
+
|
|
171
|
+
def detect_91_unnecessary_aggregate_function(self) -> list[DetectedError]:
|
|
172
|
+
return []
|
|
173
|
+
|
|
174
|
+
def detect_92_unnecessary_distinct_in_aggregate_function(self) -> list[DetectedError]:
|
|
175
|
+
'''MIN and MAX never require DISTINCT. For other aggregate functions, DISTINCT is unnecessary if the argument is unique.'''
|
|
176
|
+
|
|
177
|
+
results: list[DetectedError] = []
|
|
178
|
+
|
|
179
|
+
for select in self.query.selects:
|
|
180
|
+
select = select.strip_subqueries()
|
|
181
|
+
|
|
182
|
+
if not select.ast:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
for agg_func in select.ast.find_all(exp.AggFunc):
|
|
186
|
+
if not isinstance(agg_func.this, exp.Distinct):
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
if isinstance(agg_func, (exp.Min, exp.Max)):
|
|
190
|
+
results.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_AGGREGATE_FUNCTION, (str(agg_func),)))
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
arg_expr = agg_func.this.expressions # `.this` is the DISTINCT, `.expressions` are the arguments
|
|
194
|
+
if not arg_expr:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
for expr in arg_expr:
|
|
198
|
+
# Check if the argument is a constant literal
|
|
199
|
+
if isinstance(expr, exp.Literal):
|
|
200
|
+
results.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_AGGREGATE_FUNCTION, (str(agg_func),)))
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# Check if the argument is a column
|
|
204
|
+
if isinstance(expr, exp.Column):
|
|
205
|
+
column_name = util.ast.column.get_real_name(expr)
|
|
206
|
+
|
|
207
|
+
# Check if the column has a UNIQUE constraint
|
|
208
|
+
unique_constraints = [c for c in select.all_constraints if c.constraint_type == ConstraintType.UNIQUE]
|
|
209
|
+
for constraint in unique_constraints:
|
|
210
|
+
if { ConstraintColumn(column_name, table_idx=select._get_table_idx_for_column(expr)) } == constraint.columns:
|
|
211
|
+
results.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_AGGREGATE_FUNCTION, (str(agg_func),)))
|
|
212
|
+
break
|
|
213
|
+
return results
|
|
214
|
+
|
|
215
|
+
def detect_93_unnecessary_argument_of_count(self) -> list[DetectedError]:
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
def detect_94_unnecessary_group_by_in_exists_subquery(self) -> list[DetectedError]:
|
|
219
|
+
return []
|
|
220
|
+
|
|
221
|
+
def detect_95_group_by_with_singleton_groups(self) -> list[DetectedError]:
|
|
222
|
+
'''
|
|
223
|
+
Flags GROUP BY clauses on singleton groups due to the presence
|
|
224
|
+
of UNIQUE constraints on the grouped columns.
|
|
225
|
+
'''
|
|
226
|
+
results: list[DetectedError] = []
|
|
227
|
+
|
|
228
|
+
for select in self.query.selects:
|
|
229
|
+
if not select.group_by:
|
|
230
|
+
continue
|
|
231
|
+
|
|
232
|
+
group_by_constraint = next((c for c in select.all_constraints if c.constraint_type == ConstraintType.GROUP_BY), None)
|
|
233
|
+
if not group_by_constraint:
|
|
234
|
+
# No GROUP BY constraint found, meaning GROUP BY clause in invalid. Skip.
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
constraints = [c for c in select.all_constraints if c.constraint_type == ConstraintType.UNIQUE]
|
|
238
|
+
|
|
239
|
+
for constraint in constraints:
|
|
240
|
+
if constraint.columns.issubset(group_by_constraint.columns):
|
|
241
|
+
results.append(DetectedError(SqlErrors.GROUP_BY_WITH_SINGLETON_GROUPS, (group_by_constraint, constraint)))
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
return results
|
|
245
|
+
|
|
246
|
+
def detect_96_group_by_with_only_a_single_group(self) -> list[DetectedError]:
|
|
247
|
+
return []
|
|
248
|
+
|
|
249
|
+
def detect_97_group_by_can_be_replaced_by_distinct(self) -> list[DetectedError]:
|
|
250
|
+
'''
|
|
251
|
+
Flags GROUP BY clauses that can be replaced by SELECT DISTINCT.
|
|
252
|
+
This occurs when all selected columns are included in the GROUP BY clause
|
|
253
|
+
and there are no aggregate functions in the SELECT list.
|
|
254
|
+
'''
|
|
255
|
+
results: list[DetectedError] = []
|
|
256
|
+
|
|
257
|
+
for select in self.query.selects:
|
|
258
|
+
select = select.strip_subqueries()
|
|
259
|
+
|
|
260
|
+
if not select.group_by:
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
if not select.ast:
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
has_agg_func = False
|
|
267
|
+
for expression in select.ast.expressions:
|
|
268
|
+
if list(expression.find_all(exp.AggFunc)):
|
|
269
|
+
has_agg_func = True
|
|
270
|
+
break
|
|
271
|
+
|
|
272
|
+
if has_agg_func:
|
|
273
|
+
continue
|
|
274
|
+
|
|
275
|
+
select_columns: list[exp.Column] = []
|
|
276
|
+
for expression in select.ast.expressions:
|
|
277
|
+
columns = list(expression.find_all(exp.Column))
|
|
278
|
+
select_columns.extend(columns)
|
|
279
|
+
|
|
280
|
+
group_by_columns: list[exp.Column] = []
|
|
281
|
+
for expression in select.group_by:
|
|
282
|
+
columns = list(expression.find_all(exp.Column))
|
|
283
|
+
group_by_columns.extend(columns)
|
|
284
|
+
|
|
285
|
+
select_col_names = {(util.ast.column.get_real_name(col), select._get_table_idx_for_column(col)) for col in select_columns}
|
|
286
|
+
group_by_col_names = {(util.ast.column.get_real_name(col), select._get_table_idx_for_column(col)) for col in group_by_columns}
|
|
287
|
+
|
|
288
|
+
if select_col_names == group_by_col_names:
|
|
289
|
+
results.append(DetectedError(SqlErrors.GROUP_BY_CAN_BE_REPLACED_WITH_DISTINCT, (select_col_names,)))
|
|
290
|
+
|
|
291
|
+
return results
|
|
292
|
+
|
|
293
|
+
def detect_98_union_can_be_replaced_by_or(self) -> list[DetectedError]:
|
|
294
|
+
return []
|
|
295
|
+
|
|
296
|
+
def detect_99_unnecessary_column_in_order_by_clause(self) -> list[DetectedError]:
|
|
297
|
+
'''
|
|
298
|
+
Flags when the ORDER BY clause contains unnecessary columns in addition
|
|
299
|
+
to the required ones.
|
|
300
|
+
'''
|
|
301
|
+
return []
|
|
302
|
+
|
|
303
|
+
results = []
|
|
304
|
+
if not self.q_ast or not self.s_ast:
|
|
305
|
+
return results
|
|
306
|
+
|
|
307
|
+
q_orderby_cols = self._get_orderby_columns(self.q_ast)
|
|
308
|
+
s_orderby_cols = self._get_orderby_columns(self.s_ast)
|
|
309
|
+
|
|
310
|
+
q_cols_set = {col.lower() for col, direction in q_orderby_cols}
|
|
311
|
+
s_cols_set = {col.lower() for col, direction in s_orderby_cols}
|
|
312
|
+
|
|
313
|
+
if s_cols_set and s_cols_set.issubset(q_cols_set) and len(q_cols_set) > len(s_cols_set):
|
|
314
|
+
unnecessary_cols = q_cols_set - s_cols_set
|
|
315
|
+
for col_lower in unnecessary_cols:
|
|
316
|
+
original_col = next((col for col, direction in q_orderby_cols if col.lower() == col_lower), col_lower)
|
|
317
|
+
results.append((
|
|
318
|
+
SqlErrors.COM_99_UNNECESSARY_COLUMN_IN_ORDER_BY_CLAUSE,
|
|
319
|
+
f"Unnecessary column in ORDER BY clause: '{original_col}' is not needed for sorting."
|
|
320
|
+
))
|
|
321
|
+
|
|
322
|
+
return results
|
|
323
|
+
|
|
324
|
+
def detect_100_order_by_in_subquery(self) -> list[DetectedError]:
|
|
325
|
+
'''
|
|
326
|
+
Flags when a subquery contains an ORDER BY clause.
|
|
327
|
+
Subqueries both ORDER BY and LIMIT are considered valid.
|
|
328
|
+
'''
|
|
329
|
+
|
|
330
|
+
results: list[DetectedError] = []
|
|
331
|
+
|
|
332
|
+
# nested subqueries are checked multiple times, so track which have been checked
|
|
333
|
+
checked_subqueries: set[str] = set()
|
|
334
|
+
|
|
335
|
+
for select in self.query.selects:
|
|
336
|
+
for subquery, clause, depth in select.subqueries:
|
|
337
|
+
if subquery.sql in checked_subqueries:
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
checked_subqueries.add(subquery.sql)
|
|
341
|
+
if subquery.order_by and not subquery.limit:
|
|
342
|
+
results.append(DetectedError(SqlErrors.ORDER_BY_IN_SUBQUERY, (subquery.sql,)))
|
|
343
|
+
|
|
344
|
+
return results
|
|
345
|
+
|
|
346
|
+
def detect_101_inefficient_having(self) -> list[DetectedError]:
|
|
347
|
+
return []
|
|
348
|
+
|
|
349
|
+
def detect_102_inefficient_union(self) -> list[DetectedError]:
|
|
350
|
+
return []
|
|
351
|
+
|
|
352
|
+
def detect_103_condition_in_the_subquery_can_be_moved_up(self) -> list[DetectedError]:
|
|
353
|
+
return []
|
|
354
|
+
|
|
355
|
+
def detect_104_outer_join_can_be_replaced_by_inner_join(self) -> list[DetectedError]:
|
|
356
|
+
return []
|
|
357
|
+
|
|
358
|
+
def detect_126_unused_cte(self) -> list[DetectedError]:
|
|
359
|
+
results: list[DetectedError] = []
|
|
360
|
+
|
|
361
|
+
if not self.query.ctes:
|
|
362
|
+
return results
|
|
363
|
+
|
|
364
|
+
used_ctes: dict[int, bool] = {i: False for i in range(len(self.query.ctes))}
|
|
365
|
+
|
|
366
|
+
for select in self.query.selects:
|
|
367
|
+
for table in select.referenced_tables:
|
|
368
|
+
if table.cte_idx is not None:
|
|
369
|
+
used_ctes[table.cte_idx] = True
|
|
370
|
+
|
|
371
|
+
for cte_idx, used in used_ctes.items():
|
|
372
|
+
if not used:
|
|
373
|
+
results.append(DetectedError(SqlErrors.UNUSED_CTE, (self.query.ctes[cte_idx].sql,)))
|
|
374
|
+
|
|
375
|
+
return results
|