sql-error-categorizer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,393 @@
1
+ import difflib
2
+ import re
3
+ import sqlparse
4
+ import sqlparse.keywords
5
+ from typing import Callable
6
+ from sqlglot import exp
7
+
8
+ from .base import BaseDetector, DetectedError
9
+ from ..query import Query
10
+ from ..sql_errors import SqlErrors
11
+ from ..catalog import Catalog
12
+
13
+
14
+ class ComplicationDetector(BaseDetector):
15
+ def __init__(self,
16
+ *,
17
+ query: Query,
18
+ update_query: Callable[[str, str | None], None],
19
+ solutions: list[Query] = [],
20
+ ):
21
+ super().__init__(
22
+ query=query,
23
+ solutions=solutions,
24
+ update_query=update_query,
25
+ )
26
+
27
+ def run(self) -> list[DetectedError]:
28
+ '''
29
+ Executes all complication checks and returns a list of identified misconceptions.
30
+ '''
31
+
32
+ results: list[DetectedError] = super().run()
33
+
34
+ checks = [
35
+ self.com_83_unnecessary_distinct_in_select_clause,
36
+ self.com_84_unnecessary_join,
37
+ self.com_85_unused_correlation_name,
38
+ self.com_86_correlation_names_are_always_identical,
39
+ self.com_87_unnecessary_general_comparison_operator,
40
+ self.com_88_like_without_wildcards,
41
+ self.com_89_unnecessarily_complicated_select_in_exists_subquery,
42
+ self.com_90_in_exists_can_be_replaced_by_comparison,
43
+ self.com_91_unnecessary_aggregate_function,
44
+ self.com_92_unnecessary_distinct_in_aggregate_function,
45
+ self.com_93_unnecessary_argument_of_count,
46
+ self.com_94_unnecessary_group_by_in_exists_subquery,
47
+ self.com_95_group_by_with_singleton_groups,
48
+ self.com_96_group_by_with_only_a_single_group,
49
+ self.com_97_group_by_can_be_replaced_by_distinct,
50
+ self.com_98_union_can_be_replaced_by_or,
51
+ self.com_99_complication_unnecessary_column_in_order_by_clause,
52
+ self.com_100_order_by_in_subquery,
53
+ self.com_101_inefficient_having,
54
+ self.com_102_inefficient_union,
55
+ self.com_103_condition_in_the_subquery_can_be_moved_up,
56
+ self.com_104_condition_on_left_table_in_left_outer_join,
57
+ self.com_105_outer_join_can_be_replaced_by_inner_join,
58
+ ]
59
+
60
+ for chk in checks:
61
+ results.extend(chk())
62
+
63
+ return results
64
+
65
+ # TODO: refactor
66
+ def com_83_unnecessary_distinct_in_select_clause(self) -> list[DetectedError]:
67
+ '''
68
+ Flags the unnecessary use of DISTINCT by comparing the proposed query
69
+ against the correct solution.
70
+ '''
71
+ return []
72
+
73
+ results = []
74
+ if not self.q_ast or not self.s_ast:
75
+ return results
76
+
77
+ # Check if the proposed query has a DISTINCT clause.
78
+ # This can be a boolean `True` or a Dictionary node for `DISTINCT(...)`.
79
+ q_args = self.q_ast.get('args', {})
80
+ q_has_distinct = q_args.get('distinct') not in [None, False]
81
+
82
+ # Check if the correct solution has a DISTINCT clause.
83
+ s_args = self.s_ast.get('args', {})
84
+ s_has_distinct = s_args.get('distinct') not in [None, False]
85
+
86
+ # If the user's query has DISTINCT but the solution doesn't, it's unnecessary.
87
+ if q_has_distinct and not s_has_distinct:
88
+ results.append((
89
+ SqlErrors.COM_83_UNNECESSARY_DISTINCT_IN_SELECT_CLAUSE,
90
+ "The DISTINCT keyword is used unnecessarily and is not present in the optimal solution."
91
+ ))
92
+
93
+ return results
94
+
95
+ # TODO: refactor
96
+ def com_84_unnecessary_join(self) -> list[DetectedError]:
97
+ '''
98
+ Flags a query that joins to a table not present in the correct solution.
99
+ '''
100
+ return []
101
+
102
+ results = []
103
+ if not self.q_ast or not self.s_ast:
104
+ return results
105
+
106
+ q_tables = self._get_from_tables(self.q_ast)
107
+ s_tables = self._get_from_tables(self.s_ast)
108
+
109
+ q_tables_set = {t.lower() for t in q_tables}
110
+ s_tables_set = {t.lower() for t in s_tables}
111
+
112
+ extraneous_tables = q_tables_set - s_tables_set
113
+
114
+ if extraneous_tables:
115
+ original_q_tables = self._get_from_tables(self.q_ast, with_alias=True)
116
+ for table_name_lower in extraneous_tables:
117
+ # Find the original table name (with alias if it was used) to report back
118
+ original_table_name = next((t for t in original_q_tables if t.lower().startswith(table_name_lower)), table_name_lower)
119
+ results.append((
120
+ SqlErrors.COM_84_UNNECESSARY_JOIN,
121
+ f"Unnecessary JOIN: The table '{original_table_name}' is not needed to answer the query."
122
+ ))
123
+
124
+ return results
125
+
126
+ # TODO: implement
127
+ def com_85_unused_correlation_name(self) -> list[DetectedError]:
128
+ return []
129
+
130
+ # TODO: implement
131
+ def com_86_correlation_names_are_always_identical(self) -> list[DetectedError]:
132
+ return []
133
+
134
+ # TODO: implement
135
+ def com_87_unnecessary_general_comparison_operator(self) -> list[DetectedError]:
136
+ return []
137
+
138
+ def com_88_like_without_wildcards(self) -> list[DetectedError]:
139
+ '''
140
+ Flags queries where the LIKE operator is used without wildcards ('%' or '_').
141
+ This indicates a potential misunderstanding, where the '=' operator should
142
+ have been used instead.
143
+ '''
144
+ results: list[DetectedError] = []
145
+
146
+ for select in self.query.selects:
147
+ ast = select.ast
148
+
149
+ if not ast:
150
+ continue
151
+
152
+ for like in ast.find_all(exp.Like):
153
+ pattern_expr = like.args.get('expression')
154
+
155
+ if not pattern_expr:
156
+ # Malformed LIKE expression
157
+ continue
158
+
159
+ if not isinstance(pattern_expr, exp.Literal):
160
+ # Some other expression type, e.g., a column reference
161
+ continue
162
+
163
+ pattern_value = pattern_expr.this
164
+ if '%' not in pattern_value and '_' not in pattern_value:
165
+ full_expression = str(like)
166
+
167
+ results.append(DetectedError(SqlErrors.COM_88_LIKE_WITHOUT_WILDCARDS, (full_expression,)))
168
+
169
+ return results
170
+
171
+ # TODO: implement
172
+ def com_89_unnecessarily_complicated_select_in_exists_subquery(self) -> list[DetectedError]:
173
+ return []
174
+
175
+ # TODO: implement
176
+ def com_90_in_exists_can_be_replaced_by_comparison(self) -> list[DetectedError]:
177
+ return []
178
+
179
+ # TODO: implement
180
+ def com_91_unnecessary_aggregate_function(self) -> list[DetectedError]:
181
+ return []
182
+
183
+ # TODO: implement
184
+ def com_92_unnecessary_distinct_in_aggregate_function(self) -> list[DetectedError]:
185
+ return []
186
+
187
+ # TODO: implement
188
+ def com_93_unnecessary_argument_of_count(self) -> list[DetectedError]:
189
+ return []
190
+
191
+ # TODO: implement
192
+ def com_94_unnecessary_group_by_in_exists_subquery(self) -> list[DetectedError]:
193
+ return []
194
+
195
+ # TODO: implement
196
+ def com_95_group_by_with_singleton_groups(self) -> list[DetectedError]:
197
+ return []
198
+
199
+ # TODO: implement
200
+ def com_96_group_by_with_only_a_single_group(self) -> list[DetectedError]:
201
+ return []
202
+
203
+ # TODO: implement
204
+ def com_97_group_by_can_be_replaced_by_distinct(self) -> list[DetectedError]:
205
+ return []
206
+
207
+ # TODO: implement
208
+ def com_98_union_can_be_replaced_by_or(self) -> list[DetectedError]:
209
+ return []
210
+
211
+ # TODO: refactor
212
+ def com_99_complication_unnecessary_column_in_order_by_clause(self) -> list[DetectedError]:
213
+ '''
214
+ Flags when the ORDER BY clause contains unnecessary columns in addition
215
+ to the required ones.
216
+ '''
217
+ return []
218
+
219
+ results = []
220
+ if not self.q_ast or not self.s_ast:
221
+ return results
222
+
223
+ q_orderby_cols = self._get_orderby_columns(self.q_ast)
224
+ s_orderby_cols = self._get_orderby_columns(self.s_ast)
225
+
226
+ q_cols_set = {col.lower() for col, direction in q_orderby_cols}
227
+ s_cols_set = {col.lower() for col, direction in s_orderby_cols}
228
+
229
+ if s_cols_set and s_cols_set.issubset(q_cols_set) and len(q_cols_set) > len(s_cols_set):
230
+ unnecessary_cols = q_cols_set - s_cols_set
231
+ for col_lower in unnecessary_cols:
232
+ original_col = next((col for col, direction in q_orderby_cols if col.lower() == col_lower), col_lower)
233
+ results.append((
234
+ SqlErrors.COM_99_UNNECESSARY_COLUMN_IN_ORDER_BY_CLAUSE,
235
+ f"Unnecessary column in ORDER BY clause: '{original_col}' is not needed for sorting."
236
+ ))
237
+
238
+ return results
239
+
240
+ # TODO: implement
241
+ def com_100_order_by_in_subquery(self) -> list[DetectedError]:
242
+ return []
243
+
244
+ # TODO: implement
245
+ def com_101_inefficient_having(self) -> list[DetectedError]:
246
+ return []
247
+
248
+ # TODO: implement
249
+ def com_102_inefficient_union(self) -> list[DetectedError]:
250
+ return []
251
+
252
+ # TODO: implement
253
+ def com_103_condition_in_the_subquery_can_be_moved_up(self) -> list[DetectedError]:
254
+ return []
255
+
256
+ # TODO: implement
257
+ def com_104_condition_on_left_table_in_left_outer_join(self) -> list[DetectedError]:
258
+ return []
259
+
260
+ # TODO: implement
261
+ def com_105_outer_join_can_be_replaced_by_inner_join(self) -> list[DetectedError]:
262
+ return []
263
+
264
+
265
+ #region Utility methods
266
+ def _get_select_columns(self, ast: dict) -> list:
267
+ '''
268
+ Extracts a list of simple column names from a SELECT query's AST.
269
+ '''
270
+ columns = []
271
+ if not ast:
272
+ return columns
273
+
274
+ select_expressions = ast.get('args', {}).get('expressions', [])
275
+
276
+ for expr_node in select_expressions:
277
+ col_name = self._find_underlying_column(expr_node)
278
+ if col_name:
279
+ columns.append(col_name)
280
+
281
+ return columns
282
+ def _find_underlying_column(self, node: dict):
283
+ '''
284
+ Recursively traverses an expression node to find the underlying column identifier.
285
+ '''
286
+ if not isinstance(node, dict):
287
+ return None
288
+
289
+ node_class = node.get('class')
290
+
291
+ if node_class == 'Paren':
292
+ return self._find_underlying_column(node.get('args', {}).get('this'))
293
+
294
+ if node_class == 'Column':
295
+ try:
296
+ return node['args']['expression']['args']['this']
297
+ except (KeyError, TypeError):
298
+ try:
299
+ return node['args']['this']['args']['this']
300
+ except (KeyError, TypeError):
301
+ return None
302
+
303
+ if node_class == 'Alias':
304
+ return self._find_underlying_column(node.get('args', {}).get('this'))
305
+ def _get_from_tables(self, ast: dict, with_alias=False) -> list:
306
+ '''
307
+ Extracts a list of all table names from the FROM and JOIN clauses of a query's AST.
308
+ '''
309
+ tables = []
310
+ if not ast:
311
+ return tables
312
+
313
+ args = ast.get('args', {})
314
+
315
+ # 1. Process the main table from the 'from' clause
316
+ from_node = args.get('from')
317
+ if from_node:
318
+ # The actual table data is inside the 'this' argument of the 'From' node
319
+ main_table_node = from_node.get('args', {}).get('this')
320
+ if main_table_node:
321
+ self._collect_tables_recursive(main_table_node, tables, with_alias)
322
+
323
+ # 2. Process all tables from the 'joins' list
324
+ join_nodes = args.get('joins', [])
325
+ for join_node in join_nodes:
326
+ self._collect_tables_recursive(join_node, tables, with_alias)
327
+
328
+ return list(set(tables))
329
+ def _collect_tables_recursive(self, node: dict, tables: list, with_alias=False):
330
+ '''
331
+ Recursively traverses a FROM clause node (including joins) to collect table names.
332
+ '''
333
+ if not isinstance(node, dict):
334
+ return
335
+
336
+ node_class = node.get('class')
337
+
338
+ # This part handles aliased tables (e.g., "customer c") and regular tables
339
+ if node_class == 'Alias':
340
+ underlying_node = node.get('args', {}).get('this')
341
+ # Recurse in case the alias is on a subquery or another join
342
+ self._collect_tables_recursive(underlying_node, tables, with_alias)
343
+
344
+ elif node_class == 'Table':
345
+ try:
346
+ # The AST nests identifiers, so we go deep to get the name
347
+ table_name = node['args']['this']['args']['this']
348
+ alias_node = node.get('args', {}).get('alias')
349
+ if with_alias and alias_node:
350
+ alias_name = alias_node.get('args', {}).get('this', {}).get('args', {}).get('this')
351
+ tables.append(f"{table_name} AS {alias_name}")
352
+ else:
353
+ tables.append(table_name)
354
+ except (KeyError, TypeError):
355
+ pass
356
+
357
+ # This part handles Join nodes found in the 'joins' list
358
+ elif node_class == 'Join':
359
+ # The joined table is in the 'this' argument of the Join node
360
+ self._collect_tables_recursive(node.get('args', {}).get('this'), tables, with_alias)
361
+ # The other side of the join is already handled in the 'from' clause,
362
+ # but we check for 'expression' for other potential join structures.
363
+ if 'expression' in node.get('args', {}):
364
+ self._collect_tables_recursive(node.get('args', {}).get('expression'), tables, with_alias)
365
+ def _get_orderby_columns(self, ast: dict) -> list:
366
+ '''
367
+ Extracts a list of columns and their sort direction from an ORDER BY clause.
368
+ '''
369
+ orderby_terms = []
370
+ if not ast:
371
+ return orderby_terms
372
+
373
+ orderby_node = ast.get('args', {}).get('order')
374
+ if not orderby_node:
375
+ return orderby_terms
376
+
377
+ try:
378
+ for term_node in orderby_node['args']['expressions']:
379
+ if term_node.get('class') != 'Ordered':
380
+ continue
381
+
382
+ column_node = term_node.get('args', {}).get('this')
383
+
384
+ col_name = self._find_underlying_column(column_node)
385
+
386
+ if col_name:
387
+ direction = term_node.get('args', {}).get('direction', 'ASC').upper()
388
+ orderby_terms.append((col_name, direction))
389
+ except (KeyError, AttributeError):
390
+ return []
391
+
392
+ return orderby_terms
393
+ #endregion Utility methods