sql-error-categorizer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,493 @@
1
+ import difflib
2
+ import re
3
+ import sqlparse
4
+ import sqlparse.keywords
5
+ from typing import Callable
6
+ from sqlglot import exp
7
+ from z3 import Solver, Not, unsat, Or, And, BoolSort, is_expr
8
+ import sqlglot
9
+
10
+
11
+ from .base import BaseDetector, DetectedError
12
+ from ..query import Query, smt, extract_DNF
13
+ from ..sql_errors import SqlErrors
14
+ from ..catalog import Catalog
15
+
16
+ class SemanticErrorDetector(BaseDetector):
17
+ def __init__(self,
18
+ *,
19
+ query: Query,
20
+ update_query: Callable[[str, str | None], None],
21
+ solutions: list[Query] = [],
22
+ ):
23
+ super().__init__(
24
+ query=query,
25
+ solutions=solutions,
26
+ update_query=update_query,
27
+ )
28
+
29
+ def run(self) -> list[DetectedError]:
30
+ results: list[DetectedError] = super().run()
31
+
32
+ checks = [
33
+ self.sem_39_and_instead_of_or,
34
+ self.sem_40_tautological_or_inconsistent_expression,
35
+ self.sem_41_distinct_in_sum_or_avg,
36
+ self.sem_42_distinct_removing_important_duplicates,
37
+ self.sem_43_wildcards_without_like,
38
+ self.sem_44_incorrect_wildcard,
39
+ self.sem_45_mixing_comparison_and_null,
40
+ self.sem_46_null_in_in_subquery,
41
+ self.sem_47_join_on_incorrect_column,
42
+ self.sem_48_missing_join,
43
+ self.sem_49_duplicate_rows,
44
+ self.sem_50_constant_column_output,
45
+ self.sem_51_duplicate_column_output,
46
+ ]
47
+
48
+ for chk in checks:
49
+ results.extend(chk())
50
+
51
+ return results
52
+
53
+ def sem_39_and_instead_of_or(self) -> list[DetectedError]:
54
+ '''Detect AND used instead of OR in WHERE conditions, which produces an empty result set'''
55
+ return []
56
+
57
+ def sem_40_tautological_or_inconsistent_expression(self) -> list[DetectedError]:
58
+ results: list[DetectedError] = []
59
+
60
+ for select in self.query.selects:
61
+ where = select.where
62
+
63
+ if not where:
64
+ continue
65
+
66
+ # Build Z3 variables from catalog
67
+ variables = {}
68
+ for table in select.referenced_tables:
69
+ variables.update(smt.catalog_table_to_z3_vars(table))
70
+
71
+ dnf = extract_DNF(where)
72
+
73
+ # Refer to Brass & Goldberg, 2006 for these checks (error #8)
74
+ # (1) whole formula
75
+ try:
76
+ whole = Or(*[smt.sql_to_z3(C, variables) for C in dnf])
77
+ except Exception:
78
+ continue # skip if cannot convert to z3
79
+
80
+ if not smt.is_satisfiable(whole):
81
+ results.append(DetectedError(SqlErrors.SEM_40_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('contradiction',)))
82
+ elif not smt.is_satisfiable(Not(whole)):
83
+ results.append(DetectedError(SqlErrors.SEM_40_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('tautology',)))
84
+
85
+ # (2) each Ci redundant?
86
+ for i, Ci in enumerate(dnf):
87
+ Ci_z3 = smt.sql_to_z3(Ci, variables)
88
+ others = Or(*[smt.sql_to_z3(C, variables) for j, C in enumerate(dnf) if j != i])
89
+ if not smt.is_satisfiable(And(Ci_z3, Not(others))):
90
+ results.append(DetectedError(SqlErrors.SEM_40_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('redundant_disjunct', Ci.sql())))
91
+
92
+ # (3) each Ai,j redundant?
93
+ conjuncts = list(Ci.flatten())
94
+ for j, Aj in enumerate(conjuncts):
95
+ Aj_z3 = smt.sql_to_z3(Aj, variables)
96
+ if not smt.is_bool_expr(Aj_z3):
97
+ continue
98
+ rest = [smt.sql_to_z3(c, variables) for k, c in enumerate(conjuncts)
99
+ if k != j and smt.is_bool_expr(smt.sql_to_z3(c, variables))]
100
+ others = Or(*[smt.sql_to_z3(C, variables) for k, C in enumerate(dnf) if k != i])
101
+ formula = And(Not(Aj_z3), *rest, Not(others))
102
+ if not smt.is_satisfiable(formula):
103
+ results.append(DetectedError(SqlErrors.SEM_40_TAUTOLOGICAL_OR_INCONSISTENT_EXPRESSION, ('redundant_conjunct', (Ci.sql(), Aj.sql()))))
104
+
105
+ return results
106
+
107
+ def sem_41_distinct_in_sum_or_avg(self) -> list[DetectedError]:
108
+ '''
109
+ Detect SUM(DISTINCT ...) or AVG(DISTINCT ...)
110
+
111
+ If the correct query uses SUM(DISTINCT ...) or AVG(DISTINCT ...), then
112
+ the user query is unlikely to be incorrect, so we do not flag it.
113
+ '''
114
+
115
+ results: list[DetectedError] = []
116
+
117
+ # Flags for skipping detection if correct query uses DISTINCT in SUM/AVG
118
+ allow_sum_distinct = False
119
+ allow_avg_distinct = False
120
+
121
+ # First check the correct solutions
122
+ for solution in self.solutions:
123
+ for select in solution.selects:
124
+ ast = select.ast
125
+
126
+ if not ast:
127
+ continue
128
+
129
+ for func in ast.find_all(exp.Sum):
130
+ if func.this and isinstance(func.this, exp.Distinct):
131
+ allow_sum_distinct = True
132
+
133
+ for func in ast.find_all(exp.Avg):
134
+ if func.this and isinstance(func.this, exp.Distinct):
135
+ allow_avg_distinct = True
136
+
137
+ # Then check the user query
138
+ for select in self.query.selects:
139
+ ast = select.ast
140
+
141
+ if not ast:
142
+ continue
143
+
144
+ if not allow_sum_distinct:
145
+ # Solution does not use SUM(DISTINCT ...), so check user query
146
+ for func in ast.find_all(exp.Sum):
147
+ if func.this and isinstance(func.this, exp.Distinct):
148
+ results.append(DetectedError(SqlErrors.SEM_41_DISTINCT_IN_SUM_OR_AVG, (func.sql(),)))
149
+
150
+ if not allow_avg_distinct:
151
+ # Solution does not use AVG(DISTINCT ...), so check user query
152
+ for func in ast.find_all(exp.Avg):
153
+ if func.this and isinstance(func.this, exp.Distinct):
154
+ results.append(DetectedError(SqlErrors.SEM_41_DISTINCT_IN_SUM_OR_AVG, (func.sql(),)))
155
+
156
+ return results
157
+
158
+
159
+ # TODO: implement
160
+ def sem_42_distinct_removing_important_duplicates(self) -> list[DetectedError]:
161
+ return []
162
+
163
+ def sem_43_wildcards_without_like(self) -> list[DetectedError]:
164
+ '''
165
+ Detect = '%...%' instead of LIKE
166
+
167
+ If the correct query uses equality checks containing wildcards characters ('%' or '_'),
168
+ the user query is unlikely to be incorrect, so we do not flag it.
169
+ '''
170
+
171
+ results: list[DetectedError] = []
172
+
173
+ # First check the correct solutions
174
+ allow_underscore = False
175
+ allow_percent = False
176
+
177
+ for solution in self.solutions:
178
+ for select in solution.selects:
179
+ ast = select.ast
180
+
181
+ if not ast:
182
+ continue
183
+
184
+ for eq in ast.find_all(exp.EQ):
185
+ left = eq.this
186
+ right = eq.expression
187
+
188
+ if isinstance(left, exp.Literal):
189
+ if has_character(left, '_'):
190
+ allow_underscore = True
191
+ if has_character(left, '%'):
192
+ allow_percent = True
193
+
194
+ if isinstance(right, exp.Literal):
195
+ if has_character(right, '_'):
196
+ allow_underscore = True
197
+ if has_character(right, '%'):
198
+ allow_percent = True
199
+
200
+ for select in self.query.selects:
201
+ ast = select.ast
202
+
203
+ if not ast:
204
+ continue
205
+
206
+ for eq in ast.find_all(exp.EQ):
207
+ left = eq.this
208
+ right = eq.expression
209
+
210
+ if isinstance(left, exp.Literal):
211
+ if not allow_underscore and has_character(left, '_'):
212
+ results.append(DetectedError(SqlErrors.SEM_43_WILDCARDS_WITHOUT_LIKE, (str(eq),)))
213
+ continue
214
+ if not allow_percent and has_character(left, '%'):
215
+ results.append(DetectedError(SqlErrors.SEM_43_WILDCARDS_WITHOUT_LIKE, (str(eq),)))
216
+ continue
217
+
218
+ if isinstance(right, exp.Literal):
219
+ if not allow_underscore and has_character(right, '_'):
220
+ results.append(DetectedError(SqlErrors.SEM_43_WILDCARDS_WITHOUT_LIKE, (str(eq),)))
221
+ continue
222
+ if not allow_percent and has_character(right, '%'):
223
+ results.append(DetectedError(SqlErrors.SEM_43_WILDCARDS_WITHOUT_LIKE, (str(eq),)))
224
+ continue
225
+
226
+ return results
227
+
228
+ def sem_44_incorrect_wildcard(self) -> list[DetectedError]:
229
+ '''
230
+ Detect misuse of wildcards, namely:
231
+ - '*' and '?'
232
+ - '_' instead of '%'
233
+ - '%' instead of '_'
234
+
235
+ If the correct solution uses the same character,
236
+ the user query is unlikely to be incorrect, so we do not flag it.
237
+ '''
238
+
239
+ results: list[DetectedError] = []
240
+
241
+ # First check the correct solutions
242
+ underscore_in_solution = False
243
+ percent_in_solution = False
244
+ star_in_solution = False
245
+ question_mark_in_solution = False
246
+
247
+ for solution in self.solutions:
248
+ for select in solution.selects:
249
+ ast = select.ast
250
+
251
+ if not ast:
252
+ continue
253
+
254
+ for like in ast.find_all(exp.Like):
255
+ pattern = like.expression
256
+ if isinstance(pattern, exp.Literal):
257
+ if has_character(pattern, '_'):
258
+ underscore_in_solution = True
259
+ if has_character(pattern, '%'):
260
+ percent_in_solution = True
261
+ if has_character(pattern, '*'):
262
+ star_in_solution = True
263
+ if has_character(pattern, '?'):
264
+ question_mark_in_solution = True
265
+
266
+ # Then check the user query
267
+ for select in self.query.selects:
268
+ ast = select.ast
269
+
270
+ if not ast:
271
+ continue
272
+
273
+ for like in ast.find_all(exp.Like):
274
+ pattern = like.expression
275
+ if isinstance(pattern, exp.Literal):
276
+ if not self.solutions:
277
+ # No solutions to compare against
278
+ # Fall back to detecting just '*' or '?' usage
279
+ if has_character(pattern, '*') or has_character(pattern, '?'):
280
+ results.append(DetectedError(SqlErrors.SEM_44_INCORRECT_WILDCARD, (str(like),)))
281
+ continue
282
+
283
+ # query contains '*' while solution does not
284
+ # most likely an attempt to use '%' wildcard
285
+ if not star_in_solution and has_character(pattern, '*'):
286
+ results.append(DetectedError(SqlErrors.SEM_44_INCORRECT_WILDCARD, (str(like),)))
287
+
288
+ # query contains '?' while solution does not
289
+ # most likely an attempt to use '_' wildcard
290
+ if not question_mark_in_solution and has_character(pattern, '?'):
291
+ results.append(DetectedError(SqlErrors.SEM_44_INCORRECT_WILDCARD, (str(like),)))
292
+
293
+ # '_' instead of '%'
294
+ if percent_in_solution and not underscore_in_solution:
295
+ if has_character(pattern, '_') and not has_character(pattern, '%'):
296
+ results.append(DetectedError(SqlErrors.SEM_44_INCORRECT_WILDCARD, (str(like),)))
297
+
298
+ # '%' instead of '_'
299
+ if underscore_in_solution and not percent_in_solution:
300
+ if has_character(pattern, '%') and not has_character(pattern, '_'):
301
+ results.append(DetectedError(SqlErrors.SEM_44_INCORRECT_WILDCARD, (str(like),)))
302
+
303
+
304
+
305
+ return results
306
+
307
+ # TODO: refactor
308
+ def sem_45_mixing_comparison_and_null(self) -> list[DetectedError]:
309
+ '''Detect mixing of >0 with IS NOT NULL or empty string with IS NULL on the same column'''
310
+ return []
311
+
312
+ results = []
313
+ # a > 0 AND a IS NOT NULL
314
+ m = re.search(r"(\w+)\s*>\s*0\s+AND\s+\1\s+IS\s+NOT\s+NULL", self.query, re.IGNORECASE)
315
+ if m:
316
+ results.append((
317
+ SqlErrors.SEM_45_MIXING_A_GREATER_THAN_0_WITH_IS_NOT_NULL,
318
+ m.group(0)
319
+ ))
320
+
321
+ # a = '' AND a IS NULL
322
+ m2 = re.search(r"(\w+)\s*=\s*''\s+AND\s+\1\s+IS\s+NULL", self.query, re.IGNORECASE)
323
+ if m2:
324
+ results.append((
325
+ SqlErrors.SEM_45_MIXING_A_GREATER_THAN_0_WITH_IS_NOT_NULL,
326
+ m2.group(0)
327
+ ))
328
+
329
+ return results
330
+
331
+ #TODO: implement
332
+ def sem_46_null_in_in_subquery(self) -> list[DetectedError]:
333
+ '''Detect potential NULL/UNKNOWN in IN/ANY/ALL subqueries when subquery column is nullable.
334
+ heuristically assume that if a column is not declared as NOT NULL, then every typical
335
+ database state contains at least one row in which it is null. '''
336
+ return []
337
+
338
+ # TODO: implement
339
+ def sem_47_join_on_incorrect_column(self) -> list[DetectedError]:
340
+ '''
341
+ For each JOIN … ON: require at least one “A.col = B.col” in the ON clause.
342
+ For comma-style joins (FROM A, B): require at least one “A.col = B.col” in the WHERE.
343
+ If no such predicate is found for a given join, emit SEM_2_JOIN_ON_INCORRECT_COLUMN.
344
+ If the join operation is a self-join, then skip the check.
345
+ Check based on the content of the catalog column_metadata the compatibility of the columns.
346
+ '''
347
+ return []
348
+
349
+ # TODO: implement
350
+ def sem_48_missing_join(self) -> list[DetectedError]:
351
+ return []
352
+
353
+ # TODO: implement
354
+ def sem_49_duplicate_rows(self) -> list[DetectedError]:
355
+ return []
356
+
357
+ # TODO: refactor
358
+ def sem_50_constant_column_output(self) -> list[DetectedError]:
359
+ '''
360
+ Detect when a SELECT-list column is constrained to a constant.
361
+ - If WHERE has A = c and A is in SELECT, warn.
362
+ - If WHERE has A = c and also A = B, then both A and B in SELECT should warn.
363
+ '''
364
+ return []
365
+
366
+ results = []
367
+
368
+ # 1. Extract selected columns (simple ones only)
369
+ select_cols = set()
370
+ for expr in self.query_map.get("select_value", []):
371
+ expr = expr.strip()
372
+ if expr == "*" or "(" in expr:
373
+ continue
374
+ # Remove potential table qualification and aliases for the check
375
+ col = expr.split("AS")[0].strip().split(".")[-1]
376
+ select_cols.add(col.lower())
377
+
378
+ # 2. Extract WHERE clause from the query text
379
+ where_clause_match = re.search(
380
+ r"\bWHERE\b\s+(?P<w>.+?)(?=(?:\bGROUP\b|\bHAVING\b|\bORDER\b|$))",
381
+ self.query, re.IGNORECASE | re.DOTALL
382
+ )
383
+ if not where_clause_match:
384
+ return results
385
+
386
+ where_clause = where_clause_match.group("w")
387
+
388
+ # Remove subqueries from the WHERE clause text to avoid checking their conditions.
389
+ # This prevents the recognizer from applying a subquery's constraints to the outer query.
390
+ where_clause_no_subs = re.sub(r'\(\s*SELECT.*?\)', '', where_clause, flags=re.IGNORECASE | re.DOTALL)
391
+
392
+ # 3. Detect constant columns and column-to-column equalities in the processed clause
393
+ const_re = re.compile(
394
+ r"(?P<col>[a-zA-Z_]\w*(?:\.\w+)?)\s*=\s*(?P<const>'[^']*'|\d+(?:\.\d+)?)",
395
+ re.IGNORECASE
396
+ )
397
+ eq_re = re.compile(
398
+ r"(?P<c1>[a-zA-Z_]\w*(?:\.\w+)?)\s*=\s*(?P<c2>[a-zA-Z_]\w*(?:\.\w+)?)",
399
+ re.IGNORECASE
400
+ )
401
+
402
+ const_map = {}
403
+ for m in const_re.finditer(where_clause_no_subs):
404
+ col = m.group("col").split(".")[-1].lower()
405
+ const_map[col] = m.group("const")
406
+
407
+ adj = {}
408
+ for m in eq_re.finditer(where_clause_no_subs):
409
+ c1 = m.group("c1").split(".")[-1].lower()
410
+ c2 = m.group("c2").split(".")[-1].lower()
411
+ if c1 in const_map or c2 in const_map:
412
+ continue
413
+ # Avoid self-loops from simple equality checks
414
+ if c1 != c2:
415
+ adj.setdefault(c1, set()).add(c2)
416
+ adj.setdefault(c2, set()).add(c1)
417
+
418
+ # 4. Propagate constant constraints via BFS
419
+ constant_cols = set(const_map.keys())
420
+ for start_node in list(const_map):
421
+ queue = [start_node]
422
+ visited = {start_node}
423
+ while queue:
424
+ u = queue.pop(0)
425
+ for v in adj.get(u, []):
426
+ if v not in visited:
427
+ visited.add(v)
428
+ queue.append(v)
429
+ constant_cols.update(visited)
430
+
431
+ # 5. Check if any selected columns are constrained to be constant
432
+ for col in select_cols:
433
+ if col in constant_cols:
434
+ # Find the original casing for the error message
435
+ original_col_name = next((c for c in self.query_map.get("select_value", []) if c.lower().endswith(col)), col)
436
+ msg = f"Column `{original_col_name}` in SELECT is constrained to constant"
437
+ results.append((SqlErrors.SEM_50_CONSTANT_COLUMN_OUTPUT, msg))
438
+
439
+ return results
440
+
441
+ # TODO: refactor
442
+ def sem_51_duplicate_column_output(self) -> list[DetectedError]:
443
+ '''
444
+ Detects if the same column or expression appears multiple times in the SELECT list.
445
+ '''
446
+ return []
447
+
448
+ results = []
449
+
450
+ # 1. Usa il SELECT list già parsato dalla query_map
451
+ select_items = self.query_map.get("select_value", [])
452
+ if not select_items:
453
+ return results
454
+
455
+ norm_counts = {}
456
+
457
+ for expr in select_items:
458
+ # Normalizza l’espressione: rimuove alias, spazi, case-insensitive
459
+ clean_expr = expr.strip()
460
+
461
+ # Rimuovi alias "AS xyz" o finali (non rompere funzioni con parentesi)
462
+ clean_expr = re.sub(r"\s+AS\s+\w+$", "", clean_expr, flags=re.IGNORECASE)
463
+ clean_expr = re.sub(r"\s+\w+$", "", clean_expr)
464
+
465
+ # Normalizza spazi e case
466
+ key = clean_expr.strip().lower()
467
+ norm_counts[key] = norm_counts.get(key, 0) + 1
468
+
469
+ # 2. Rileva duplicati
470
+ for expr, count in norm_counts.items():
471
+ if count > 1:
472
+ msg = f"Output expression `{expr}` appears {count} times in SELECT"
473
+ results.append((
474
+ SqlErrors.SEM_51_DUPLICATE_COLUMN_OUTPUT,
475
+ msg
476
+ ))
477
+
478
+ return results
479
+
480
+
481
+ # region Helper methods
482
+ def has_character(literal: exp.Literal, chars: str) -> bool:
483
+ '''
484
+ Check if the literal contains a specific character.
485
+ If `chars` contains multiple characters, check if any of them are present.
486
+ '''
487
+ value = literal.this
488
+
489
+ if not isinstance(value, str):
490
+ return False
491
+
492
+ return any(c in value for c in chars)
493
+ # endregion