sql-error-categorizer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_error_categorizer/__init__.py +56 -0
- sql_error_categorizer/catalog/__init__.py +73 -0
- sql_error_categorizer/catalog/catalog.py +328 -0
- sql_error_categorizer/catalog/queries.py +60 -0
- sql_error_categorizer/detectors/__init__.py +88 -0
- sql_error_categorizer/detectors/base.py +39 -0
- sql_error_categorizer/detectors/complications.py +393 -0
- sql_error_categorizer/detectors/logical.py +708 -0
- sql_error_categorizer/detectors/semantic.py +493 -0
- sql_error_categorizer/detectors/syntax.py +1278 -0
- sql_error_categorizer/query/__init__.py +4 -0
- sql_error_categorizer/query/extractors.py +134 -0
- sql_error_categorizer/query/query.py +98 -0
- sql_error_categorizer/query/set_operations/__init__.py +150 -0
- sql_error_categorizer/query/set_operations/binary_set_operation.py +89 -0
- sql_error_categorizer/query/set_operations/select.py +361 -0
- sql_error_categorizer/query/set_operations/set_operation.py +45 -0
- sql_error_categorizer/query/smt.py +206 -0
- sql_error_categorizer/query/tokenized_sql.py +68 -0
- sql_error_categorizer/query/typechecking.py +242 -0
- sql_error_categorizer/query/util.py +27 -0
- sql_error_categorizer/sql_errors.py +112 -0
- sql_error_categorizer/util.py +101 -0
- sql_error_categorizer-0.1.0.dist-info/METADATA +149 -0
- sql_error_categorizer-0.1.0.dist-info/RECORD +27 -0
- sql_error_categorizer-0.1.0.dist-info/WHEEL +4 -0
- sql_error_categorizer-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import re
|
|
3
|
+
import sqlparse
|
|
4
|
+
import sqlparse.keywords
|
|
5
|
+
from typing import Callable
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
|
|
8
|
+
from .base import BaseDetector, DetectedError
|
|
9
|
+
from ..query import Query
|
|
10
|
+
from ..sql_errors import SqlErrors
|
|
11
|
+
from ..catalog import Catalog
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ComplicationDetector(BaseDetector):
|
|
15
|
+
def __init__(self,
|
|
16
|
+
*,
|
|
17
|
+
query: Query,
|
|
18
|
+
update_query: Callable[[str, str | None], None],
|
|
19
|
+
solutions: list[Query] = [],
|
|
20
|
+
):
|
|
21
|
+
super().__init__(
|
|
22
|
+
query=query,
|
|
23
|
+
solutions=solutions,
|
|
24
|
+
update_query=update_query,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def run(self) -> list[DetectedError]:
|
|
28
|
+
'''
|
|
29
|
+
Executes all complication checks and returns a list of identified misconceptions.
|
|
30
|
+
'''
|
|
31
|
+
|
|
32
|
+
results: list[DetectedError] = super().run()
|
|
33
|
+
|
|
34
|
+
checks = [
|
|
35
|
+
self.com_83_unnecessary_distinct_in_select_clause,
|
|
36
|
+
self.com_84_unnecessary_join,
|
|
37
|
+
self.com_85_unused_correlation_name,
|
|
38
|
+
self.com_86_correlation_names_are_always_identical,
|
|
39
|
+
self.com_87_unnecessary_general_comparison_operator,
|
|
40
|
+
self.com_88_like_without_wildcards,
|
|
41
|
+
self.com_89_unnecessarily_complicated_select_in_exists_subquery,
|
|
42
|
+
self.com_90_in_exists_can_be_replaced_by_comparison,
|
|
43
|
+
self.com_91_unnecessary_aggregate_function,
|
|
44
|
+
self.com_92_unnecessary_distinct_in_aggregate_function,
|
|
45
|
+
self.com_93_unnecessary_argument_of_count,
|
|
46
|
+
self.com_94_unnecessary_group_by_in_exists_subquery,
|
|
47
|
+
self.com_95_group_by_with_singleton_groups,
|
|
48
|
+
self.com_96_group_by_with_only_a_single_group,
|
|
49
|
+
self.com_97_group_by_can_be_replaced_by_distinct,
|
|
50
|
+
self.com_98_union_can_be_replaced_by_or,
|
|
51
|
+
self.com_99_complication_unnecessary_column_in_order_by_clause,
|
|
52
|
+
self.com_100_order_by_in_subquery,
|
|
53
|
+
self.com_101_inefficient_having,
|
|
54
|
+
self.com_102_inefficient_union,
|
|
55
|
+
self.com_103_condition_in_the_subquery_can_be_moved_up,
|
|
56
|
+
self.com_104_condition_on_left_table_in_left_outer_join,
|
|
57
|
+
self.com_105_outer_join_can_be_replaced_by_inner_join,
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
for chk in checks:
|
|
61
|
+
results.extend(chk())
|
|
62
|
+
|
|
63
|
+
return results
|
|
64
|
+
|
|
65
|
+
# TODO: refactor
|
|
66
|
+
def com_83_unnecessary_distinct_in_select_clause(self) -> list[DetectedError]:
|
|
67
|
+
'''
|
|
68
|
+
Flags the unnecessary use of DISTINCT by comparing the proposed query
|
|
69
|
+
against the correct solution.
|
|
70
|
+
'''
|
|
71
|
+
return []
|
|
72
|
+
|
|
73
|
+
results = []
|
|
74
|
+
if not self.q_ast or not self.s_ast:
|
|
75
|
+
return results
|
|
76
|
+
|
|
77
|
+
# Check if the proposed query has a DISTINCT clause.
|
|
78
|
+
# This can be a boolean `True` or a Dictionary node for `DISTINCT(...)`.
|
|
79
|
+
q_args = self.q_ast.get('args', {})
|
|
80
|
+
q_has_distinct = q_args.get('distinct') not in [None, False]
|
|
81
|
+
|
|
82
|
+
# Check if the correct solution has a DISTINCT clause.
|
|
83
|
+
s_args = self.s_ast.get('args', {})
|
|
84
|
+
s_has_distinct = s_args.get('distinct') not in [None, False]
|
|
85
|
+
|
|
86
|
+
# If the user's query has DISTINCT but the solution doesn't, it's unnecessary.
|
|
87
|
+
if q_has_distinct and not s_has_distinct:
|
|
88
|
+
results.append((
|
|
89
|
+
SqlErrors.COM_83_UNNECESSARY_DISTINCT_IN_SELECT_CLAUSE,
|
|
90
|
+
"The DISTINCT keyword is used unnecessarily and is not present in the optimal solution."
|
|
91
|
+
))
|
|
92
|
+
|
|
93
|
+
return results
|
|
94
|
+
|
|
95
|
+
# TODO: refactor
|
|
96
|
+
def com_84_unnecessary_join(self) -> list[DetectedError]:
|
|
97
|
+
'''
|
|
98
|
+
Flags a query that joins to a table not present in the correct solution.
|
|
99
|
+
'''
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
results = []
|
|
103
|
+
if not self.q_ast or not self.s_ast:
|
|
104
|
+
return results
|
|
105
|
+
|
|
106
|
+
q_tables = self._get_from_tables(self.q_ast)
|
|
107
|
+
s_tables = self._get_from_tables(self.s_ast)
|
|
108
|
+
|
|
109
|
+
q_tables_set = {t.lower() for t in q_tables}
|
|
110
|
+
s_tables_set = {t.lower() for t in s_tables}
|
|
111
|
+
|
|
112
|
+
extraneous_tables = q_tables_set - s_tables_set
|
|
113
|
+
|
|
114
|
+
if extraneous_tables:
|
|
115
|
+
original_q_tables = self._get_from_tables(self.q_ast, with_alias=True)
|
|
116
|
+
for table_name_lower in extraneous_tables:
|
|
117
|
+
# Find the original table name (with alias if it was used) to report back
|
|
118
|
+
original_table_name = next((t for t in original_q_tables if t.lower().startswith(table_name_lower)), table_name_lower)
|
|
119
|
+
results.append((
|
|
120
|
+
SqlErrors.COM_84_UNNECESSARY_JOIN,
|
|
121
|
+
f"Unnecessary JOIN: The table '{original_table_name}' is not needed to answer the query."
|
|
122
|
+
))
|
|
123
|
+
|
|
124
|
+
return results
|
|
125
|
+
|
|
126
|
+
# TODO: implement
|
|
127
|
+
def com_85_unused_correlation_name(self) -> list[DetectedError]:
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
# TODO: implement
|
|
131
|
+
def com_86_correlation_names_are_always_identical(self) -> list[DetectedError]:
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
# TODO: implement
|
|
135
|
+
def com_87_unnecessary_general_comparison_operator(self) -> list[DetectedError]:
|
|
136
|
+
return []
|
|
137
|
+
|
|
138
|
+
def com_88_like_without_wildcards(self) -> list[DetectedError]:
|
|
139
|
+
'''
|
|
140
|
+
Flags queries where the LIKE operator is used without wildcards ('%' or '_').
|
|
141
|
+
This indicates a potential misunderstanding, where the '=' operator should
|
|
142
|
+
have been used instead.
|
|
143
|
+
'''
|
|
144
|
+
results: list[DetectedError] = []
|
|
145
|
+
|
|
146
|
+
for select in self.query.selects:
|
|
147
|
+
ast = select.ast
|
|
148
|
+
|
|
149
|
+
if not ast:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
for like in ast.find_all(exp.Like):
|
|
153
|
+
pattern_expr = like.args.get('expression')
|
|
154
|
+
|
|
155
|
+
if not pattern_expr:
|
|
156
|
+
# Malformed LIKE expression
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
if not isinstance(pattern_expr, exp.Literal):
|
|
160
|
+
# Some other expression type, e.g., a column reference
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
pattern_value = pattern_expr.this
|
|
164
|
+
if '%' not in pattern_value and '_' not in pattern_value:
|
|
165
|
+
full_expression = str(like)
|
|
166
|
+
|
|
167
|
+
results.append(DetectedError(SqlErrors.COM_88_LIKE_WITHOUT_WILDCARDS, (full_expression,)))
|
|
168
|
+
|
|
169
|
+
return results
|
|
170
|
+
|
|
171
|
+
# TODO: implement
|
|
172
|
+
def com_89_unnecessarily_complicated_select_in_exists_subquery(self) -> list[DetectedError]:
|
|
173
|
+
return []
|
|
174
|
+
|
|
175
|
+
# TODO: implement
|
|
176
|
+
def com_90_in_exists_can_be_replaced_by_comparison(self) -> list[DetectedError]:
|
|
177
|
+
return []
|
|
178
|
+
|
|
179
|
+
# TODO: implement
|
|
180
|
+
def com_91_unnecessary_aggregate_function(self) -> list[DetectedError]:
|
|
181
|
+
return []
|
|
182
|
+
|
|
183
|
+
# TODO: implement
|
|
184
|
+
def com_92_unnecessary_distinct_in_aggregate_function(self) -> list[DetectedError]:
|
|
185
|
+
return []
|
|
186
|
+
|
|
187
|
+
# TODO: implement
|
|
188
|
+
def com_93_unnecessary_argument_of_count(self) -> list[DetectedError]:
|
|
189
|
+
return []
|
|
190
|
+
|
|
191
|
+
# TODO: implement
|
|
192
|
+
def com_94_unnecessary_group_by_in_exists_subquery(self) -> list[DetectedError]:
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
# TODO: implement
|
|
196
|
+
def com_95_group_by_with_singleton_groups(self) -> list[DetectedError]:
|
|
197
|
+
return []
|
|
198
|
+
|
|
199
|
+
# TODO: implement
|
|
200
|
+
def com_96_group_by_with_only_a_single_group(self) -> list[DetectedError]:
|
|
201
|
+
return []
|
|
202
|
+
|
|
203
|
+
# TODO: implement
|
|
204
|
+
def com_97_group_by_can_be_replaced_by_distinct(self) -> list[DetectedError]:
|
|
205
|
+
return []
|
|
206
|
+
|
|
207
|
+
# TODO: implement
|
|
208
|
+
def com_98_union_can_be_replaced_by_or(self) -> list[DetectedError]:
|
|
209
|
+
return []
|
|
210
|
+
|
|
211
|
+
# TODO: refactor
|
|
212
|
+
def com_99_complication_unnecessary_column_in_order_by_clause(self) -> list[DetectedError]:
|
|
213
|
+
'''
|
|
214
|
+
Flags when the ORDER BY clause contains unnecessary columns in addition
|
|
215
|
+
to the required ones.
|
|
216
|
+
'''
|
|
217
|
+
return []
|
|
218
|
+
|
|
219
|
+
results = []
|
|
220
|
+
if not self.q_ast or not self.s_ast:
|
|
221
|
+
return results
|
|
222
|
+
|
|
223
|
+
q_orderby_cols = self._get_orderby_columns(self.q_ast)
|
|
224
|
+
s_orderby_cols = self._get_orderby_columns(self.s_ast)
|
|
225
|
+
|
|
226
|
+
q_cols_set = {col.lower() for col, direction in q_orderby_cols}
|
|
227
|
+
s_cols_set = {col.lower() for col, direction in s_orderby_cols}
|
|
228
|
+
|
|
229
|
+
if s_cols_set and s_cols_set.issubset(q_cols_set) and len(q_cols_set) > len(s_cols_set):
|
|
230
|
+
unnecessary_cols = q_cols_set - s_cols_set
|
|
231
|
+
for col_lower in unnecessary_cols:
|
|
232
|
+
original_col = next((col for col, direction in q_orderby_cols if col.lower() == col_lower), col_lower)
|
|
233
|
+
results.append((
|
|
234
|
+
SqlErrors.COM_99_UNNECESSARY_COLUMN_IN_ORDER_BY_CLAUSE,
|
|
235
|
+
f"Unnecessary column in ORDER BY clause: '{original_col}' is not needed for sorting."
|
|
236
|
+
))
|
|
237
|
+
|
|
238
|
+
return results
|
|
239
|
+
|
|
240
|
+
# TODO: implement
|
|
241
|
+
def com_100_order_by_in_subquery(self) -> list[DetectedError]:
|
|
242
|
+
return []
|
|
243
|
+
|
|
244
|
+
# TODO: implement
|
|
245
|
+
def com_101_inefficient_having(self) -> list[DetectedError]:
|
|
246
|
+
return []
|
|
247
|
+
|
|
248
|
+
# TODO: implement
|
|
249
|
+
def com_102_inefficient_union(self) -> list[DetectedError]:
|
|
250
|
+
return []
|
|
251
|
+
|
|
252
|
+
# TODO: implement
|
|
253
|
+
def com_103_condition_in_the_subquery_can_be_moved_up(self) -> list[DetectedError]:
|
|
254
|
+
return []
|
|
255
|
+
|
|
256
|
+
# TODO: implement
|
|
257
|
+
def com_104_condition_on_left_table_in_left_outer_join(self) -> list[DetectedError]:
|
|
258
|
+
return []
|
|
259
|
+
|
|
260
|
+
# TODO: implement
|
|
261
|
+
def com_105_outer_join_can_be_replaced_by_inner_join(self) -> list[DetectedError]:
|
|
262
|
+
return []
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
#region Utility methods
|
|
266
|
+
def _get_select_columns(self, ast: dict) -> list:
|
|
267
|
+
'''
|
|
268
|
+
Extracts a list of simple column names from a SELECT query's AST.
|
|
269
|
+
'''
|
|
270
|
+
columns = []
|
|
271
|
+
if not ast:
|
|
272
|
+
return columns
|
|
273
|
+
|
|
274
|
+
select_expressions = ast.get('args', {}).get('expressions', [])
|
|
275
|
+
|
|
276
|
+
for expr_node in select_expressions:
|
|
277
|
+
col_name = self._find_underlying_column(expr_node)
|
|
278
|
+
if col_name:
|
|
279
|
+
columns.append(col_name)
|
|
280
|
+
|
|
281
|
+
return columns
|
|
282
|
+
def _find_underlying_column(self, node: dict):
|
|
283
|
+
'''
|
|
284
|
+
Recursively traverses an expression node to find the underlying column identifier.
|
|
285
|
+
'''
|
|
286
|
+
if not isinstance(node, dict):
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
node_class = node.get('class')
|
|
290
|
+
|
|
291
|
+
if node_class == 'Paren':
|
|
292
|
+
return self._find_underlying_column(node.get('args', {}).get('this'))
|
|
293
|
+
|
|
294
|
+
if node_class == 'Column':
|
|
295
|
+
try:
|
|
296
|
+
return node['args']['expression']['args']['this']
|
|
297
|
+
except (KeyError, TypeError):
|
|
298
|
+
try:
|
|
299
|
+
return node['args']['this']['args']['this']
|
|
300
|
+
except (KeyError, TypeError):
|
|
301
|
+
return None
|
|
302
|
+
|
|
303
|
+
if node_class == 'Alias':
|
|
304
|
+
return self._find_underlying_column(node.get('args', {}).get('this'))
|
|
305
|
+
def _get_from_tables(self, ast: dict, with_alias=False) -> list:
|
|
306
|
+
'''
|
|
307
|
+
Extracts a list of all table names from the FROM and JOIN clauses of a query's AST.
|
|
308
|
+
'''
|
|
309
|
+
tables = []
|
|
310
|
+
if not ast:
|
|
311
|
+
return tables
|
|
312
|
+
|
|
313
|
+
args = ast.get('args', {})
|
|
314
|
+
|
|
315
|
+
# 1. Process the main table from the 'from' clause
|
|
316
|
+
from_node = args.get('from')
|
|
317
|
+
if from_node:
|
|
318
|
+
# The actual table data is inside the 'this' argument of the 'From' node
|
|
319
|
+
main_table_node = from_node.get('args', {}).get('this')
|
|
320
|
+
if main_table_node:
|
|
321
|
+
self._collect_tables_recursive(main_table_node, tables, with_alias)
|
|
322
|
+
|
|
323
|
+
# 2. Process all tables from the 'joins' list
|
|
324
|
+
join_nodes = args.get('joins', [])
|
|
325
|
+
for join_node in join_nodes:
|
|
326
|
+
self._collect_tables_recursive(join_node, tables, with_alias)
|
|
327
|
+
|
|
328
|
+
return list(set(tables))
|
|
329
|
+
def _collect_tables_recursive(self, node: dict, tables: list, with_alias=False):
|
|
330
|
+
'''
|
|
331
|
+
Recursively traverses a FROM clause node (including joins) to collect table names.
|
|
332
|
+
'''
|
|
333
|
+
if not isinstance(node, dict):
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
node_class = node.get('class')
|
|
337
|
+
|
|
338
|
+
# This part handles aliased tables (e.g., "customer c") and regular tables
|
|
339
|
+
if node_class == 'Alias':
|
|
340
|
+
underlying_node = node.get('args', {}).get('this')
|
|
341
|
+
# Recurse in case the alias is on a subquery or another join
|
|
342
|
+
self._collect_tables_recursive(underlying_node, tables, with_alias)
|
|
343
|
+
|
|
344
|
+
elif node_class == 'Table':
|
|
345
|
+
try:
|
|
346
|
+
# The AST nests identifiers, so we go deep to get the name
|
|
347
|
+
table_name = node['args']['this']['args']['this']
|
|
348
|
+
alias_node = node.get('args', {}).get('alias')
|
|
349
|
+
if with_alias and alias_node:
|
|
350
|
+
alias_name = alias_node.get('args', {}).get('this', {}).get('args', {}).get('this')
|
|
351
|
+
tables.append(f"{table_name} AS {alias_name}")
|
|
352
|
+
else:
|
|
353
|
+
tables.append(table_name)
|
|
354
|
+
except (KeyError, TypeError):
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
# This part handles Join nodes found in the 'joins' list
|
|
358
|
+
elif node_class == 'Join':
|
|
359
|
+
# The joined table is in the 'this' argument of the Join node
|
|
360
|
+
self._collect_tables_recursive(node.get('args', {}).get('this'), tables, with_alias)
|
|
361
|
+
# The other side of the join is already handled in the 'from' clause,
|
|
362
|
+
# but we check for 'expression' for other potential join structures.
|
|
363
|
+
if 'expression' in node.get('args', {}):
|
|
364
|
+
self._collect_tables_recursive(node.get('args', {}).get('expression'), tables, with_alias)
|
|
365
|
+
def _get_orderby_columns(self, ast: dict) -> list:
|
|
366
|
+
'''
|
|
367
|
+
Extracts a list of columns and their sort direction from an ORDER BY clause.
|
|
368
|
+
'''
|
|
369
|
+
orderby_terms = []
|
|
370
|
+
if not ast:
|
|
371
|
+
return orderby_terms
|
|
372
|
+
|
|
373
|
+
orderby_node = ast.get('args', {}).get('order')
|
|
374
|
+
if not orderby_node:
|
|
375
|
+
return orderby_terms
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
for term_node in orderby_node['args']['expressions']:
|
|
379
|
+
if term_node.get('class') != 'Ordered':
|
|
380
|
+
continue
|
|
381
|
+
|
|
382
|
+
column_node = term_node.get('args', {}).get('this')
|
|
383
|
+
|
|
384
|
+
col_name = self._find_underlying_column(column_node)
|
|
385
|
+
|
|
386
|
+
if col_name:
|
|
387
|
+
direction = term_node.get('args', {}).get('direction', 'ASC').upper()
|
|
388
|
+
orderby_terms.append((col_name, direction))
|
|
389
|
+
except (KeyError, AttributeError):
|
|
390
|
+
return []
|
|
391
|
+
|
|
392
|
+
return orderby_terms
|
|
393
|
+
#endregion Utility methods
|