sqlchecker 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlchecker/__init__.py +58 -0
- sqlchecker/detectors/__init__.py +103 -0
- sqlchecker/detectors/base.py +44 -0
- sqlchecker/detectors/complications.py +375 -0
- sqlchecker/detectors/logical.py +732 -0
- sqlchecker/detectors/semantic.py +289 -0
- sqlchecker/detectors/syntax.py +1140 -0
- sqlchecker-0.3.1.dist-info/METADATA +153 -0
- sqlchecker-0.3.1.dist-info/RECORD +11 -0
- sqlchecker-0.3.1.dist-info/WHEEL +4 -0
- sqlchecker-0.3.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1140 @@
|
|
|
1
|
+
'''Detector for syntax errors in SQL queries.'''
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import difflib
|
|
5
|
+
import re
|
|
6
|
+
import sqlparse
|
|
7
|
+
from sqlglot import exp
|
|
8
|
+
from typing import Callable
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
from sqlerrors import SqlErrors
|
|
11
|
+
from sqlscope import Query
|
|
12
|
+
from sqlscope.query.set_operations.set_operation import SetOperation
|
|
13
|
+
from sqlscope.query.typechecking import get_type, collect_errors
|
|
14
|
+
from sqlscope import util
|
|
15
|
+
|
|
16
|
+
from .base import BaseDetector, DetectedError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SyntaxErrorDetector(BaseDetector):
|
|
20
|
+
'''Detector for syntax errors in SQL queries.'''
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
*,
|
|
24
|
+
query: Query,
|
|
25
|
+
update_query: Callable[[str, str | None], None],
|
|
26
|
+
solutions: list[Query] = [],
|
|
27
|
+
):
|
|
28
|
+
super().__init__(
|
|
29
|
+
query=query,
|
|
30
|
+
solutions=solutions,
|
|
31
|
+
update_query=update_query,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def run(self) -> list[DetectedError]:
|
|
35
|
+
'''Run the detector and return a list of detected errors with their descriptions'''
|
|
36
|
+
results: list[DetectedError] = super().run()
|
|
37
|
+
|
|
38
|
+
# 1) fix stray semicolons (to allow ast building for subsequent checks)
|
|
39
|
+
checks = [self.detect_22_38_additional_omitted_semicolons]
|
|
40
|
+
|
|
41
|
+
for check in checks:
|
|
42
|
+
check_result, fixed_query_str = check()
|
|
43
|
+
results.extend(check_result)
|
|
44
|
+
|
|
45
|
+
if fixed_query_str != self.query.sql:
|
|
46
|
+
self.update_query(fixed_query_str, check.__name__)
|
|
47
|
+
|
|
48
|
+
# 2) detect unexisting objects (before corrections, to avoid false positives)
|
|
49
|
+
unexisting_checks = [
|
|
50
|
+
self.detect_2_4_undefined_columns_ambiguous_columns, # ok
|
|
51
|
+
self.detect_2_ambiguous_function, # TODO: implement
|
|
52
|
+
self.detect_5_undefined_functions, # ok
|
|
53
|
+
self.detect_6_undefined_function_parameters, # ok
|
|
54
|
+
self.detect_7_8_undefined_tables, # ok
|
|
55
|
+
self.detect_25_using_an_undefined_correlation_name, # TODO: implement
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
for check in unexisting_checks:
|
|
59
|
+
check_result = check()
|
|
60
|
+
results.extend(check_result)
|
|
61
|
+
|
|
62
|
+
# 3.1) detect fixable errors and apply corrections for improved subsequent checks
|
|
63
|
+
# NOTE: leave in this order!
|
|
64
|
+
misspelling_checks = [
|
|
65
|
+
self.detect_33_omitting_commas, # TODO: implement/refactor
|
|
66
|
+
self.detect_27_confusing_table_names_with_column_names, # TODO: implement
|
|
67
|
+
self.detect_36_nonstandard_operators, # ok
|
|
68
|
+
self.detect_9_misspellings_schemas_tables, # ok
|
|
69
|
+
self.detect_9_misspellings_columns, # ok
|
|
70
|
+
self.detect_10_synonyms, # TODO: implement
|
|
71
|
+
self.detect_11_omitted_quotes, # TODO: implement/refactor
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
# 3.2) apply corrections and re-parse query
|
|
75
|
+
corrected_sql = self.query.sql
|
|
76
|
+
for check in misspelling_checks:
|
|
77
|
+
for error in check():
|
|
78
|
+
results.append(error)
|
|
79
|
+
pattern = r'\b' + re.escape(error.data[0]) + r'\b'
|
|
80
|
+
corrected_sql = re.sub(
|
|
81
|
+
pattern,
|
|
82
|
+
error.data[1],
|
|
83
|
+
corrected_sql,
|
|
84
|
+
# flags=re.IGNORECASE
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Use the corrected query from here on (across all detectors)
|
|
88
|
+
if corrected_sql != self.query.sql:
|
|
89
|
+
self.update_query(corrected_sql, check.__name__)
|
|
90
|
+
|
|
91
|
+
# Proceed with all other checks
|
|
92
|
+
checks = [
|
|
93
|
+
self.detect_12_failure_to_specify_column_name_twice, # TODO: implement
|
|
94
|
+
self.detect_13_data_type_mismatch, # ok
|
|
95
|
+
self.detect_14_aggregate_function_outside_select_or_having, # ok
|
|
96
|
+
self.detect_15_aggregate_functions_cannot_be_nested, # ok
|
|
97
|
+
self.detect_16_extraneous_or_omitted_grouping_column, # ok
|
|
98
|
+
self.detect_17_having_without_group_by, # ok
|
|
99
|
+
self.detect_106_missing_quantifier, #TODO: implement
|
|
100
|
+
self.detect_18_confusing_function_with_function_parameter, # TODO: implement
|
|
101
|
+
self.detect_19_using_where_twice, # ok
|
|
102
|
+
self.detect_20_omitted_from_clause, # ok
|
|
103
|
+
self.detect_21_comparison_with_null, # ok
|
|
104
|
+
self.detect_23_date_time_field_overflow, # TODO: implement, needs AST
|
|
105
|
+
self.detect_24_duplicate_clause, # ok
|
|
106
|
+
self.detect_26_too_many_columns_in_subquery, # ok
|
|
107
|
+
self.detect_30_confused_order_of_keywords, # ok
|
|
108
|
+
self.detect_32_confused_syntax_of_keywords, # TODO: check and refactor
|
|
109
|
+
self.detect_107_108_curly_square_or_unmatched_brackets, # ok
|
|
110
|
+
self.detect_35_is_where_not_applicable, # ok
|
|
111
|
+
self.detect_36_nonstandard_keywords_or_standard_keywords_in_wrong_context, #TODO: implement
|
|
112
|
+
self.detect_109_different_tuples_in_set_operation, #TODO: implement
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
for check in checks:
|
|
116
|
+
results.extend(check())
|
|
117
|
+
return results
|
|
118
|
+
|
|
119
|
+
# region 1) Semicolons
|
|
120
|
+
def detect_22_38_additional_omitted_semicolons(self) -> tuple[list[DetectedError], str]:
|
|
121
|
+
'''
|
|
122
|
+
Flags queries that omit the semicolon at the end or have multiple semicolons.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
- List of DetectedError instances for semicolon issues.
|
|
126
|
+
- The cleaned query string with extra semicolons removed.
|
|
127
|
+
'''
|
|
128
|
+
|
|
129
|
+
results: list[DetectedError] = []
|
|
130
|
+
|
|
131
|
+
all_tokens = []
|
|
132
|
+
for statement in self.query.all_statements:
|
|
133
|
+
all_tokens.extend(list(statement.flatten()))
|
|
134
|
+
|
|
135
|
+
good_tokens = []
|
|
136
|
+
trailing_semicolon_found = False
|
|
137
|
+
non_whitespace_found = False
|
|
138
|
+
|
|
139
|
+
for token in reversed(all_tokens): # start from end to preserve only the last semicolon
|
|
140
|
+
# check for whitespace/newline
|
|
141
|
+
if token.ttype in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline):
|
|
142
|
+
# keep as is and continue
|
|
143
|
+
good_tokens.append(token.value)
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
# check for semicolons: the first one before any non-whitespace is kept, others are flagged
|
|
147
|
+
if token.ttype == sqlparse.tokens.Punctuation and token.value == ';':
|
|
148
|
+
if non_whitespace_found:
|
|
149
|
+
# we encountered a semicolon in the middle of the query!
|
|
150
|
+
# we don't care if this is the first one we encounter, it's surely not supposed to be here
|
|
151
|
+
results.append(DetectedError(SqlErrors.ADDITIONAL_SEMICOLON))
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
if not trailing_semicolon_found:
|
|
155
|
+
# we encountered the trailing semicolon for the first time
|
|
156
|
+
# it's good, keep it
|
|
157
|
+
good_tokens.append(token.value)
|
|
158
|
+
trailing_semicolon_found = True
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
# else, we have already found the trailing semicolon, so this is an extra one at the end
|
|
162
|
+
results.append(DetectedError(SqlErrors.ADDITIONAL_SEMICOLON))
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
# any other token
|
|
166
|
+
non_whitespace_found = True
|
|
167
|
+
good_tokens.append(token.value)
|
|
168
|
+
|
|
169
|
+
if not trailing_semicolon_found:
|
|
170
|
+
results.append(DetectedError(SqlErrors.OMITTED_SEMICOLON))
|
|
171
|
+
|
|
172
|
+
return (results, ''.join(reversed(good_tokens)))
|
|
173
|
+
# endregion
|
|
174
|
+
|
|
175
|
+
# region 2) Pre-fixing
|
|
176
|
+
def detect_2_ambiguous_function(self) -> list[DetectedError]:
|
|
177
|
+
return []
|
|
178
|
+
|
|
179
|
+
def detect_7_8_undefined_tables(self) -> list[DetectedError]:
|
|
180
|
+
'''
|
|
181
|
+
Checks for undefined tables in the FROM clause
|
|
182
|
+
'''
|
|
183
|
+
|
|
184
|
+
results: list[DetectedError] = []
|
|
185
|
+
|
|
186
|
+
for select in self.query.selects:
|
|
187
|
+
select = select.strip_subqueries()
|
|
188
|
+
|
|
189
|
+
if select.ast is None:
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
for table in select.ast.find_all(exp.Table):
|
|
193
|
+
table_name = util.ast.table.get_real_name(table)
|
|
194
|
+
schema_name = util.ast.table.get_schema(table)
|
|
195
|
+
|
|
196
|
+
if schema_name:
|
|
197
|
+
# Fully qualified table (schema.table)
|
|
198
|
+
if not select.catalog.has_schema(schema_name):
|
|
199
|
+
results.append(DetectedError(SqlErrors.INVALID_SCHEMA_NAME, (table.sql(),)))
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
if not select.catalog.has_table(schema_name, table_name):
|
|
203
|
+
results.append(DetectedError(SqlErrors.UNDEFINED_OBJECT, (table.sql(),)))
|
|
204
|
+
continue
|
|
205
|
+
else:
|
|
206
|
+
# Unqualified table (table)
|
|
207
|
+
# Check if table is a CTE
|
|
208
|
+
if select.catalog.has_table('', table_name):
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
# Check if table is in the current schema
|
|
212
|
+
if select.catalog.has_table(select.search_path, table_name):
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
results.append(DetectedError(SqlErrors.UNDEFINED_OBJECT, (table.sql(),)))
|
|
216
|
+
|
|
217
|
+
return results
|
|
218
|
+
|
|
219
|
+
def detect_2_4_undefined_columns_ambiguous_columns(self) -> list[DetectedError]:
|
|
220
|
+
'''
|
|
221
|
+
Checks for undefined and ambiguous columns.
|
|
222
|
+
'''
|
|
223
|
+
|
|
224
|
+
results: list[DetectedError] = []
|
|
225
|
+
|
|
226
|
+
for select in self.query.selects:
|
|
227
|
+
select = select.strip_subqueries()
|
|
228
|
+
|
|
229
|
+
if select.ast is None:
|
|
230
|
+
continue
|
|
231
|
+
|
|
232
|
+
for column in select.ast.find_all(exp.Column):
|
|
233
|
+
# skip `table.*` syntax, we only want to check actual column references
|
|
234
|
+
if isinstance(column.this, exp.Star):
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
column_name = util.ast.column.get_name(column)
|
|
238
|
+
table_name = util.ast.column.get_table(column)
|
|
239
|
+
|
|
240
|
+
possible_matches = []
|
|
241
|
+
|
|
242
|
+
if table_name:
|
|
243
|
+
# Qualified column (table.column)
|
|
244
|
+
for table in select.referenced_tables:
|
|
245
|
+
if table.name != table_name:
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
for possible_match in table.columns:
|
|
249
|
+
if possible_match.name == column_name:
|
|
250
|
+
possible_matches.append(f'{table_name}.{column_name}')
|
|
251
|
+
else:
|
|
252
|
+
# Unqualified column (column)
|
|
253
|
+
for table in select.referenced_tables:
|
|
254
|
+
for possible_match in table.columns:
|
|
255
|
+
if possible_match.name == column_name:
|
|
256
|
+
possible_matches.append(f'{table.name}.{column_name}')
|
|
257
|
+
|
|
258
|
+
if len(possible_matches) == 0:
|
|
259
|
+
results.append(DetectedError(SqlErrors.UNDEFINED_COLUMN, (column.sql(),)))
|
|
260
|
+
elif len(possible_matches) > 1:
|
|
261
|
+
results.append(DetectedError(SqlErrors.AMBIGUOUS_COLUMN, (column.sql(), possible_matches)))
|
|
262
|
+
|
|
263
|
+
return results
|
|
264
|
+
|
|
265
|
+
def detect_5_undefined_functions(self) -> list[DetectedError]:
|
|
266
|
+
'''Checks for undefined functions (i.e. invalid names followed by parentheses).'''
|
|
267
|
+
|
|
268
|
+
results: list[DetectedError] = []
|
|
269
|
+
|
|
270
|
+
# standard_functions = {
|
|
271
|
+
known_aggregate_functions = {
|
|
272
|
+
'SUM', 'AVG', 'COUNT', 'MIN', 'MAX',
|
|
273
|
+
'IN', 'EXISTS', 'ANY', 'ALL',
|
|
274
|
+
'COALESCE', 'NULLIF', 'CAST', 'CONVERT',
|
|
275
|
+
'UPPER', 'LOWER', 'LENGTH', 'SUBSTRING',
|
|
276
|
+
'NOW', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
|
|
277
|
+
}
|
|
278
|
+
user_defined_functions = set() # TODO: self.catalog.functions
|
|
279
|
+
|
|
280
|
+
all_functions = known_aggregate_functions.union(user_defined_functions)
|
|
281
|
+
|
|
282
|
+
for func, clause in self.query.functions:
|
|
283
|
+
func_name = func.get_name()
|
|
284
|
+
|
|
285
|
+
if func_name is None:
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
if func_name.upper() not in all_functions:
|
|
289
|
+
results.append(DetectedError(SqlErrors.UNDEFINED_FUNCTION, (func_name, clause)))
|
|
290
|
+
|
|
291
|
+
return results
|
|
292
|
+
|
|
293
|
+
def detect_6_undefined_function_parameters(self) -> list[DetectedError]:
|
|
294
|
+
'''Checks for undefined function parameters'''
|
|
295
|
+
|
|
296
|
+
results: list[DetectedError] = []
|
|
297
|
+
|
|
298
|
+
for token, val in self.query.tokens:
|
|
299
|
+
if any(val.startswith(p) for p in (':', '@', '?')):
|
|
300
|
+
results.append(DetectedError(SqlErrors.UNDEFINED_PARAMETER, (val,)))
|
|
301
|
+
|
|
302
|
+
return results
|
|
303
|
+
|
|
304
|
+
def detect_25_using_an_undefined_correlation_name(self) -> list[DetectedError]:
|
|
305
|
+
return []
|
|
306
|
+
# endregion
|
|
307
|
+
|
|
308
|
+
# region 3) Fixable errors
|
|
309
|
+
def detect_9_misspellings_schemas_tables(self) -> list[DetectedError]:
|
|
310
|
+
'''
|
|
311
|
+
Check for misspellings in table names.
|
|
312
|
+
'''
|
|
313
|
+
|
|
314
|
+
results: set[DetectedError] = set() # use a set to avoid applying the same correction multiple times
|
|
315
|
+
|
|
316
|
+
for select in self.query.selects:
|
|
317
|
+
select = select.strip_subqueries()
|
|
318
|
+
|
|
319
|
+
if select.ast is None:
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
for table in select.ast.find_all(exp.Table):
|
|
323
|
+
table = deepcopy(table) # avoid modifying the original AST until we are sure we want to apply the correction
|
|
324
|
+
table_str = table.sql()
|
|
325
|
+
table_name = util.ast.table.get_real_name(table)
|
|
326
|
+
schema_name = util.ast.table.get_schema(table)
|
|
327
|
+
|
|
328
|
+
if schema_name:
|
|
329
|
+
# Fully qualified table (schema.table)
|
|
330
|
+
if select.catalog.has_table(schema_name, table_name):
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
# check "schema.table" for more accurate matches in edge cases (i.e. can't determine if the misspelled part is schema or table)
|
|
334
|
+
available_tables = {f'{s}.{t}' for s in select.catalog.schema_names for t in select.catalog[s].table_names}
|
|
335
|
+
match = difflib.get_close_matches(f'{schema_name}.{table_name}', available_tables, n=1, cutoff=0.6)
|
|
336
|
+
if match:
|
|
337
|
+
s, t = match[0].split('.')
|
|
338
|
+
|
|
339
|
+
table.set('db', exp.TableAlias(this=exp.to_identifier(s, quoted=True)))
|
|
340
|
+
table.set('this', exp.to_identifier(t, quoted=True))
|
|
341
|
+
|
|
342
|
+
results.add(DetectedError(SqlErrors.MISSPELLINGS, (table_str, table.sql())))
|
|
343
|
+
continue
|
|
344
|
+
|
|
345
|
+
else:
|
|
346
|
+
# Unqualified table (table)
|
|
347
|
+
# Check if table is a CTE
|
|
348
|
+
if select.catalog.has_table('', table_name):
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
# Check if table is in the current schema
|
|
352
|
+
if select.catalog.has_table(select.search_path, table_name):
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
available_tables = {t for s in select.catalog.schema_names for t in select.catalog[s].table_names}
|
|
356
|
+
match = difflib.get_close_matches(table_name, available_tables, n=1, cutoff=0.6)
|
|
357
|
+
if match:
|
|
358
|
+
db = next(s for s in select.catalog.schema_names if select.catalog.has_table(s, match[0]))
|
|
359
|
+
table.set('this', exp.to_identifier(match[0], quoted=True))
|
|
360
|
+
if db != select.search_path:
|
|
361
|
+
table.set('db', exp.TableAlias(this=exp.to_identifier(db, quoted=True)))
|
|
362
|
+
results.add(DetectedError(SqlErrors.MISSPELLINGS, (table_str, table.sql())))
|
|
363
|
+
|
|
364
|
+
return [*results]
|
|
365
|
+
|
|
366
|
+
def detect_9_misspellings_columns(self) -> list[DetectedError]:
|
|
367
|
+
'''
|
|
368
|
+
Check for misspellings in table and column names.
|
|
369
|
+
Performs two passes: first try to match objects to their own type, then try to match to any type.
|
|
370
|
+
'''
|
|
371
|
+
results: set[DetectedError] = set() # use a set to avoid applying the same correction multiple times
|
|
372
|
+
|
|
373
|
+
for select in self.query.selects:
|
|
374
|
+
select = select.strip_subqueries()
|
|
375
|
+
|
|
376
|
+
if select.ast is None:
|
|
377
|
+
continue
|
|
378
|
+
|
|
379
|
+
for column in select.ast.find_all(exp.Column):
|
|
380
|
+
# skip `table.*` syntax, we only want to check actual column references
|
|
381
|
+
if isinstance(column.this, exp.Star):
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
column = deepcopy(column) # avoid modifying the original AST until we are sure we want to apply the correction
|
|
385
|
+
column_str = column.sql()
|
|
386
|
+
column_name = util.ast.column.get_name(column)
|
|
387
|
+
table_name = util.ast.column.get_table(column)
|
|
388
|
+
|
|
389
|
+
found = False
|
|
390
|
+
|
|
391
|
+
for table in select.referenced_tables:
|
|
392
|
+
if table_name and table.name != table_name:
|
|
393
|
+
# Qualified column (table.column)
|
|
394
|
+
# check if column exists only in the specified table
|
|
395
|
+
continue
|
|
396
|
+
if table.has_column(column_name):
|
|
397
|
+
found = True
|
|
398
|
+
break
|
|
399
|
+
|
|
400
|
+
if found:
|
|
401
|
+
continue
|
|
402
|
+
|
|
403
|
+
if table_name:
|
|
404
|
+
# Qualified column (table.column)
|
|
405
|
+
available_columns = {f'{t.name}.{c.name}' for t in select.referenced_tables for c in t.columns}
|
|
406
|
+
else:
|
|
407
|
+
# Unqualified column (column)
|
|
408
|
+
available_columns = {c.name for t in select.referenced_tables for c in t.columns}
|
|
409
|
+
|
|
410
|
+
match = difflib.get_close_matches(column_name if not table_name else f'{table_name}.{column_name}', available_columns, n=1, cutoff=0.6)
|
|
411
|
+
if match:
|
|
412
|
+
if table_name:
|
|
413
|
+
matched_table, matched_column = match[0].split('.')
|
|
414
|
+
column.set('table', exp.to_identifier(matched_table, quoted=True))
|
|
415
|
+
column.set('this', exp.to_identifier(matched_column, quoted=True))
|
|
416
|
+
else:
|
|
417
|
+
column.set('this', exp.to_identifier(match[0], quoted=True))
|
|
418
|
+
|
|
419
|
+
results.add(DetectedError(SqlErrors.MISSPELLINGS, (column_str, column.sql())))
|
|
420
|
+
|
|
421
|
+
return [*results]
|
|
422
|
+
|
|
423
|
+
def detect_10_synonyms(self) -> list[DetectedError]:
|
|
424
|
+
return []
|
|
425
|
+
|
|
426
|
+
def detect_11_omitted_quotes(self) -> list[DetectedError]:
|
|
427
|
+
'''
|
|
428
|
+
Checks for potential omitting of quotes around character data in WHERE/HAVING clauses.
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
A list of DetectedErrors. data=(offending_value,corrected_value)
|
|
432
|
+
'''
|
|
433
|
+
return []
|
|
434
|
+
|
|
435
|
+
results: list[DetectedError] = []
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
comparisons = self.query.comparisons
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# for comparison in comparisons:
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
return results
|
|
446
|
+
|
|
447
|
+
# # 3. Build sets of ALL known identifiers for the entire query (main + subqueries + CTEs)
|
|
448
|
+
# valid_source_identifiers = set()
|
|
449
|
+
# all_known_columns_lower = set()
|
|
450
|
+
# db_tables = self.catalog.get('table_columns', {})
|
|
451
|
+
|
|
452
|
+
# # -- Main Query --
|
|
453
|
+
# main_query_sources = self._get_referenced_tables()
|
|
454
|
+
# main_alias_map = self.query_map.alias_mapping
|
|
455
|
+
# valid_source_identifiers.update(s.lower() for s in main_query_sources)
|
|
456
|
+
# valid_source_identifiers.update(a.lower() for a in main_alias_map.keys())
|
|
457
|
+
# for source in main_query_sources:
|
|
458
|
+
# actual_base_name = next((k for k in db_tables if k.lower() == source.lower()), None)
|
|
459
|
+
# if actual_base_name:
|
|
460
|
+
# all_known_columns_lower.update(c.lower() for c in db_tables[actual_base_name])
|
|
461
|
+
|
|
462
|
+
# # -- Subqueries --
|
|
463
|
+
# for subq_map in self.subquery_map.values():
|
|
464
|
+
# sub_sources = []
|
|
465
|
+
# sub_from = subq_map.from_value
|
|
466
|
+
# if sub_from:
|
|
467
|
+
# sub_sources.append(sub_from)
|
|
468
|
+
# sub_joins = subq_map.join_value
|
|
469
|
+
# sub_sources.extend(sub_joins)
|
|
470
|
+
# sub_aliases = subq_map.alias_mapping
|
|
471
|
+
# valid_source_identifiers.update(s.lower() for s in sub_sources)
|
|
472
|
+
# valid_source_identifiers.update(a.lower() for a in sub_aliases.keys())
|
|
473
|
+
# for source in sub_sources:
|
|
474
|
+
# actual_base_name = next((k for k in db_tables if k.lower() == source.lower()), None)
|
|
475
|
+
# if actual_base_name:
|
|
476
|
+
# all_known_columns_lower.update(c.lower() for c in db_tables[actual_base_name])
|
|
477
|
+
|
|
478
|
+
# # -- CTEs --
|
|
479
|
+
# if self.cte_map:
|
|
480
|
+
# valid_source_identifiers.update(name.lower() for name in self.cte_map.keys())
|
|
481
|
+
# for cte_name, cte_columns in self.cte_catalog.cte_tables.items():
|
|
482
|
+
# all_known_columns_lower.update(c.lower() for c in cte_columns)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
# 4. Main Token-based Check
|
|
486
|
+
is_where_or_having = False
|
|
487
|
+
is_rhs_of_comparison = False # nothing prevents an expression to have its sides inverted, although it's unlikely to happen
|
|
488
|
+
comparison_operators = {'=', '<>', '!=', '<', '>', '<=', '>=', 'LIKE', 'NOT LIKE'}
|
|
489
|
+
known_keywords = {'SELECT', 'FROM', 'WHERE', 'JOIN', 'ON', 'GROUP', 'BY', 'ORDER', 'HAVING', 'LIMIT', 'AS', 'DISTINCT'}
|
|
490
|
+
|
|
491
|
+
for i, (tt, val) in enumerate(self.query.tokens):
|
|
492
|
+
if tt == sqlparse.tokens.Keyword and val.upper() in {'WHERE', 'HAVING'}:
|
|
493
|
+
is_where_or_having = True
|
|
494
|
+
if tt == sqlparse.tokens.Error:
|
|
495
|
+
continue
|
|
496
|
+
if val in comparison_operators:
|
|
497
|
+
is_rhs_of_comparison = True
|
|
498
|
+
continue
|
|
499
|
+
if tt in sqlparse.tokens.Literal or tt in (sqlparse.tokens.String.Single, sqlparse.tokens.String.Symbol):
|
|
500
|
+
if is_where_or_having and is_rhs_of_comparison:
|
|
501
|
+
stripped_val = val.strip()
|
|
502
|
+
if stripped_val.startswith('"') and stripped_val.endswith('"'):
|
|
503
|
+
results.append(DetectedError(SqlErrors.SYN_11_OMITTING_QUOTES_AROUND_CHARACTER_DATA, (val,)))
|
|
504
|
+
is_rhs_of_comparison = False
|
|
505
|
+
continue
|
|
506
|
+
if tt is not sqlparse.tokens.Name:
|
|
507
|
+
is_rhs_of_comparison = False
|
|
508
|
+
continue
|
|
509
|
+
if val.upper() in known_keywords:
|
|
510
|
+
is_rhs_of_comparison = False
|
|
511
|
+
continue
|
|
512
|
+
if val.lower() in valid_source_identifiers:
|
|
513
|
+
is_rhs_of_comparison = False
|
|
514
|
+
continue
|
|
515
|
+
if val.lower() in output_aliases_lower:
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
clean_val = val
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# if string OP notcol -> error
|
|
522
|
+
# if date OP notcol2 -> error
|
|
523
|
+
# if extract(notstring FROM ...) -> error
|
|
524
|
+
# like notstring -> error
|
|
525
|
+
|
|
526
|
+
# is this the correct approach? col OP notColumn
|
|
527
|
+
# TODO: literal or string.single/string.symbol in RHS of WHERE/HAVING
|
|
528
|
+
if is_where_or_having and is_rhs_of_comparison:
|
|
529
|
+
if clean_val.isalpha() and clean_val.lower() not in all_known_columns_lower:
|
|
530
|
+
results.append(DetectedError(SqlErrors.SYN_11_OMITTING_QUOTES_AROUND_CHARACTER_DATA, (val,)))
|
|
531
|
+
is_rhs_of_comparison = False
|
|
532
|
+
continue
|
|
533
|
+
|
|
534
|
+
return results
|
|
535
|
+
|
|
536
|
+
def detect_27_confusing_table_names_with_column_names(self) -> list[DetectedError]:
|
|
537
|
+
return []
|
|
538
|
+
|
|
539
|
+
def detect_33_omitting_commas(self) -> list[DetectedError]:
|
|
540
|
+
'''
|
|
541
|
+
Flags queries where commas are likely missing between column expressions
|
|
542
|
+
(e.g., SELECT name age FROM ..., GROUP BY x y).
|
|
543
|
+
'''
|
|
544
|
+
return []
|
|
545
|
+
|
|
546
|
+
results = []
|
|
547
|
+
|
|
548
|
+
clause_starters = {
|
|
549
|
+
"SELECT", "FROM", "WHERE", "GROUP BY", "HAVING", "ORDER BY", "LIMIT", "JOIN", "ON"
|
|
550
|
+
}
|
|
551
|
+
comma_required_clauses = {"SELECT", "GROUP BY", "ORDER BY", "VALUES"}
|
|
552
|
+
current_clause = None
|
|
553
|
+
in_clause_block = False
|
|
554
|
+
|
|
555
|
+
tokens = self.tokens
|
|
556
|
+
i = 0
|
|
557
|
+
while i < len(tokens):
|
|
558
|
+
tt, val = tokens[i]
|
|
559
|
+
val_upper = val.upper().strip()
|
|
560
|
+
|
|
561
|
+
# Detect clause start
|
|
562
|
+
if val_upper in {"SELECT", "GROUP BY", "ORDER BY", "VALUES"}:
|
|
563
|
+
current_clause = val_upper
|
|
564
|
+
in_clause_block = True
|
|
565
|
+
elif val_upper in clause_starters:
|
|
566
|
+
current_clause = None
|
|
567
|
+
in_clause_block = False
|
|
568
|
+
|
|
569
|
+
# Check for missing commas inside comma-required clauses
|
|
570
|
+
if in_clause_block and current_clause in comma_required_clauses:
|
|
571
|
+
is_valid_column = (
|
|
572
|
+
tt in sqlparse.tokens.Name or
|
|
573
|
+
(tt is None and val.replace('.', '').isalnum())
|
|
574
|
+
)
|
|
575
|
+
if is_valid_column and val_upper not in clause_starters:
|
|
576
|
+
# Look ahead to the next non-whitespace token
|
|
577
|
+
j = i + 1
|
|
578
|
+
while j < len(tokens) and tokens[j][0] in sqlparse.tokens.Whitespace:
|
|
579
|
+
j += 1
|
|
580
|
+
if j < len(tokens):
|
|
581
|
+
next_tt, next_val = tokens[j]
|
|
582
|
+
next_val_upper = next_val.upper().strip()
|
|
583
|
+
is_next_valid_column = (
|
|
584
|
+
next_tt in sqlparse.tokens.Name or
|
|
585
|
+
(next_tt is None and next_val.replace('.', '').isalnum())
|
|
586
|
+
)
|
|
587
|
+
if (
|
|
588
|
+
is_next_valid_column and
|
|
589
|
+
next_val_upper not in clause_starters and
|
|
590
|
+
next_val != ','
|
|
591
|
+
):
|
|
592
|
+
results.append((
|
|
593
|
+
SqlErrors.SYN_33_OMITTING_COMMAS,
|
|
594
|
+
f"Possible missing comma between '{val}' and '{next_val}' in {current_clause} clause"
|
|
595
|
+
))
|
|
596
|
+
i += 1
|
|
597
|
+
|
|
598
|
+
return results
|
|
599
|
+
|
|
600
|
+
def detect_36_nonstandard_operators(self) -> list[DetectedError]:
|
|
601
|
+
'''
|
|
602
|
+
Flags usage of non-standard or language-specific operators like &&, ||, ==, etc.
|
|
603
|
+
'''
|
|
604
|
+
|
|
605
|
+
results: list[DetectedError] = []
|
|
606
|
+
|
|
607
|
+
# dict {error: correction}
|
|
608
|
+
nonstandard_ops = {
|
|
609
|
+
'==' : '=',
|
|
610
|
+
'===' : '=',
|
|
611
|
+
'!==' : '<>',
|
|
612
|
+
'&&' : ' AND ',
|
|
613
|
+
'||' : ' OR ',
|
|
614
|
+
'!' : ' NOT ',
|
|
615
|
+
# '^' : '',
|
|
616
|
+
# '~' : '',
|
|
617
|
+
'>>' : '>',
|
|
618
|
+
'<<' : '<',
|
|
619
|
+
'≠' : '<>',
|
|
620
|
+
'≥' : '>=',
|
|
621
|
+
'≤' : '<=',
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
for ttype, val in self.query.tokens:
|
|
625
|
+
val_stripped = val.strip()
|
|
626
|
+
if ttype in sqlparse.tokens.Operator or ttype in sqlparse.tokens.Operator.Comparison or ttype == sqlparse.tokens.Error:
|
|
627
|
+
if val_stripped in nonstandard_ops:
|
|
628
|
+
correction = nonstandard_ops[val_stripped]
|
|
629
|
+
results.append(DetectedError(SqlErrors.NONSTANDARD_OPERATORS, (val_stripped, correction)))
|
|
630
|
+
|
|
631
|
+
return results
|
|
632
|
+
# endregion
|
|
633
|
+
|
|
634
|
+
# region 4) Other checks
|
|
635
|
+
def detect_12_failure_to_specify_column_name_twice(self) -> list[DetectedError]:
|
|
636
|
+
return []
|
|
637
|
+
|
|
638
|
+
def detect_13_data_type_mismatch(self) -> list[DetectedError]:
|
|
639
|
+
'''
|
|
640
|
+
Checks for data type mismatches in comparisons within the query.
|
|
641
|
+
'''
|
|
642
|
+
|
|
643
|
+
def parse_set_operation(set_op: 'SetOperation', location: str) -> list[DetectedError]:
|
|
644
|
+
|
|
645
|
+
'''
|
|
646
|
+
Util function to parse a SetOperation and check for data type mismatches among its main selects.
|
|
647
|
+
'''
|
|
648
|
+
errors: list[DetectedError] = []
|
|
649
|
+
expected_output = None # type of the first select's output
|
|
650
|
+
for select in set_op.main_selects:
|
|
651
|
+
|
|
652
|
+
typed_ast = select.typed_ast
|
|
653
|
+
|
|
654
|
+
if typed_ast is None:
|
|
655
|
+
continue
|
|
656
|
+
|
|
657
|
+
columns_type = get_type(typed_ast, select.catalog, select.search_path)
|
|
658
|
+
|
|
659
|
+
# 1st select: set expected output type
|
|
660
|
+
if expected_output is None:
|
|
661
|
+
expected_output = columns_type
|
|
662
|
+
else:
|
|
663
|
+
# compare with expected output type
|
|
664
|
+
if expected_output != columns_type:
|
|
665
|
+
errors.append(DetectedError(SqlErrors.DATA_TYPE_MISMATCH, (location,"setop types inconsistent")))
|
|
666
|
+
|
|
667
|
+
# load found messages
|
|
668
|
+
for message in columns_type.messages:
|
|
669
|
+
errors.append(DetectedError(SqlErrors.DATA_TYPE_MISMATCH, message))
|
|
670
|
+
|
|
671
|
+
return errors
|
|
672
|
+
|
|
673
|
+
results: list[DetectedError] = []
|
|
674
|
+
|
|
675
|
+
# CTEs
|
|
676
|
+
for cte in self.query.ctes:
|
|
677
|
+
results.extend(parse_set_operation(cte, f"CTE {cte.output.name}"))
|
|
678
|
+
|
|
679
|
+
# Main Query
|
|
680
|
+
results.extend(parse_set_operation(self.query.main_query, "Main Query"))
|
|
681
|
+
|
|
682
|
+
return results
|
|
683
|
+
|
|
684
|
+
def detect_14_aggregate_function_outside_select_or_having(self) -> list[DetectedError]:
|
|
685
|
+
'''
|
|
686
|
+
Flags use of aggregate functions (SUM, AVG, COUNT, MIN, MAX) outside SELECT or HAVING clauses,
|
|
687
|
+
respecting subquery scopes.
|
|
688
|
+
'''
|
|
689
|
+
|
|
690
|
+
results: list[DetectedError] = []
|
|
691
|
+
|
|
692
|
+
functions = self.query.functions
|
|
693
|
+
for function, clause in functions:
|
|
694
|
+
function_name = function.get_name()
|
|
695
|
+
if function_name and function_name.upper() in {'SUM', 'AVG', 'COUNT', 'MIN', 'MAX'}:
|
|
696
|
+
if clause not in {'SELECT', 'HAVING'}:
|
|
697
|
+
results.append(DetectedError(SqlErrors.AGGREGATE_FUNCTION_OUTSIDE_SELECT_OR_HAVING, (function_name, clause)))
|
|
698
|
+
|
|
699
|
+
return results
|
|
700
|
+
|
|
701
|
+
def detect_15_aggregate_functions_cannot_be_nested(self) -> list[DetectedError]:
|
|
702
|
+
'''
|
|
703
|
+
Flags cases where aggregate functions are nested within the *same query scope*,
|
|
704
|
+
which mainstream SQL dialects do not allow (e.g., SUM(MAX(x))).
|
|
705
|
+
'''
|
|
706
|
+
results: list[DetectedError] = []
|
|
707
|
+
|
|
708
|
+
for select in self.query.selects:
|
|
709
|
+
stripped = select.strip_subqueries()
|
|
710
|
+
|
|
711
|
+
if stripped.ast is None:
|
|
712
|
+
continue
|
|
713
|
+
|
|
714
|
+
aggregate_functions = stripped.ast.find_all(exp.AggFunc)
|
|
715
|
+
|
|
716
|
+
for outer_agg in aggregate_functions:
|
|
717
|
+
inner = outer_agg.this
|
|
718
|
+
for inner_agg in inner.find_all(exp.AggFunc):
|
|
719
|
+
results.append(DetectedError(
|
|
720
|
+
SqlErrors.AGGREGATE_FUNCTIONS_CANNOT_BE_NESTED,
|
|
721
|
+
(outer_agg.sql(),)
|
|
722
|
+
))
|
|
723
|
+
|
|
724
|
+
return results
|
|
725
|
+
|
|
726
|
+
def detect_16_extraneous_or_omitted_grouping_column(self) -> list[DetectedError]:
|
|
727
|
+
'''
|
|
728
|
+
All columns in SELECT must be either included in the GROUP BY clause or aggregated.
|
|
729
|
+
|
|
730
|
+
All non-aggregated columns in HAVING must not be included in the GROUP BY clause.
|
|
731
|
+
'''
|
|
732
|
+
|
|
733
|
+
@dataclass(frozen=True)
|
|
734
|
+
class ColumnInfo:
|
|
735
|
+
name: str
|
|
736
|
+
alias: str
|
|
737
|
+
is_aggregated: bool = False
|
|
738
|
+
|
|
739
|
+
def get_column_name(col: exp.Column | exp.Alias) -> ColumnInfo:
|
|
740
|
+
'''Return normalized column name and alias. If no alias, both are the same.'''
|
|
741
|
+
col_name = util.ast.column.get_real_name(col)
|
|
742
|
+
col_alias = util.ast.column.get_name(col)
|
|
743
|
+
return ColumnInfo(col_name, col_alias)
|
|
744
|
+
|
|
745
|
+
results: list[DetectedError] = []
|
|
746
|
+
|
|
747
|
+
for select in self.query.selects:
|
|
748
|
+
if select.ast is None:
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
if not select.group_by:
|
|
752
|
+
continue # no GROUP BY, skip
|
|
753
|
+
|
|
754
|
+
select_columns: list[ColumnInfo] = [] # we need a list for positional GROUP BY handling
|
|
755
|
+
|
|
756
|
+
# Gather non-aggregated columns from SELECT
|
|
757
|
+
for col in select.ast.expressions:
|
|
758
|
+
if isinstance(col, exp.Star):
|
|
759
|
+
# SELECT * case: expand to all columns from all referenced tables
|
|
760
|
+
for table in select.referenced_tables:
|
|
761
|
+
for table_col in table.columns:
|
|
762
|
+
select_columns.append(ColumnInfo(table_col.name, table_col.name))
|
|
763
|
+
if isinstance(col, exp.Column) or isinstance(col, exp.Alias):
|
|
764
|
+
col_name = get_column_name(col)
|
|
765
|
+
select_columns.append(col_name)
|
|
766
|
+
elif isinstance(col, exp.Func):
|
|
767
|
+
# aggregated, add the column but skip it later
|
|
768
|
+
select_columns.append(ColumnInfo(col.sql(), col.sql(), is_aggregated=True))
|
|
769
|
+
else:
|
|
770
|
+
# Complex expression: try to extract columns
|
|
771
|
+
for c in col.find_all(exp.Column):
|
|
772
|
+
col_name = get_column_name(c)
|
|
773
|
+
select_columns.append(col_name)
|
|
774
|
+
|
|
775
|
+
# Gather columns from GROUP BY
|
|
776
|
+
group_by_columns: set[ColumnInfo] = set()
|
|
777
|
+
for gb in select.group_by:
|
|
778
|
+
if isinstance(gb, exp.Column):
|
|
779
|
+
gb_name = get_column_name(gb)
|
|
780
|
+
group_by_columns.add(gb_name)
|
|
781
|
+
elif isinstance(gb, exp.Literal):
|
|
782
|
+
try:
|
|
783
|
+
val = int(gb.this)
|
|
784
|
+
# Positional GROUP BY: map to selected columns
|
|
785
|
+
if 1 <= val <= len(select_columns):
|
|
786
|
+
group_by_columns.add(select_columns[val - 1])
|
|
787
|
+
except ValueError:
|
|
788
|
+
continue
|
|
789
|
+
elif isinstance(gb, exp.AggFunc):
|
|
790
|
+
group_by_columns.add(ColumnInfo(gb.sql(), gb.sql(), is_aggregated=True))
|
|
791
|
+
else:
|
|
792
|
+
# Complex expression in GROUP BY: try to extract columns
|
|
793
|
+
for c in gb.find_all(exp.Column):
|
|
794
|
+
gb_name = get_column_name(c)
|
|
795
|
+
group_by_columns.add(gb_name)
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
# Ensure all non-aggregated columns in SELECT are in GROUP BY
|
|
799
|
+
for select_col in set(select_columns): # convert to set to avoid outputting the same error multiple times
|
|
800
|
+
if select_col.is_aggregated:
|
|
801
|
+
continue # aggregated, skip
|
|
802
|
+
if any(select_col.name == group_col.name or select_col.alias == group_col.alias for group_col in group_by_columns):
|
|
803
|
+
continue # valid: in GROUP BY
|
|
804
|
+
results.append(DetectedError(SqlErrors.EXTRANEOUS_OR_OMITTED_GROUPING_COLUMN,(select_col.name, 'ONLY IN SELECT')))
|
|
805
|
+
|
|
806
|
+
# Ensure all non-aggregated columns in GROUP BY are in SELECT
|
|
807
|
+
# (Note: aggregated columns in GROUP BY are invalid)
|
|
808
|
+
for group_col in group_by_columns:
|
|
809
|
+
if group_col.is_aggregated:
|
|
810
|
+
results.append(DetectedError(SqlErrors.EXTRANEOUS_OR_OMITTED_GROUPING_COLUMN,(group_col.name, 'AGGREGATED IN GROUP BY')))
|
|
811
|
+
continue
|
|
812
|
+
if any(group_col.name == select_col.name or group_col.alias == select_col.alias for select_col in select_columns):
|
|
813
|
+
continue # valid: in SELECT
|
|
814
|
+
results.append(DetectedError(SqlErrors.EXTRANEOUS_OR_OMITTED_GROUPING_COLUMN,(group_col.name, 'ONLY IN GROUP BY')))
|
|
815
|
+
# Ensure all non-aggregated columns in HAVING are in GROUP BY
|
|
816
|
+
|
|
817
|
+
return results
|
|
818
|
+
|
|
819
|
+
def detect_17_having_without_group_by(self) -> list[DetectedError]:
|
|
820
|
+
'''
|
|
821
|
+
Flags queries where HAVING is used without a GROUP BY clause.
|
|
822
|
+
'''
|
|
823
|
+
results: list[DetectedError] = []
|
|
824
|
+
|
|
825
|
+
for select in self.query.selects:
|
|
826
|
+
if select.having and not select.group_by:
|
|
827
|
+
results.append(DetectedError(SqlErrors.HAVING_WITHOUT_GROUP_BY))
|
|
828
|
+
|
|
829
|
+
return results
|
|
830
|
+
|
|
831
|
+
def detect_18_confusing_function_with_function_parameter(self) -> list[DetectedError]:
|
|
832
|
+
return []
|
|
833
|
+
|
|
834
|
+
def detect_19_using_where_twice(self) -> list[DetectedError]:
|
|
835
|
+
'''
|
|
836
|
+
Flags multiple WHERE clauses in a single query block (main query, CTEs, subqueries).
|
|
837
|
+
'''
|
|
838
|
+
|
|
839
|
+
results: list[DetectedError] = []
|
|
840
|
+
|
|
841
|
+
for select in self.query.selects:
|
|
842
|
+
|
|
843
|
+
# By removing subqueries, we can check only the top-level WHERE clauses in this select.
|
|
844
|
+
stripped = select.strip_subqueries()
|
|
845
|
+
|
|
846
|
+
where_count = 0
|
|
847
|
+
for ttype, val in stripped.tokens:
|
|
848
|
+
if ttype == sqlparse.tokens.Keyword and val.upper() == 'WHERE':
|
|
849
|
+
where_count += 1
|
|
850
|
+
|
|
851
|
+
if where_count > 1:
|
|
852
|
+
results.append(DetectedError(SqlErrors.USING_WHERE_TWICE, (select.sql, where_count)))
|
|
853
|
+
|
|
854
|
+
return results
|
|
855
|
+
|
|
856
|
+
def detect_20_omitted_from_clause(self) -> list[DetectedError]:
|
|
857
|
+
'''
|
|
858
|
+
Flags queries that omit the FROM clause entirely when it's required.
|
|
859
|
+
A FROM clause is not required if:
|
|
860
|
+
- The query selects only constants/literals
|
|
861
|
+
- The query uses CTEs and references them implicitly
|
|
862
|
+
'''
|
|
863
|
+
results: list[DetectedError] = []
|
|
864
|
+
|
|
865
|
+
for select in self.query.selects:
|
|
866
|
+
stripped = select.strip_subqueries()
|
|
867
|
+
|
|
868
|
+
from_found = False
|
|
869
|
+
for ttype, val in stripped.tokens:
|
|
870
|
+
if ttype == sqlparse.tokens.Keyword and val.upper() == 'FROM':
|
|
871
|
+
from_found = True
|
|
872
|
+
break
|
|
873
|
+
|
|
874
|
+
if from_found:
|
|
875
|
+
continue # valid, has FROM clause
|
|
876
|
+
|
|
877
|
+
# Check if selecting only constants/literals
|
|
878
|
+
for col in stripped.output.columns:
|
|
879
|
+
if not col.is_constant:
|
|
880
|
+
results.append(DetectedError(SqlErrors.OMITTED_FROM_CLAUSE, (select.sql,)))
|
|
881
|
+
break
|
|
882
|
+
|
|
883
|
+
return results
|
|
884
|
+
|
|
885
|
+
def detect_21_comparison_with_null(self) -> list[DetectedError]:
|
|
886
|
+
'''
|
|
887
|
+
Flags SQL comparisons using = NULL, <> NULL, etc. instead of IS NULL / IS NOT NULL.
|
|
888
|
+
'''
|
|
889
|
+
results: list[DetectedError] = []
|
|
890
|
+
|
|
891
|
+
for select in self.query.selects:
|
|
892
|
+
select = select.strip_subqueries(replacement='1') # avoid false positives from subqueries
|
|
893
|
+
|
|
894
|
+
if select.ast is None:
|
|
895
|
+
continue
|
|
896
|
+
|
|
897
|
+
for comparison in select.ast.find_all(exp.EQ, exp.NEQ, exp.LT, exp.GT, exp.LTE, exp.GTE):
|
|
898
|
+
left = comparison.left
|
|
899
|
+
right = comparison.right
|
|
900
|
+
if (isinstance(left, exp.Null) or isinstance(right, exp.Null)):
|
|
901
|
+
results.append(DetectedError(SqlErrors.COMPARISON_WITH_NULL, (comparison.sql(),)))
|
|
902
|
+
|
|
903
|
+
return results
|
|
904
|
+
|
|
905
|
+
def detect_23_date_time_field_overflow(self) -> list[DetectedError]:
|
|
906
|
+
return []
|
|
907
|
+
|
|
908
|
+
def detect_24_duplicate_clause(self) -> list[DetectedError]:
|
|
909
|
+
'''
|
|
910
|
+
Flags queries that contain duplicate clauses (e.g., two WHERE clauses).
|
|
911
|
+
'''
|
|
912
|
+
results: list[DetectedError] = []
|
|
913
|
+
|
|
914
|
+
clause_keywords = {'SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT', 'OFFSET'}
|
|
915
|
+
|
|
916
|
+
for select in self.query.selects:
|
|
917
|
+
stripped = select.strip_subqueries()
|
|
918
|
+
|
|
919
|
+
clause_count = {}
|
|
920
|
+
for ttype, val in stripped.tokens:
|
|
921
|
+
val_upper = val.upper()
|
|
922
|
+
if ttype == sqlparse.tokens.DML and val_upper == 'SELECT':
|
|
923
|
+
clause_count[val_upper] = clause_count.get(val_upper, 0) + 1
|
|
924
|
+
if ttype == sqlparse.tokens.Keyword and val_upper in clause_keywords:
|
|
925
|
+
clause_count[val_upper] = clause_count.get(val_upper, 0) + 1
|
|
926
|
+
|
|
927
|
+
for clause, count in clause_count.items():
|
|
928
|
+
if count > 1:
|
|
929
|
+
results.append(DetectedError(SqlErrors.DUPLICATE_CLAUSE, (clause, count)))
|
|
930
|
+
|
|
931
|
+
return results
|
|
932
|
+
|
|
933
|
+
def detect_26_too_many_columns_in_subquery(self) -> list[DetectedError]:
|
|
934
|
+
'''
|
|
935
|
+
Flags subqueries that return more columns than expected in contexts like WHERE IN (subquery).
|
|
936
|
+
'''
|
|
937
|
+
|
|
938
|
+
results: list[DetectedError] = []
|
|
939
|
+
|
|
940
|
+
for select in self.query.selects:
|
|
941
|
+
for subquery, clause, depth in select.subqueries:
|
|
942
|
+
if clause in ('FROM', 'EXISTS'):
|
|
943
|
+
continue # FROM/EXISTS subqueries can have any number of columns
|
|
944
|
+
|
|
945
|
+
output_columns = len(subquery.output.columns)
|
|
946
|
+
expected_columns = 1 # Default expected columns for most contexts
|
|
947
|
+
|
|
948
|
+
col_difference = output_columns - expected_columns
|
|
949
|
+
if col_difference != 0:
|
|
950
|
+
results.append(DetectedError(SqlErrors.TOO_MANY_COLUMNS_IN_SUBQUERY, (subquery.sql, col_difference)))
|
|
951
|
+
|
|
952
|
+
return results
|
|
953
|
+
|
|
954
|
+
def detect_30_confused_order_of_keywords(self) -> list[DetectedError]:
|
|
955
|
+
'''
|
|
956
|
+
Flags queries where the standard order of SQL clauses is not respected.
|
|
957
|
+
Expected order:
|
|
958
|
+
SELECT → FROM → WHERE → GROUP BY → HAVING → ORDER BY → LIMIT → OFFSET
|
|
959
|
+
'''
|
|
960
|
+
results: list[DetectedError] = []
|
|
961
|
+
|
|
962
|
+
for select in self.query.selects:
|
|
963
|
+
stripped = select.strip_subqueries()
|
|
964
|
+
|
|
965
|
+
expected_order = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT', 'OFFSET']
|
|
966
|
+
actual_order: list[str] = []
|
|
967
|
+
|
|
968
|
+
for ttype, val in stripped.tokens:
|
|
969
|
+
if ttype == sqlparse.tokens.DML:
|
|
970
|
+
actual_order.append('SELECT')
|
|
971
|
+
elif ttype == sqlparse.tokens.Keyword:
|
|
972
|
+
val_upper = val.upper()
|
|
973
|
+
if val_upper == 'FROM':
|
|
974
|
+
actual_order.append('FROM')
|
|
975
|
+
elif val_upper == 'WHERE':
|
|
976
|
+
actual_order.append('WHERE')
|
|
977
|
+
elif val_upper == 'GROUP BY':
|
|
978
|
+
actual_order.append('GROUP BY')
|
|
979
|
+
elif val_upper == 'HAVING':
|
|
980
|
+
actual_order.append('HAVING')
|
|
981
|
+
elif val_upper == 'ORDER BY':
|
|
982
|
+
actual_order.append('ORDER BY')
|
|
983
|
+
elif val_upper == 'LIMIT':
|
|
984
|
+
actual_order.append('LIMIT')
|
|
985
|
+
elif val_upper == 'OFFSET':
|
|
986
|
+
actual_order.append('OFFSET')
|
|
987
|
+
|
|
988
|
+
# Check the order of clauses
|
|
989
|
+
last_index = -1
|
|
990
|
+
for clause in actual_order:
|
|
991
|
+
if clause in expected_order:
|
|
992
|
+
current_index = expected_order.index(clause)
|
|
993
|
+
if current_index < last_index:
|
|
994
|
+
results.append(DetectedError(
|
|
995
|
+
SqlErrors.CONFUSED_ORDER_OF_KEYWORDS,
|
|
996
|
+
(actual_order,)
|
|
997
|
+
))
|
|
998
|
+
break
|
|
999
|
+
last_index = current_index
|
|
1000
|
+
|
|
1001
|
+
return results
|
|
1002
|
+
|
|
1003
|
+
# NOTE: is this implementation actually coherent with the error description?
|
|
1004
|
+
def detect_32_confused_syntax_of_keywords(self) -> list[DetectedError]:
|
|
1005
|
+
'''
|
|
1006
|
+
Flags use of SQL keywords like LIKE, IN, BETWEEN, etc. with incorrect function-like syntax (e.g., LIKE(...)).
|
|
1007
|
+
'''
|
|
1008
|
+
return []
|
|
1009
|
+
|
|
1010
|
+
results = []
|
|
1011
|
+
tokens = self.tokens
|
|
1012
|
+
keywords = {"LIKE", "BETWEEN", "IS", "IS NOT"}
|
|
1013
|
+
|
|
1014
|
+
i = 0
|
|
1015
|
+
while i < len(tokens):
|
|
1016
|
+
tt, val = tokens[i]
|
|
1017
|
+
val_upper = val.upper()
|
|
1018
|
+
|
|
1019
|
+
# Handle two-word operators like NOT IN and IS NOT
|
|
1020
|
+
if val_upper == "NOT" and i + 1 < len(tokens) and tokens[i + 1][1].upper() == "IN":
|
|
1021
|
+
keyword = "NOT IN"
|
|
1022
|
+
next_index = i + 2
|
|
1023
|
+
elif val_upper == "IS" and i + 1 < len(tokens) and tokens[i + 1][1].upper() == "NOT":
|
|
1024
|
+
keyword = "IS NOT"
|
|
1025
|
+
next_index = i + 2
|
|
1026
|
+
elif val_upper in keywords:
|
|
1027
|
+
keyword = val_upper
|
|
1028
|
+
next_index = i + 1
|
|
1029
|
+
else:
|
|
1030
|
+
i += 1
|
|
1031
|
+
continue
|
|
1032
|
+
|
|
1033
|
+
# Look for '(' after the keyword → indicates function misuse
|
|
1034
|
+
if next_index < len(tokens):
|
|
1035
|
+
next_val = tokens[next_index][1].strip()
|
|
1036
|
+
if next_val == "(":
|
|
1037
|
+
results.append((
|
|
1038
|
+
SqlErrors.SYN_32_CONFUSING_THE_SYNTAX_OF_KEYWORDS,
|
|
1039
|
+
f"Misuse of keyword '{keyword}' as a function with parentheses"
|
|
1040
|
+
))
|
|
1041
|
+
i = next_index # Skip ahead to avoid duplicate flag
|
|
1042
|
+
i += 1
|
|
1043
|
+
|
|
1044
|
+
return results
|
|
1045
|
+
|
|
1046
|
+
def detect_107_108_curly_square_or_unmatched_brackets(self) -> list[DetectedError]:
|
|
1047
|
+
'''
|
|
1048
|
+
Flags unmatched parentheses or usage of non-standard square or curly brackets in the SQL query.
|
|
1049
|
+
'''
|
|
1050
|
+
|
|
1051
|
+
results: list[DetectedError] = []
|
|
1052
|
+
|
|
1053
|
+
round_open = 0
|
|
1054
|
+
round_close = 0
|
|
1055
|
+
square_open = 0
|
|
1056
|
+
square_close = 0
|
|
1057
|
+
curly_open = 0
|
|
1058
|
+
curly_close = 0
|
|
1059
|
+
|
|
1060
|
+
for ttype, val in self.query.tokens:
|
|
1061
|
+
if ttype is sqlparse.tokens.Punctuation:
|
|
1062
|
+
if val == '(':
|
|
1063
|
+
round_open += 1
|
|
1064
|
+
elif val == ')':
|
|
1065
|
+
round_close += 1
|
|
1066
|
+
elif val == '[':
|
|
1067
|
+
square_open += 1
|
|
1068
|
+
elif val == ']':
|
|
1069
|
+
square_close += 1
|
|
1070
|
+
elif ttype is sqlparse.tokens.Error:
|
|
1071
|
+
if val == '{':
|
|
1072
|
+
curly_open += 1
|
|
1073
|
+
elif val == '}':
|
|
1074
|
+
curly_close += 1
|
|
1075
|
+
elif ttype is sqlparse.tokens.Name:
|
|
1076
|
+
if val.startswith('{') or val.endswith('}'):
|
|
1077
|
+
curly_open += val.count('{')
|
|
1078
|
+
curly_close += val.count('}')
|
|
1079
|
+
if val.startswith('[') or val.endswith(']'):
|
|
1080
|
+
square_open += val.count('[')
|
|
1081
|
+
square_close += val.count(']')
|
|
1082
|
+
|
|
1083
|
+
# Check for imbalance
|
|
1084
|
+
if round_open != round_close:
|
|
1085
|
+
results.append(DetectedError(SqlErrors.UNMATCHED_BRACKETS, ('round', round_open, round_close)))
|
|
1086
|
+
if square_open > 0 or square_close > 0:
|
|
1087
|
+
results.append(DetectedError(SqlErrors.CURLY_OR_SQUARE_BRACKETS, ('square', square_open, square_close)))
|
|
1088
|
+
if curly_open > 0 or curly_close > 0:
|
|
1089
|
+
results.append(DetectedError(SqlErrors.CURLY_OR_SQUARE_BRACKETS, ('curly', curly_open, curly_close)))
|
|
1090
|
+
|
|
1091
|
+
return results
|
|
1092
|
+
|
|
1093
|
+
def detect_35_is_where_not_applicable(self) -> list[DetectedError]:
|
|
1094
|
+
'''
|
|
1095
|
+
Find all erroneous usages of IS where it is not applicable
|
|
1096
|
+
'''
|
|
1097
|
+
|
|
1098
|
+
def parse_set_operation(set_operation: 'SetOperation') -> list[DetectedError]:
|
|
1099
|
+
'''
|
|
1100
|
+
Util function to parse a SetOperation and check for invalid usage of IS in all its main selects.
|
|
1101
|
+
'''
|
|
1102
|
+
|
|
1103
|
+
errors: list[DetectedError] = []
|
|
1104
|
+
for select in set_operation.main_selects:
|
|
1105
|
+
|
|
1106
|
+
typed_ast = select.typed_ast
|
|
1107
|
+
|
|
1108
|
+
if typed_ast is None:
|
|
1109
|
+
continue
|
|
1110
|
+
|
|
1111
|
+
for is_expr in typed_ast.find_all(exp.Is):
|
|
1112
|
+
for error in collect_errors(is_expr, select.catalog, select.search_path):
|
|
1113
|
+
|
|
1114
|
+
# if the expected type is boolean|null, it means that the part after IS is not valid
|
|
1115
|
+
if error[2] == 'boolean|null':
|
|
1116
|
+
errors.append(DetectedError(SqlErrors.IS_WHERE_NOT_APPLICABLE, error))
|
|
1117
|
+
|
|
1118
|
+
return errors
|
|
1119
|
+
|
|
1120
|
+
results: list[DetectedError] = []
|
|
1121
|
+
|
|
1122
|
+
# CTEs
|
|
1123
|
+
for cte in self.query.ctes:
|
|
1124
|
+
results.extend(parse_set_operation(cte))
|
|
1125
|
+
|
|
1126
|
+
# Main Query
|
|
1127
|
+
results.extend(parse_set_operation(self.query.main_query))
|
|
1128
|
+
|
|
1129
|
+
return results
|
|
1130
|
+
|
|
1131
|
+
def detect_36_nonstandard_keywords_or_standard_keywords_in_wrong_context(self) -> list[DetectedError]:
|
|
1132
|
+
return []
|
|
1133
|
+
|
|
1134
|
+
def detect_109_different_tuples_in_set_operation(self) -> list[DetectedError]:
|
|
1135
|
+
return []
|
|
1136
|
+
|
|
1137
|
+
def detect_106_missing_quantifier(self) -> list[DetectedError]:
|
|
1138
|
+
return []
|
|
1139
|
+
# endregion
|
|
1140
|
+
|