sqlchecker 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,289 @@
1
+ '''Detector for semantic errors in SQL queries.'''
2
+
3
+ import re
4
+ from typing import Callable
5
+ from sqlglot import exp
6
+ from z3 import Not, Or, And
7
+ from sqlerrors import SqlErrors
8
+ from sqlscope import util
9
+ from sqlscope.query import Query, smt
10
+
11
+ from .base import BaseDetector, DetectedError
12
+
13
+ class SemanticErrorDetector(BaseDetector):
14
+ '''Detector for semantic errors in SQL queries.'''
15
+
16
+ def __init__(self,
17
+ *,
18
+ query: Query,
19
+ update_query: Callable[[str, str | None], None],
20
+ solutions: list[Query] = [],
21
+ ):
22
+ super().__init__(
23
+ query=query,
24
+ solutions=solutions,
25
+ update_query=update_query,
26
+ )
27
+
28
+ def run(self) -> list[DetectedError]:
29
+ results: list[DetectedError] = super().run()
30
+
31
+ checks = [
32
+ self.detect_40_tautological_or_inconsistent_expression, # ok
33
+ self.detect_41_distinct_in_sum_or_avg, # ok
34
+ self.detect_42_distinct_removing_important_duplicates, # TODO: implement
35
+ self.detect_45_mixing_comparison_and_null, # TODO: refactor/implement
36
+ self.detect_46_null_in_InAnyAll_subquery, # TODO: implement
37
+ self.detect_47_join_condition_on_unmatchable_column, # TODO: implement
38
+ self.detect_49_many_duplicates, # TODO: implement
39
+ self.detect_50_constant_column_output, # TODO: revise and implement
40
+ self.detect_51_duplicate_column_output, # ok
41
+ ]
42
+
43
+ for chk in checks:
44
+ results.extend(chk())
45
+
46
+ return results
47
+
48
+ def detect_40_tautological_or_inconsistent_expression(self) -> list[DetectedError]:
49
+ results: list[DetectedError] = []
50
+
51
+ for select in self.query.selects:
52
+ where = select.where
53
+
54
+ if not where:
55
+ continue
56
+
57
+ # Build Z3 variables from catalog
58
+ variables = {}
59
+ for table in select.referenced_tables:
60
+ variables.update(smt.catalog_table_to_z3_vars(table))
61
+
62
+ dnf = util.ast.extract_DNF(where)
63
+
64
+
65
+ # Refer to Brass & Goldberg, 2006 for these checks (error #8)
66
+ # (1) whole formula
67
+ try:
68
+ whole_clauses = [smt.sql_to_z3(C, variables) for C in dnf]
69
+ whole = Or(*whole_clauses)
70
+ except Exception:
71
+ continue # skip if cannot convert to z3
72
+
73
+ if not smt.is_satisfiable(whole):
74
+ results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('contradiction',)))
75
+ elif not smt.is_satisfiable(Not(whole)):
76
+ results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('tautology',)))
77
+
78
+ # (2) each Ci redundant?
79
+ for i, Ci in enumerate(dnf):
80
+ Ci_z3 = smt.sql_to_z3(Ci, variables)
81
+ others = Or(*[smt.sql_to_z3(C, variables) for j, C in enumerate(dnf) if j != i])
82
+ if not smt.is_satisfiable(And(Ci_z3, Not(others))):
83
+ results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('redundant_disjunct', Ci.sql())))
84
+
85
+ # (3) each Ai,j redundant?
86
+ conjuncts = list(Ci.flatten())
87
+ for j, Aj in enumerate(conjuncts):
88
+ Aj_z3 = smt.sql_to_z3(Aj, variables)
89
+ if not smt.is_bool_expr(Aj_z3):
90
+ continue
91
+ rest = [smt.sql_to_z3(c, variables) for k, c in enumerate(conjuncts)
92
+ if k != j and smt.is_bool_expr(smt.sql_to_z3(c, variables))]
93
+ others = Or(*[smt.sql_to_z3(C, variables) for k, C in enumerate(dnf) if k != i])
94
+ formula = And(Not(Aj_z3), *rest, Not(others))
95
+ if not smt.is_satisfiable(formula):
96
+ results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('redundant_conjunct', (Ci.sql(), Aj.sql()))))
97
+
98
+ return results
99
+
100
+ def detect_41_distinct_in_sum_or_avg(self) -> list[DetectedError]:
101
+ '''
102
+ Detect SUM(DISTINCT ...) or AVG(DISTINCT ...)
103
+
104
+ If the correct query uses SUM(DISTINCT ...) or AVG(DISTINCT ...), then
105
+ the user query is unlikely to be incorrect, so we do not flag it.
106
+ '''
107
+
108
+ results: list[DetectedError] = []
109
+
110
+ # Flags for skipping detection if correct query uses DISTINCT in SUM/AVG
111
+ allow_sum_distinct = False
112
+ allow_avg_distinct = False
113
+
114
+ # First check the correct solutions
115
+ for solution in self.solutions:
116
+ for select in solution.selects:
117
+ ast = select.ast
118
+
119
+ if not ast:
120
+ continue
121
+
122
+ for func in ast.find_all(exp.Sum):
123
+ if func.this and isinstance(func.this, exp.Distinct):
124
+ allow_sum_distinct = True
125
+
126
+ for func in ast.find_all(exp.Avg):
127
+ if func.this and isinstance(func.this, exp.Distinct):
128
+ allow_avg_distinct = True
129
+
130
+ # Then check the user query
131
+ for select in self.query.selects:
132
+ ast = select.ast
133
+
134
+ if not ast:
135
+ continue
136
+
137
+ if not allow_sum_distinct:
138
+ # Solution does not use SUM(DISTINCT ...), so check user query
139
+ for func in ast.find_all(exp.Sum):
140
+ if func.this and isinstance(func.this, exp.Distinct):
141
+ results.append(DetectedError(SqlErrors.DISTINCT_IN_SUM_OR_AVG, (func.sql(),)))
142
+
143
+ if not allow_avg_distinct:
144
+ # Solution does not use AVG(DISTINCT ...), so check user query
145
+ for func in ast.find_all(exp.Avg):
146
+ if func.this and isinstance(func.this, exp.Distinct):
147
+ results.append(DetectedError(SqlErrors.DISTINCT_IN_SUM_OR_AVG, (func.sql(),)))
148
+
149
+ return results
150
+
151
+ def detect_42_distinct_removing_important_duplicates(self) -> list[DetectedError]:
152
+ return []
153
+
154
+ def detect_45_mixing_comparison_and_null(self) -> list[DetectedError]:
155
+ '''Detect mixing of >0 with IS NOT NULL or empty string with IS NULL on the same column'''
156
+ return []
157
+
158
+ results = []
159
+ # a > 0 AND a IS NOT NULL
160
+ m = re.search(r"(\w+)\s*>\s*0\s+AND\s+\1\s+IS\s+NOT\s+NULL", self.query, re.IGNORECASE)
161
+ if m:
162
+ results.append((
163
+ SqlErrors.SEM_45_MIXING_A_GREATER_THAN_0_WITH_IS_NOT_NULL,
164
+ m.group(0)
165
+ ))
166
+
167
+ # a = '' AND a IS NULL
168
+ m2 = re.search(r"(\w+)\s*=\s*''\s+AND\s+\1\s+IS\s+NULL", self.query, re.IGNORECASE)
169
+ if m2:
170
+ results.append((
171
+ SqlErrors.SEM_45_MIXING_A_GREATER_THAN_0_WITH_IS_NOT_NULL,
172
+ m2.group(0)
173
+ ))
174
+
175
+ return results
176
+
177
+ def detect_46_null_in_InAnyAll_subquery(self) -> list[DetectedError]:
178
+ '''Detect potential NULL/UNKNOWN in IN/ANY/ALL subqueries when subquery column is nullable.
179
+ heuristically assume that if a column is not declared as NOT NULL, then every typical
180
+ database state contains at least one row in which it is null. '''
181
+ return []
182
+
183
+ def detect_47_join_condition_on_unmatchable_column(self) -> list[DetectedError]:
184
+ '''
185
+ For each JOIN … ON: require at least one “A.col = B.col” in the ON clause.
186
+ For comma-style joins (FROM A, B): require at least one “A.col = B.col” in the WHERE.
187
+ If no such predicate is found for a given join, emit SEM_2_JOIN_ON_INCORRECT_COLUMN.
188
+ If the join operation is a self-join, then skip the check.
189
+ Check based on the content of the catalog column_metadata the compatibility of the columns.
190
+ '''
191
+ return []
192
+
193
+ def detect_49_many_duplicates(self) -> list[DetectedError]:
194
+ return []
195
+
196
+ def detect_50_constant_column_output(self) -> list[DetectedError]:
197
+ '''
198
+ Detect if the output of the query contains a constant value.
199
+ Exclude constants that are likely intentional, such as SELECT 1, SELECT 'constant', etc.
200
+ Also exclude aggregation functions that return constants, such as COUNT(*), SUM(*), etc.
201
+ '''
202
+
203
+ return []
204
+
205
+ # NOTE: the following implementatation is incorrect, since it selects only intentional constants
206
+
207
+ results: list[DetectedError] = []
208
+
209
+ output = self.query.main_query.output
210
+
211
+ for col in output.columns:
212
+ if col.is_constant:
213
+ results.append(DetectedError(SqlErrors.CONSTANT_COLUMN_OUTPUT, (col.name,)))
214
+
215
+ return results
216
+
217
+ def detect_51_duplicate_column_output(self) -> list[DetectedError]:
218
+ '''
219
+ Detects if the same column or expression appears multiple times in the SELECT list.
220
+ Also include columns that are equated.
221
+ '''
222
+
223
+ results: list[DetectedError] = []
224
+
225
+ projected_columns: set[str] = set()
226
+
227
+ # list of equivalence classes of columns that are equated in the WHERE clause (e.g. A.id = B.id means A.id and B.id are equivalent for the purpose of duplicate detection)
228
+ column_equivalences: list[set[str]] = []
229
+
230
+ for select in self.query.selects:
231
+ if select.where:
232
+ equalities = util.ast.extract_column_equalities(select.where)
233
+
234
+ for left, right in equalities:
235
+ left_name = util.ast.column.get_real_name(left)
236
+ left_idx = select._get_table_idx_for_column(left)
237
+
238
+ right_name = util.ast.column.get_real_name(right)
239
+ right_idx = select._get_table_idx_for_column(right)
240
+
241
+ if left_idx is not None and right_idx is not None:
242
+ left_full = f'{left_idx}.{left_name}'
243
+ right_full = f'{right_idx}.{right_name}'
244
+
245
+ # Check if left and right are already in an equivalence class
246
+ left_class = None
247
+ right_class = None
248
+
249
+ for eq_class in column_equivalences:
250
+ if left_full in eq_class:
251
+ left_class = eq_class
252
+ if right_full in eq_class:
253
+ right_class = eq_class
254
+
255
+ if left_class and right_class and left_class != right_class:
256
+ # Merge the two classes
257
+ left_class.update(right_class)
258
+ column_equivalences.remove(right_class)
259
+ elif left_class and not right_class:
260
+ left_class.add(right_full)
261
+ elif right_class and not left_class:
262
+ right_class.add(left_full)
263
+ else:
264
+ column_equivalences.append(set([left_full, right_full]))
265
+
266
+ for column in select.output.columns:
267
+ table_idx = column.table_idx
268
+
269
+ if table_idx is None:
270
+ # TODO: handle expressions and constants in the SELECT list
271
+ continue # skip if no table reference (e.g. constant or computed column)
272
+
273
+ name = f'{column.table_idx}.{column.real_name}'
274
+
275
+ equivalent_names = set()
276
+ for eq_class in column_equivalences:
277
+ if name in eq_class:
278
+ equivalent_names.update(eq_class)
279
+
280
+ if name in projected_columns:
281
+ results.append(DetectedError(SqlErrors.DUPLICATE_COLUMN_OUTPUT, (select.referenced_tables[table_idx].name, column.real_name)))
282
+
283
+ projected_columns.add(name)
284
+ projected_columns.update(equivalent_names)
285
+
286
+ return results
287
+
288
+
289
+