sqlchecker 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlchecker/__init__.py +58 -0
- sqlchecker/detectors/__init__.py +103 -0
- sqlchecker/detectors/base.py +44 -0
- sqlchecker/detectors/complications.py +375 -0
- sqlchecker/detectors/logical.py +732 -0
- sqlchecker/detectors/semantic.py +289 -0
- sqlchecker/detectors/syntax.py +1140 -0
- sqlchecker-0.3.1.dist-info/METADATA +153 -0
- sqlchecker-0.3.1.dist-info/RECORD +11 -0
- sqlchecker-0.3.1.dist-info/WHEEL +4 -0
- sqlchecker-0.3.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
'''Detector for semantic errors in SQL queries.'''
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Callable
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
from z3 import Not, Or, And
|
|
7
|
+
from sqlerrors import SqlErrors
|
|
8
|
+
from sqlscope import util
|
|
9
|
+
from sqlscope.query import Query, smt
|
|
10
|
+
|
|
11
|
+
from .base import BaseDetector, DetectedError
|
|
12
|
+
|
|
13
|
+
class SemanticErrorDetector(BaseDetector):
|
|
14
|
+
'''Detector for semantic errors in SQL queries.'''
|
|
15
|
+
|
|
16
|
+
def __init__(self,
|
|
17
|
+
*,
|
|
18
|
+
query: Query,
|
|
19
|
+
update_query: Callable[[str, str | None], None],
|
|
20
|
+
solutions: list[Query] = [],
|
|
21
|
+
):
|
|
22
|
+
super().__init__(
|
|
23
|
+
query=query,
|
|
24
|
+
solutions=solutions,
|
|
25
|
+
update_query=update_query,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
def run(self) -> list[DetectedError]:
|
|
29
|
+
results: list[DetectedError] = super().run()
|
|
30
|
+
|
|
31
|
+
checks = [
|
|
32
|
+
self.detect_40_tautological_or_inconsistent_expression, # ok
|
|
33
|
+
self.detect_41_distinct_in_sum_or_avg, # ok
|
|
34
|
+
self.detect_42_distinct_removing_important_duplicates, # TODO: implement
|
|
35
|
+
self.detect_45_mixing_comparison_and_null, # TODO: refactor/implement
|
|
36
|
+
self.detect_46_null_in_InAnyAll_subquery, # TODO: implement
|
|
37
|
+
self.detect_47_join_condition_on_unmatchable_column, # TODO: implement
|
|
38
|
+
self.detect_49_many_duplicates, # TODO: implement
|
|
39
|
+
self.detect_50_constant_column_output, # TODO: revise and implement
|
|
40
|
+
self.detect_51_duplicate_column_output, # ok
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
for chk in checks:
|
|
44
|
+
results.extend(chk())
|
|
45
|
+
|
|
46
|
+
return results
|
|
47
|
+
|
|
48
|
+
def detect_40_tautological_or_inconsistent_expression(self) -> list[DetectedError]:
|
|
49
|
+
results: list[DetectedError] = []
|
|
50
|
+
|
|
51
|
+
for select in self.query.selects:
|
|
52
|
+
where = select.where
|
|
53
|
+
|
|
54
|
+
if not where:
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
# Build Z3 variables from catalog
|
|
58
|
+
variables = {}
|
|
59
|
+
for table in select.referenced_tables:
|
|
60
|
+
variables.update(smt.catalog_table_to_z3_vars(table))
|
|
61
|
+
|
|
62
|
+
dnf = util.ast.extract_DNF(where)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Refer to Brass & Goldberg, 2006 for these checks (error #8)
|
|
66
|
+
# (1) whole formula
|
|
67
|
+
try:
|
|
68
|
+
whole_clauses = [smt.sql_to_z3(C, variables) for C in dnf]
|
|
69
|
+
whole = Or(*whole_clauses)
|
|
70
|
+
except Exception:
|
|
71
|
+
continue # skip if cannot convert to z3
|
|
72
|
+
|
|
73
|
+
if not smt.is_satisfiable(whole):
|
|
74
|
+
results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('contradiction',)))
|
|
75
|
+
elif not smt.is_satisfiable(Not(whole)):
|
|
76
|
+
results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('tautology',)))
|
|
77
|
+
|
|
78
|
+
# (2) each Ci redundant?
|
|
79
|
+
for i, Ci in enumerate(dnf):
|
|
80
|
+
Ci_z3 = smt.sql_to_z3(Ci, variables)
|
|
81
|
+
others = Or(*[smt.sql_to_z3(C, variables) for j, C in enumerate(dnf) if j != i])
|
|
82
|
+
if not smt.is_satisfiable(And(Ci_z3, Not(others))):
|
|
83
|
+
results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('redundant_disjunct', Ci.sql())))
|
|
84
|
+
|
|
85
|
+
# (3) each Ai,j redundant?
|
|
86
|
+
conjuncts = list(Ci.flatten())
|
|
87
|
+
for j, Aj in enumerate(conjuncts):
|
|
88
|
+
Aj_z3 = smt.sql_to_z3(Aj, variables)
|
|
89
|
+
if not smt.is_bool_expr(Aj_z3):
|
|
90
|
+
continue
|
|
91
|
+
rest = [smt.sql_to_z3(c, variables) for k, c in enumerate(conjuncts)
|
|
92
|
+
if k != j and smt.is_bool_expr(smt.sql_to_z3(c, variables))]
|
|
93
|
+
others = Or(*[smt.sql_to_z3(C, variables) for k, C in enumerate(dnf) if k != i])
|
|
94
|
+
formula = And(Not(Aj_z3), *rest, Not(others))
|
|
95
|
+
if not smt.is_satisfiable(formula):
|
|
96
|
+
results.append(DetectedError(SqlErrors.IMPLIED_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('redundant_conjunct', (Ci.sql(), Aj.sql()))))
|
|
97
|
+
|
|
98
|
+
return results
|
|
99
|
+
|
|
100
|
+
def detect_41_distinct_in_sum_or_avg(self) -> list[DetectedError]:
|
|
101
|
+
'''
|
|
102
|
+
Detect SUM(DISTINCT ...) or AVG(DISTINCT ...)
|
|
103
|
+
|
|
104
|
+
If the correct query uses SUM(DISTINCT ...) or AVG(DISTINCT ...), then
|
|
105
|
+
the user query is unlikely to be incorrect, so we do not flag it.
|
|
106
|
+
'''
|
|
107
|
+
|
|
108
|
+
results: list[DetectedError] = []
|
|
109
|
+
|
|
110
|
+
# Flags for skipping detection if correct query uses DISTINCT in SUM/AVG
|
|
111
|
+
allow_sum_distinct = False
|
|
112
|
+
allow_avg_distinct = False
|
|
113
|
+
|
|
114
|
+
# First check the correct solutions
|
|
115
|
+
for solution in self.solutions:
|
|
116
|
+
for select in solution.selects:
|
|
117
|
+
ast = select.ast
|
|
118
|
+
|
|
119
|
+
if not ast:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
for func in ast.find_all(exp.Sum):
|
|
123
|
+
if func.this and isinstance(func.this, exp.Distinct):
|
|
124
|
+
allow_sum_distinct = True
|
|
125
|
+
|
|
126
|
+
for func in ast.find_all(exp.Avg):
|
|
127
|
+
if func.this and isinstance(func.this, exp.Distinct):
|
|
128
|
+
allow_avg_distinct = True
|
|
129
|
+
|
|
130
|
+
# Then check the user query
|
|
131
|
+
for select in self.query.selects:
|
|
132
|
+
ast = select.ast
|
|
133
|
+
|
|
134
|
+
if not ast:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
if not allow_sum_distinct:
|
|
138
|
+
# Solution does not use SUM(DISTINCT ...), so check user query
|
|
139
|
+
for func in ast.find_all(exp.Sum):
|
|
140
|
+
if func.this and isinstance(func.this, exp.Distinct):
|
|
141
|
+
results.append(DetectedError(SqlErrors.DISTINCT_IN_SUM_OR_AVG, (func.sql(),)))
|
|
142
|
+
|
|
143
|
+
if not allow_avg_distinct:
|
|
144
|
+
# Solution does not use AVG(DISTINCT ...), so check user query
|
|
145
|
+
for func in ast.find_all(exp.Avg):
|
|
146
|
+
if func.this and isinstance(func.this, exp.Distinct):
|
|
147
|
+
results.append(DetectedError(SqlErrors.DISTINCT_IN_SUM_OR_AVG, (func.sql(),)))
|
|
148
|
+
|
|
149
|
+
return results
|
|
150
|
+
|
|
151
|
+
def detect_42_distinct_removing_important_duplicates(self) -> list[DetectedError]:
|
|
152
|
+
return []
|
|
153
|
+
|
|
154
|
+
def detect_45_mixing_comparison_and_null(self) -> list[DetectedError]:
|
|
155
|
+
'''Detect mixing of >0 with IS NOT NULL or empty string with IS NULL on the same column'''
|
|
156
|
+
return []
|
|
157
|
+
|
|
158
|
+
results = []
|
|
159
|
+
# a > 0 AND a IS NOT NULL
|
|
160
|
+
m = re.search(r"(\w+)\s*>\s*0\s+AND\s+\1\s+IS\s+NOT\s+NULL", self.query, re.IGNORECASE)
|
|
161
|
+
if m:
|
|
162
|
+
results.append((
|
|
163
|
+
SqlErrors.SEM_45_MIXING_A_GREATER_THAN_0_WITH_IS_NOT_NULL,
|
|
164
|
+
m.group(0)
|
|
165
|
+
))
|
|
166
|
+
|
|
167
|
+
# a = '' AND a IS NULL
|
|
168
|
+
m2 = re.search(r"(\w+)\s*=\s*''\s+AND\s+\1\s+IS\s+NULL", self.query, re.IGNORECASE)
|
|
169
|
+
if m2:
|
|
170
|
+
results.append((
|
|
171
|
+
SqlErrors.SEM_45_MIXING_A_GREATER_THAN_0_WITH_IS_NOT_NULL,
|
|
172
|
+
m2.group(0)
|
|
173
|
+
))
|
|
174
|
+
|
|
175
|
+
return results
|
|
176
|
+
|
|
177
|
+
def detect_46_null_in_InAnyAll_subquery(self) -> list[DetectedError]:
|
|
178
|
+
'''Detect potential NULL/UNKNOWN in IN/ANY/ALL subqueries when subquery column is nullable.
|
|
179
|
+
heuristically assume that if a column is not declared as NOT NULL, then every typical
|
|
180
|
+
database state contains at least one row in which it is null. '''
|
|
181
|
+
return []
|
|
182
|
+
|
|
183
|
+
def detect_47_join_condition_on_unmatchable_column(self) -> list[DetectedError]:
|
|
184
|
+
'''
|
|
185
|
+
For each JOIN … ON: require at least one “A.col = B.col” in the ON clause.
|
|
186
|
+
For comma-style joins (FROM A, B): require at least one “A.col = B.col” in the WHERE.
|
|
187
|
+
If no such predicate is found for a given join, emit SEM_2_JOIN_ON_INCORRECT_COLUMN.
|
|
188
|
+
If the join operation is a self-join, then skip the check.
|
|
189
|
+
Check based on the content of the catalog column_metadata the compatibility of the columns.
|
|
190
|
+
'''
|
|
191
|
+
return []
|
|
192
|
+
|
|
193
|
+
def detect_49_many_duplicates(self) -> list[DetectedError]:
|
|
194
|
+
return []
|
|
195
|
+
|
|
196
|
+
def detect_50_constant_column_output(self) -> list[DetectedError]:
|
|
197
|
+
'''
|
|
198
|
+
Detect if the output of the query contains a constant value.
|
|
199
|
+
Exclude constants that are likely intentional, such as SELECT 1, SELECT 'constant', etc.
|
|
200
|
+
Also exclude aggregation functions that return constants, such as COUNT(*), SUM(*), etc.
|
|
201
|
+
'''
|
|
202
|
+
|
|
203
|
+
return []
|
|
204
|
+
|
|
205
|
+
# NOTE: the following implementatation is incorrect, since it selects only intentional constants
|
|
206
|
+
|
|
207
|
+
results: list[DetectedError] = []
|
|
208
|
+
|
|
209
|
+
output = self.query.main_query.output
|
|
210
|
+
|
|
211
|
+
for col in output.columns:
|
|
212
|
+
if col.is_constant:
|
|
213
|
+
results.append(DetectedError(SqlErrors.CONSTANT_COLUMN_OUTPUT, (col.name,)))
|
|
214
|
+
|
|
215
|
+
return results
|
|
216
|
+
|
|
217
|
+
def detect_51_duplicate_column_output(self) -> list[DetectedError]:
|
|
218
|
+
'''
|
|
219
|
+
Detects if the same column or expression appears multiple times in the SELECT list.
|
|
220
|
+
Also include columns that are equated.
|
|
221
|
+
'''
|
|
222
|
+
|
|
223
|
+
results: list[DetectedError] = []
|
|
224
|
+
|
|
225
|
+
projected_columns: set[str] = set()
|
|
226
|
+
|
|
227
|
+
# list of equivalence classes of columns that are equated in the WHERE clause (e.g. A.id = B.id means A.id and B.id are equivalent for the purpose of duplicate detection)
|
|
228
|
+
column_equivalences: list[set[str]] = []
|
|
229
|
+
|
|
230
|
+
for select in self.query.selects:
|
|
231
|
+
if select.where:
|
|
232
|
+
equalities = util.ast.extract_column_equalities(select.where)
|
|
233
|
+
|
|
234
|
+
for left, right in equalities:
|
|
235
|
+
left_name = util.ast.column.get_real_name(left)
|
|
236
|
+
left_idx = select._get_table_idx_for_column(left)
|
|
237
|
+
|
|
238
|
+
right_name = util.ast.column.get_real_name(right)
|
|
239
|
+
right_idx = select._get_table_idx_for_column(right)
|
|
240
|
+
|
|
241
|
+
if left_idx is not None and right_idx is not None:
|
|
242
|
+
left_full = f'{left_idx}.{left_name}'
|
|
243
|
+
right_full = f'{right_idx}.{right_name}'
|
|
244
|
+
|
|
245
|
+
# Check if left and right are already in an equivalence class
|
|
246
|
+
left_class = None
|
|
247
|
+
right_class = None
|
|
248
|
+
|
|
249
|
+
for eq_class in column_equivalences:
|
|
250
|
+
if left_full in eq_class:
|
|
251
|
+
left_class = eq_class
|
|
252
|
+
if right_full in eq_class:
|
|
253
|
+
right_class = eq_class
|
|
254
|
+
|
|
255
|
+
if left_class and right_class and left_class != right_class:
|
|
256
|
+
# Merge the two classes
|
|
257
|
+
left_class.update(right_class)
|
|
258
|
+
column_equivalences.remove(right_class)
|
|
259
|
+
elif left_class and not right_class:
|
|
260
|
+
left_class.add(right_full)
|
|
261
|
+
elif right_class and not left_class:
|
|
262
|
+
right_class.add(left_full)
|
|
263
|
+
else:
|
|
264
|
+
column_equivalences.append(set([left_full, right_full]))
|
|
265
|
+
|
|
266
|
+
for column in select.output.columns:
|
|
267
|
+
table_idx = column.table_idx
|
|
268
|
+
|
|
269
|
+
if table_idx is None:
|
|
270
|
+
# TODO: handle expressions and constants in the SELECT list
|
|
271
|
+
continue # skip if no table reference (e.g. constant or computed column)
|
|
272
|
+
|
|
273
|
+
name = f'{column.table_idx}.{column.real_name}'
|
|
274
|
+
|
|
275
|
+
equivalent_names = set()
|
|
276
|
+
for eq_class in column_equivalences:
|
|
277
|
+
if name in eq_class:
|
|
278
|
+
equivalent_names.update(eq_class)
|
|
279
|
+
|
|
280
|
+
if name in projected_columns:
|
|
281
|
+
results.append(DetectedError(SqlErrors.DUPLICATE_COLUMN_OUTPUT, (select.referenced_tables[table_idx].name, column.real_name)))
|
|
282
|
+
|
|
283
|
+
projected_columns.add(name)
|
|
284
|
+
projected_columns.update(equivalent_names)
|
|
285
|
+
|
|
286
|
+
return results
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
|