sqlchecker 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlchecker/__init__.py ADDED
@@ -0,0 +1,58 @@
1
+ '''Detect and categorize SQL errors in queries.'''
2
+
3
+ # Hidden, internal use only
4
+ from .detectors import BaseDetector as _BaseDetector, Detector as _Detector
5
+
6
+ # Public API
7
+ from sqlerrors import SqlErrors
8
+ from sqlscope import Catalog, build_catalog, load_catalog, build_catalog_from_postgres, build_catalog_from_sql
9
+ from .detectors import SyntaxErrorDetector, SemanticErrorDetector, LogicalErrorDetector, ComplicationDetector, DetectedError
10
+
11
+ def get_errors(query_str: str,
12
+ solutions: list[str] = [],
13
+ catalog: Catalog = Catalog(),
14
+ search_path: str = 'public',
15
+ solution_search_path: str = 'public',
16
+ detectors: list[type[_BaseDetector]] = [
17
+ SyntaxErrorDetector,
18
+ SemanticErrorDetector,
19
+ LogicalErrorDetector,
20
+ ComplicationDetector
21
+ ],
22
+ debug: bool = False) -> list[DetectedError]:
23
+ '''Detect SQL errors in the given query string.'''
24
+ det = _Detector(query_str,
25
+ solutions=solutions,
26
+ catalog=catalog,
27
+ search_path=search_path,
28
+ solution_search_path=solution_search_path,
29
+ debug=debug)
30
+
31
+ for detector in detectors:
32
+ det.add_detector(detector)
33
+
34
+ return det.run()
35
+
36
+ def get_error_types(query_str: str,
37
+ solutions: list[str] = [],
38
+ catalog: Catalog = Catalog(),
39
+ search_path: str = 'public',
40
+ solution_search_path: str = 'public',
41
+ detectors: list[type[_BaseDetector]] = [
42
+ SyntaxErrorDetector,
43
+ SemanticErrorDetector,
44
+ LogicalErrorDetector,
45
+ ComplicationDetector
46
+ ],
47
+ debug: bool = False) -> set[SqlErrors]:
48
+ '''Detect SQL error types in the given query string.'''
49
+
50
+ detected_errors = get_errors(query_str,
51
+ solutions=solutions,
52
+ catalog=catalog,
53
+ search_path=search_path,
54
+ solution_search_path=solution_search_path,
55
+ detectors=detectors,
56
+ debug=debug)
57
+
58
+ return {error.error for error in detected_errors}
@@ -0,0 +1,103 @@
1
+ '''SQL error detectors.'''
2
+
3
+ from sqlscope import Query, Catalog
4
+ from .base import BaseDetector, DetectedError
5
+
6
+ # exported detectors
7
+ from .syntax import SyntaxErrorDetector
8
+ from .semantic import SemanticErrorDetector
9
+ from .logical import LogicalErrorDetector
10
+ from .complications import ComplicationDetector
11
+
12
+ class Detector:
13
+ '''Manages and runs SQL error detectors on a query.'''
14
+ def __init__(self,
15
+ query: str,
16
+ *,
17
+ search_path: str = 'public',
18
+ solution_search_path: str = 'public',
19
+ solutions: list[str] = [],
20
+ catalog: Catalog = Catalog(),
21
+ detectors: list[type[BaseDetector]] = [],
22
+ debug: bool = False):
23
+
24
+ # Context data: they don't need to be parsed again if the query changes
25
+ self.search_path = search_path
26
+ self.solution_search_path = solution_search_path
27
+ self.catalog = catalog
28
+ self.solutions = [Query(sol, catalog=self.catalog, search_path=self.solution_search_path) for sol in solutions]
29
+ self.detectors: list[BaseDetector] = []
30
+ self.debug = debug
31
+
32
+ self.set_query(query)
33
+
34
+ # NOTE: Add detectors after setting the query to ensure they are correctly initialized
35
+ for detector_cls in detectors:
36
+ self.add_detector(detector_cls)
37
+
38
+ def set_query(self, query: str, reason: str | None = None) -> None:
39
+ '''Set a new query, re-parse it, and update all detectors. Doesn't affect detected errors.'''
40
+
41
+ if self.debug:
42
+ print('=' * 20)
43
+ if reason:
44
+ print(f'Updating query ({reason}):\n{query}')
45
+ else:
46
+ print(f'Updating query:\n{query}')
47
+ print('=' * 20)
48
+
49
+ self.query = Query(query, catalog=self.catalog, search_path=self.search_path)
50
+
51
+ # Update all detectors with the new query and parse results
52
+ for detector in self.detectors:
53
+ detector.query = self.query
54
+ detector.update_query = lambda new_query, reason=None: self.set_query(new_query, reason)
55
+
56
+ def add_detector(self, detector_cls: type[BaseDetector]) -> None:
57
+ '''Add a detector instance to the list of detectors'''
58
+
59
+ # Make copies to avoid possible modifications during detection
60
+ # TODO: check if it's needed
61
+ detector = detector_cls(
62
+ query=self.query,
63
+ solutions=self.solutions,
64
+ update_query=lambda new_query, reason=None: self.set_query(new_query, reason),
65
+ )
66
+
67
+ self.detectors.append(detector)
68
+
69
+ def run(self) -> list[DetectedError]:
70
+ '''
71
+ Run all detectors and return a list of detected errors.
72
+ This function can return duplicate errors, as well as additional information on the detected errors.
73
+ '''
74
+
75
+ if self.debug:
76
+ print('===== Query =====')
77
+ print(self.query.sql)
78
+
79
+ print('===== search_path =====')
80
+ print(self.search_path)
81
+
82
+ print('===== solution_search_path =====')
83
+ print(self.solution_search_path)
84
+
85
+ print('===== Solutions =====')
86
+ print('\n-----\n'.join(sol.sql for sol in self.solutions))
87
+
88
+ print('===== Catalog =====')
89
+ print(self.catalog)
90
+
91
+ results: list[DetectedError] = []
92
+
93
+ for detector in self.detectors:
94
+ errors = detector.run()
95
+
96
+ if self.debug:
97
+ print(f'===== Detected errors from {detector.__class__.__name__} =====')
98
+ for error in errors:
99
+ print(error)
100
+
101
+ results.extend(errors)
102
+
103
+ return results
@@ -0,0 +1,44 @@
1
+ '''Base classes for SQL error detectors.'''
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Callable
6
+
7
+ from sqlerrors import SqlErrors
8
+ from sqlscope.query import Query
9
+
10
+ @dataclass(repr=False)
11
+ class DetectedError:
12
+ '''Represents a detected SQL error with its type and associated data.'''
13
+
14
+ error: SqlErrors
15
+ data: tuple[Any, ...] = field(default_factory=tuple)
16
+
17
+ def __repr__(self):
18
+ return f"DetectedError({self.error.value} - {self.error.name}: {self.data})"
19
+
20
+ def __str__(self) -> str:
21
+ if self.data:
22
+ return f'[{self.error.value:3}] {self.error.name}: {self.data}'
23
+ return f'[{self.error.value:3}] {self.error.name}'
24
+
25
+ def __hash__(self) -> int:
26
+ return hash((self.error, self.data))
27
+
28
+ class BaseDetector(ABC):
29
+ '''Abstract base class for SQL error detectors.'''
30
+
31
+ def __init__(self, *,
32
+ query: Query,
33
+ solutions: list[Query] = [],
34
+ update_query: Callable[[str, str | None], None],
35
+ ):
36
+ self.query = query
37
+ self.solutions = solutions
38
+ self.update_query = update_query
39
+
40
+ @abstractmethod
41
+ def run(self) -> list[DetectedError]:
42
+ '''Run the detector and return a list of detected errors with their descriptions'''
43
+ return []
44
+
@@ -0,0 +1,375 @@
1
+ '''Detector for complications in SQL queries.'''
2
+
3
+ from typing import Callable
4
+ from sqlglot import exp
5
+ from sqlerrors import SqlErrors
6
+ from sqlscope.catalog import ConstraintType, ConstraintColumn
7
+ from sqlscope import Query
8
+ from sqlscope import util
9
+
10
+ from .base import BaseDetector, DetectedError
11
+
12
+ class ComplicationDetector(BaseDetector):
13
+ '''Detector for complications in SQL queries.'''
14
+
15
+ def __init__(self,
16
+ *,
17
+ query: Query,
18
+ update_query: Callable[[str, str | None], None],
19
+ solutions: list[Query] = [],
20
+ ):
21
+ super().__init__(
22
+ query=query,
23
+ solutions=solutions,
24
+ update_query=update_query,
25
+ )
26
+
27
+ def run(self) -> list[DetectedError]:
28
+ '''
29
+ Executes all complication checks and returns a list of identified misconceptions.
30
+ '''
31
+
32
+ results: list[DetectedError] = super().run()
33
+
34
+ checks = [
35
+ self.detect_82_unnecessary_complication, # ok
36
+ self.detect_83_unnecessary_distinct_in_select_clause, # ok
37
+ self.detect_84_unnecessary_table_reference, # TODO: refactor/implement
38
+ self.detect_85_unused_correlation_name, # TODO: implement
39
+ self.detect_86_tables_have_same_data, # TODO: implement
40
+ self.detect_125_correlation_name_identical_to_table_name, # TODO: implement
41
+ self.detect_87_unnecessary_general_comparison_operator, # TODO: implement
42
+ self.detect_88_like_without_wildcards, # ok
43
+ self.detect_89_unnecessarily_complicated_select_in_exists_subquery, # TODO: implement
44
+ self.detect_90_in_exists_can_be_replaced_by_comparison, # TODO: implement
45
+ self.detect_91_unnecessary_aggregate_function, # TODO: implement
46
+ self.detect_92_unnecessary_distinct_in_aggregate_function, # ok
47
+ self.detect_93_unnecessary_argument_of_count, # ok
48
+ self.detect_94_unnecessary_group_by_in_exists_subquery, # TODO: implement
49
+ self.detect_95_group_by_with_singleton_groups, # ok
50
+ self.detect_96_group_by_with_only_a_single_group, # TODO: implement
51
+ self.detect_97_group_by_can_be_replaced_by_distinct, # ok
52
+ self.detect_98_union_can_be_replaced_by_or, # TODO: implement
53
+ self.detect_99_unnecessary_column_in_order_by_clause, # TODO: refactor/implement
54
+ self.detect_100_order_by_in_subquery, # TODO: implement
55
+ self.detect_101_inefficient_having, # TODO: implement
56
+ self.detect_102_inefficient_union, # TODO: implement
57
+ self.detect_103_condition_in_the_subquery_can_be_moved_up, # TODO: implement
58
+ self.detect_104_outer_join_can_be_replaced_by_inner_join, # TODO: implement
59
+ self.detect_126_unused_cte, #
60
+ ]
61
+
62
+ for chk in checks:
63
+ results.extend(chk())
64
+
65
+ return results
66
+
67
+ def detect_82_unnecessary_complication(self) -> list[DetectedError]:
68
+ '''NOTE: this is an umbrella term, so it can't be directly detected.'''
69
+ return []
70
+
71
+ def detect_83_unnecessary_distinct_in_select_clause(self) -> list[DetectedError]:
72
+ '''
73
+ Flags a SELECT DISTINCT clause that is unnecessary because the selected
74
+ columns are already unique due to existing constraints.
75
+ '''
76
+ result: list[DetectedError] = []
77
+
78
+ for select in self.query.selects:
79
+ if not select.distinct:
80
+ continue
81
+
82
+ # Remove DISTINCT constraint
83
+ constraints = [c for c in select.output.unique_constraints if c.constraint_type != ConstraintType.DISTINCT]
84
+
85
+ if len(constraints) > 0:
86
+ result.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_SELECT_CLAUSE, (select.sql,)))
87
+
88
+ return result
89
+
90
+ def detect_84_unnecessary_table_reference(self) -> list[DetectedError]:
91
+ '''
92
+ Flags a query that joins to a table not present in the correct solution.
93
+ '''
94
+ return []
95
+
96
+ results = []
97
+ if not self.q_ast or not self.s_ast:
98
+ return results
99
+
100
+ q_tables = self._get_from_tables(self.q_ast)
101
+ s_tables = self._get_from_tables(self.s_ast)
102
+
103
+ q_tables_set = {t.lower() for t in q_tables}
104
+ s_tables_set = {t.lower() for t in s_tables}
105
+
106
+ extraneous_tables = q_tables_set - s_tables_set
107
+
108
+ if extraneous_tables:
109
+ original_q_tables = self._get_from_tables(self.q_ast, with_alias=True)
110
+ for table_name_lower in extraneous_tables:
111
+ # Find the original table name (with alias if it was used) to report back
112
+ original_table_name = next((t for t in original_q_tables if t.lower().startswith(table_name_lower)), table_name_lower)
113
+ results.append((
114
+ SqlErrors.UNNECESSARY_TABLE_REFERENCE,
115
+ f"Unnecessary JOIN: The table '{original_table_name}' is not needed to answer the query."
116
+ ))
117
+
118
+ return results
119
+
120
+ def detect_85_unused_correlation_name(self) -> list[DetectedError]:
121
+ return []
122
+
123
+ def detect_86_tables_have_same_data(self) -> list[DetectedError]:
124
+ return []
125
+
126
+ def detect_125_correlation_name_identical_to_table_name(self) -> list[DetectedError]:
127
+ return []
128
+
129
+ def detect_87_unnecessary_general_comparison_operator(self) -> list[DetectedError]:
130
+ return []
131
+
132
+ def detect_88_like_without_wildcards(self) -> list[DetectedError]:
133
+ '''
134
+ Flags queries where the LIKE operator is used without wildcards ('%' or '_').
135
+ This indicates a potential misunderstanding, where the '=' operator should
136
+ have been used instead.
137
+ '''
138
+ results: list[DetectedError] = []
139
+
140
+ for select in self.query.selects:
141
+ ast = select.ast
142
+
143
+ if not ast:
144
+ continue
145
+
146
+ for like in ast.find_all(exp.Like):
147
+ pattern_expr = like.args.get('expression')
148
+
149
+ if not pattern_expr:
150
+ # Malformed LIKE expression
151
+ continue
152
+
153
+ if not isinstance(pattern_expr, exp.Literal):
154
+ # Some other expression type, e.g., a column reference
155
+ continue
156
+
157
+ pattern_value = pattern_expr.this
158
+ if '%' not in pattern_value and '_' not in pattern_value:
159
+ full_expression = str(like)
160
+
161
+ results.append(DetectedError(SqlErrors.LIKE_WITHOUT_WILDCARDS, (full_expression,)))
162
+
163
+ return results
164
+
165
+ def detect_89_unnecessarily_complicated_select_in_exists_subquery(self) -> list[DetectedError]:
166
+ return []
167
+
168
+ def detect_90_in_exists_can_be_replaced_by_comparison(self) -> list[DetectedError]:
169
+ return []
170
+
171
+ def detect_91_unnecessary_aggregate_function(self) -> list[DetectedError]:
172
+ return []
173
+
174
+ def detect_92_unnecessary_distinct_in_aggregate_function(self) -> list[DetectedError]:
175
+ '''MIN and MAX never require DISTINCT. For other aggregate functions, DISTINCT is unnecessary if the argument is unique.'''
176
+
177
+ results: list[DetectedError] = []
178
+
179
+ for select in self.query.selects:
180
+ select = select.strip_subqueries()
181
+
182
+ if not select.ast:
183
+ continue
184
+
185
+ for agg_func in select.ast.find_all(exp.AggFunc):
186
+ if not isinstance(agg_func.this, exp.Distinct):
187
+ continue
188
+
189
+ if isinstance(agg_func, (exp.Min, exp.Max)):
190
+ results.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_AGGREGATE_FUNCTION, (str(agg_func),)))
191
+ continue
192
+
193
+ arg_expr = agg_func.this.expressions # `.this` is the DISTINCT, `.expressions` are the arguments
194
+ if not arg_expr:
195
+ continue
196
+
197
+ for expr in arg_expr:
198
+ # Check if the argument is a constant literal
199
+ if isinstance(expr, exp.Literal):
200
+ results.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_AGGREGATE_FUNCTION, (str(agg_func),)))
201
+ continue
202
+
203
+ # Check if the argument is a column
204
+ if isinstance(expr, exp.Column):
205
+ column_name = util.ast.column.get_real_name(expr)
206
+
207
+ # Check if the column has a UNIQUE constraint
208
+ unique_constraints = [c for c in select.all_constraints if c.constraint_type == ConstraintType.UNIQUE]
209
+ for constraint in unique_constraints:
210
+ if { ConstraintColumn(column_name, table_idx=select._get_table_idx_for_column(expr)) } == constraint.columns:
211
+ results.append(DetectedError(SqlErrors.UNNECESSARY_DISTINCT_IN_AGGREGATE_FUNCTION, (str(agg_func),)))
212
+ break
213
+ return results
214
+
215
+ def detect_93_unnecessary_argument_of_count(self) -> list[DetectedError]:
216
+ return []
217
+
218
+ def detect_94_unnecessary_group_by_in_exists_subquery(self) -> list[DetectedError]:
219
+ return []
220
+
221
+ def detect_95_group_by_with_singleton_groups(self) -> list[DetectedError]:
222
+ '''
223
+ Flags GROUP BY clauses on singleton groups due to the presence
224
+ of UNIQUE constraints on the grouped columns.
225
+ '''
226
+ results: list[DetectedError] = []
227
+
228
+ for select in self.query.selects:
229
+ if not select.group_by:
230
+ continue
231
+
232
+ group_by_constraint = next((c for c in select.all_constraints if c.constraint_type == ConstraintType.GROUP_BY), None)
233
+ if not group_by_constraint:
234
+ # No GROUP BY constraint found, meaning GROUP BY clause in invalid. Skip.
235
+ continue
236
+
237
+ constraints = [c for c in select.all_constraints if c.constraint_type == ConstraintType.UNIQUE]
238
+
239
+ for constraint in constraints:
240
+ if constraint.columns.issubset(group_by_constraint.columns):
241
+ results.append(DetectedError(SqlErrors.GROUP_BY_WITH_SINGLETON_GROUPS, (group_by_constraint, constraint)))
242
+ break
243
+
244
+ return results
245
+
246
+ def detect_96_group_by_with_only_a_single_group(self) -> list[DetectedError]:
247
+ return []
248
+
249
+ def detect_97_group_by_can_be_replaced_by_distinct(self) -> list[DetectedError]:
250
+ '''
251
+ Flags GROUP BY clauses that can be replaced by SELECT DISTINCT.
252
+ This occurs when all selected columns are included in the GROUP BY clause
253
+ and there are no aggregate functions in the SELECT list.
254
+ '''
255
+ results: list[DetectedError] = []
256
+
257
+ for select in self.query.selects:
258
+ select = select.strip_subqueries()
259
+
260
+ if not select.group_by:
261
+ continue
262
+
263
+ if not select.ast:
264
+ continue
265
+
266
+ has_agg_func = False
267
+ for expression in select.ast.expressions:
268
+ if list(expression.find_all(exp.AggFunc)):
269
+ has_agg_func = True
270
+ break
271
+
272
+ if has_agg_func:
273
+ continue
274
+
275
+ select_columns: list[exp.Column] = []
276
+ for expression in select.ast.expressions:
277
+ columns = list(expression.find_all(exp.Column))
278
+ select_columns.extend(columns)
279
+
280
+ group_by_columns: list[exp.Column] = []
281
+ for expression in select.group_by:
282
+ columns = list(expression.find_all(exp.Column))
283
+ group_by_columns.extend(columns)
284
+
285
+ select_col_names = {(util.ast.column.get_real_name(col), select._get_table_idx_for_column(col)) for col in select_columns}
286
+ group_by_col_names = {(util.ast.column.get_real_name(col), select._get_table_idx_for_column(col)) for col in group_by_columns}
287
+
288
+ if select_col_names == group_by_col_names:
289
+ results.append(DetectedError(SqlErrors.GROUP_BY_CAN_BE_REPLACED_WITH_DISTINCT, (select_col_names,)))
290
+
291
+ return results
292
+
293
+ def detect_98_union_can_be_replaced_by_or(self) -> list[DetectedError]:
294
+ return []
295
+
296
+ def detect_99_unnecessary_column_in_order_by_clause(self) -> list[DetectedError]:
297
+ '''
298
+ Flags when the ORDER BY clause contains unnecessary columns in addition
299
+ to the required ones.
300
+ '''
301
+ return []
302
+
303
+ results = []
304
+ if not self.q_ast or not self.s_ast:
305
+ return results
306
+
307
+ q_orderby_cols = self._get_orderby_columns(self.q_ast)
308
+ s_orderby_cols = self._get_orderby_columns(self.s_ast)
309
+
310
+ q_cols_set = {col.lower() for col, direction in q_orderby_cols}
311
+ s_cols_set = {col.lower() for col, direction in s_orderby_cols}
312
+
313
+ if s_cols_set and s_cols_set.issubset(q_cols_set) and len(q_cols_set) > len(s_cols_set):
314
+ unnecessary_cols = q_cols_set - s_cols_set
315
+ for col_lower in unnecessary_cols:
316
+ original_col = next((col for col, direction in q_orderby_cols if col.lower() == col_lower), col_lower)
317
+ results.append((
318
+ SqlErrors.COM_99_UNNECESSARY_COLUMN_IN_ORDER_BY_CLAUSE,
319
+ f"Unnecessary column in ORDER BY clause: '{original_col}' is not needed for sorting."
320
+ ))
321
+
322
+ return results
323
+
324
+ def detect_100_order_by_in_subquery(self) -> list[DetectedError]:
325
+ '''
326
+ Flags when a subquery contains an ORDER BY clause.
327
+ Subqueries both ORDER BY and LIMIT are considered valid.
328
+ '''
329
+
330
+ results: list[DetectedError] = []
331
+
332
+ # nested subqueries are checked multiple times, so track which have been checked
333
+ checked_subqueries: set[str] = set()
334
+
335
+ for select in self.query.selects:
336
+ for subquery, clause, depth in select.subqueries:
337
+ if subquery.sql in checked_subqueries:
338
+ continue
339
+
340
+ checked_subqueries.add(subquery.sql)
341
+ if subquery.order_by and not subquery.limit:
342
+ results.append(DetectedError(SqlErrors.ORDER_BY_IN_SUBQUERY, (subquery.sql,)))
343
+
344
+ return results
345
+
346
+ def detect_101_inefficient_having(self) -> list[DetectedError]:
347
+ return []
348
+
349
+ def detect_102_inefficient_union(self) -> list[DetectedError]:
350
+ return []
351
+
352
+ def detect_103_condition_in_the_subquery_can_be_moved_up(self) -> list[DetectedError]:
353
+ return []
354
+
355
+ def detect_104_outer_join_can_be_replaced_by_inner_join(self) -> list[DetectedError]:
356
+ return []
357
+
358
+ def detect_126_unused_cte(self) -> list[DetectedError]:
359
+ results: list[DetectedError] = []
360
+
361
+ if not self.query.ctes:
362
+ return results
363
+
364
+ used_ctes: dict[int, bool] = {i: False for i in range(len(self.query.ctes))}
365
+
366
+ for select in self.query.selects:
367
+ for table in select.referenced_tables:
368
+ if table.cte_idx is not None:
369
+ used_ctes[table.cte_idx] = True
370
+
371
+ for cte_idx, used in used_ctes.items():
372
+ if not used:
373
+ results.append(DetectedError(SqlErrors.UNUSED_CTE, (self.query.ctes[cte_idx].sql,)))
374
+
375
+ return results