sqlchecker 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1140 @@
1
+ '''Detector for syntax errors in SQL queries.'''
2
+
3
+ from dataclasses import dataclass
4
+ import difflib
5
+ import re
6
+ import sqlparse
7
+ from sqlglot import exp
8
+ from typing import Callable
9
+ from copy import deepcopy
10
+ from sqlerrors import SqlErrors
11
+ from sqlscope import Query
12
+ from sqlscope.query.set_operations.set_operation import SetOperation
13
+ from sqlscope.query.typechecking import get_type, collect_errors
14
+ from sqlscope import util
15
+
16
+ from .base import BaseDetector, DetectedError
17
+
18
+
19
+ class SyntaxErrorDetector(BaseDetector):
20
+ '''Detector for syntax errors in SQL queries.'''
21
+
22
+ def __init__(self,
23
+ *,
24
+ query: Query,
25
+ update_query: Callable[[str, str | None], None],
26
+ solutions: list[Query] = [],
27
+ ):
28
+ super().__init__(
29
+ query=query,
30
+ solutions=solutions,
31
+ update_query=update_query,
32
+ )
33
+
34
+ def run(self) -> list[DetectedError]:
35
+ '''Run the detector and return a list of detected errors with their descriptions'''
36
+ results: list[DetectedError] = super().run()
37
+
38
+ # 1) fix stray semicolons (to allow ast building for subsequent checks)
39
+ checks = [self.detect_22_38_additional_omitted_semicolons]
40
+
41
+ for check in checks:
42
+ check_result, fixed_query_str = check()
43
+ results.extend(check_result)
44
+
45
+ if fixed_query_str != self.query.sql:
46
+ self.update_query(fixed_query_str, check.__name__)
47
+
48
+ # 2) detect unexisting objects (before corrections, to avoid false positives)
49
+ unexisting_checks = [
50
+ self.detect_2_4_undefined_columns_ambiguous_columns, # ok
51
+ self.detect_2_ambiguous_function, # TODO: implement
52
+ self.detect_5_undefined_functions, # ok
53
+ self.detect_6_undefined_function_parameters, # ok
54
+ self.detect_7_8_undefined_tables, # ok
55
+ self.detect_25_using_an_undefined_correlation_name, # TODO: implement
56
+ ]
57
+
58
+ for check in unexisting_checks:
59
+ check_result = check()
60
+ results.extend(check_result)
61
+
62
+ # 3.1) detect fixable errors and apply corrections for improved subsequent checks
63
+ # NOTE: leave in this order!
64
+ misspelling_checks = [
65
+ self.detect_33_omitting_commas, # TODO: implement/refactor
66
+ self.detect_27_confusing_table_names_with_column_names, # TODO: implement
67
+ self.detect_36_nonstandard_operators, # ok
68
+ self.detect_9_misspellings_schemas_tables, # ok
69
+ self.detect_9_misspellings_columns, # ok
70
+ self.detect_10_synonyms, # TODO: implement
71
+ self.detect_11_omitted_quotes, # TODO: implement/refactor
72
+ ]
73
+
74
+ # 3.2) apply corrections and re-parse query
75
+ corrected_sql = self.query.sql
76
+ for check in misspelling_checks:
77
+ for error in check():
78
+ results.append(error)
79
+ pattern = r'\b' + re.escape(error.data[0]) + r'\b'
80
+ corrected_sql = re.sub(
81
+ pattern,
82
+ error.data[1],
83
+ corrected_sql,
84
+ # flags=re.IGNORECASE
85
+ )
86
+
87
+ # Use the corrected query from here on (across all detectors)
88
+ if corrected_sql != self.query.sql:
89
+ self.update_query(corrected_sql, check.__name__)
90
+
91
+ # Proceed with all other checks
92
+ checks = [
93
+ self.detect_12_failure_to_specify_column_name_twice, # TODO: implement
94
+ self.detect_13_data_type_mismatch, # ok
95
+ self.detect_14_aggregate_function_outside_select_or_having, # ok
96
+ self.detect_15_aggregate_functions_cannot_be_nested, # ok
97
+ self.detect_16_extraneous_or_omitted_grouping_column, # ok
98
+ self.detect_17_having_without_group_by, # ok
99
+ self.detect_106_missing_quantifier, #TODO: implement
100
+ self.detect_18_confusing_function_with_function_parameter, # TODO: implement
101
+ self.detect_19_using_where_twice, # ok
102
+ self.detect_20_omitted_from_clause, # ok
103
+ self.detect_21_comparison_with_null, # ok
104
+ self.detect_23_date_time_field_overflow, # TODO: implement, needs AST
105
+ self.detect_24_duplicate_clause, # ok
106
+ self.detect_26_too_many_columns_in_subquery, # ok
107
+ self.detect_30_confused_order_of_keywords, # ok
108
+ self.detect_32_confused_syntax_of_keywords, # TODO: check and refactor
109
+ self.detect_107_108_curly_square_or_unmatched_brackets, # ok
110
+ self.detect_35_is_where_not_applicable, # ok
111
+ self.detect_36_nonstandard_keywords_or_standard_keywords_in_wrong_context, #TODO: implement
112
+ self.detect_109_different_tuples_in_set_operation, #TODO: implement
113
+ ]
114
+
115
+ for check in checks:
116
+ results.extend(check())
117
+ return results
118
+
119
+ # region 1) Semicolons
120
+ def detect_22_38_additional_omitted_semicolons(self) -> tuple[list[DetectedError], str]:
121
+ '''
122
+ Flags queries that omit the semicolon at the end or have multiple semicolons.
123
+
124
+ Returns:
125
+ - List of DetectedError instances for semicolon issues.
126
+ - The cleaned query string with extra semicolons removed.
127
+ '''
128
+
129
+ results: list[DetectedError] = []
130
+
131
+ all_tokens = []
132
+ for statement in self.query.all_statements:
133
+ all_tokens.extend(list(statement.flatten()))
134
+
135
+ good_tokens = []
136
+ trailing_semicolon_found = False
137
+ non_whitespace_found = False
138
+
139
+ for token in reversed(all_tokens): # start from end to preserve only the last semicolon
140
+ # check for whitespace/newline
141
+ if token.ttype in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline):
142
+ # keep as is and continue
143
+ good_tokens.append(token.value)
144
+ continue
145
+
146
+ # check for semicolons: the first one before any non-whitespace is kept, others are flagged
147
+ if token.ttype == sqlparse.tokens.Punctuation and token.value == ';':
148
+ if non_whitespace_found:
149
+ # we encountered a semicolon in the middle of the query!
150
+ # we don't care if this is the first one we encounter, it's surely not supposed to be here
151
+ results.append(DetectedError(SqlErrors.ADDITIONAL_SEMICOLON))
152
+ continue
153
+
154
+ if not trailing_semicolon_found:
155
+ # we encountered the trailing semicolon for the first time
156
+ # it's good, keep it
157
+ good_tokens.append(token.value)
158
+ trailing_semicolon_found = True
159
+ continue
160
+
161
+ # else, we have already found the trailing semicolon, so this is an extra one at the end
162
+ results.append(DetectedError(SqlErrors.ADDITIONAL_SEMICOLON))
163
+ continue
164
+
165
+ # any other token
166
+ non_whitespace_found = True
167
+ good_tokens.append(token.value)
168
+
169
+ if not trailing_semicolon_found:
170
+ results.append(DetectedError(SqlErrors.OMITTED_SEMICOLON))
171
+
172
+ return (results, ''.join(reversed(good_tokens)))
173
+ # endregion
174
+
175
+ # region 2) Pre-fixing
176
+ def detect_2_ambiguous_function(self) -> list[DetectedError]:
177
+ return []
178
+
179
+ def detect_7_8_undefined_tables(self) -> list[DetectedError]:
180
+ '''
181
+ Checks for undefined tables in the FROM clause
182
+ '''
183
+
184
+ results: list[DetectedError] = []
185
+
186
+ for select in self.query.selects:
187
+ select = select.strip_subqueries()
188
+
189
+ if select.ast is None:
190
+ continue
191
+
192
+ for table in select.ast.find_all(exp.Table):
193
+ table_name = util.ast.table.get_real_name(table)
194
+ schema_name = util.ast.table.get_schema(table)
195
+
196
+ if schema_name:
197
+ # Fully qualified table (schema.table)
198
+ if not select.catalog.has_schema(schema_name):
199
+ results.append(DetectedError(SqlErrors.INVALID_SCHEMA_NAME, (table.sql(),)))
200
+ continue
201
+
202
+ if not select.catalog.has_table(schema_name, table_name):
203
+ results.append(DetectedError(SqlErrors.UNDEFINED_OBJECT, (table.sql(),)))
204
+ continue
205
+ else:
206
+ # Unqualified table (table)
207
+ # Check if table is a CTE
208
+ if select.catalog.has_table('', table_name):
209
+ continue
210
+
211
+ # Check if table is in the current schema
212
+ if select.catalog.has_table(select.search_path, table_name):
213
+ continue
214
+
215
+ results.append(DetectedError(SqlErrors.UNDEFINED_OBJECT, (table.sql(),)))
216
+
217
+ return results
218
+
219
+ def detect_2_4_undefined_columns_ambiguous_columns(self) -> list[DetectedError]:
220
+ '''
221
+ Checks for undefined and ambiguous columns.
222
+ '''
223
+
224
+ results: list[DetectedError] = []
225
+
226
+ for select in self.query.selects:
227
+ select = select.strip_subqueries()
228
+
229
+ if select.ast is None:
230
+ continue
231
+
232
+ for column in select.ast.find_all(exp.Column):
233
+ # skip `table.*` syntax, we only want to check actual column references
234
+ if isinstance(column.this, exp.Star):
235
+ continue
236
+
237
+ column_name = util.ast.column.get_name(column)
238
+ table_name = util.ast.column.get_table(column)
239
+
240
+ possible_matches = []
241
+
242
+ if table_name:
243
+ # Qualified column (table.column)
244
+ for table in select.referenced_tables:
245
+ if table.name != table_name:
246
+ continue
247
+
248
+ for possible_match in table.columns:
249
+ if possible_match.name == column_name:
250
+ possible_matches.append(f'{table_name}.{column_name}')
251
+ else:
252
+ # Unqualified column (column)
253
+ for table in select.referenced_tables:
254
+ for possible_match in table.columns:
255
+ if possible_match.name == column_name:
256
+ possible_matches.append(f'{table.name}.{column_name}')
257
+
258
+ if len(possible_matches) == 0:
259
+ results.append(DetectedError(SqlErrors.UNDEFINED_COLUMN, (column.sql(),)))
260
+ elif len(possible_matches) > 1:
261
+ results.append(DetectedError(SqlErrors.AMBIGUOUS_COLUMN, (column.sql(), possible_matches)))
262
+
263
+ return results
264
+
265
+ def detect_5_undefined_functions(self) -> list[DetectedError]:
266
+ '''Checks for undefined functions (i.e. invalid names followed by parentheses).'''
267
+
268
+ results: list[DetectedError] = []
269
+
270
+ # standard_functions = {
271
+ known_aggregate_functions = {
272
+ 'SUM', 'AVG', 'COUNT', 'MIN', 'MAX',
273
+ 'IN', 'EXISTS', 'ANY', 'ALL',
274
+ 'COALESCE', 'NULLIF', 'CAST', 'CONVERT',
275
+ 'UPPER', 'LOWER', 'LENGTH', 'SUBSTRING',
276
+ 'NOW', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
277
+ }
278
+ user_defined_functions = set() # TODO: self.catalog.functions
279
+
280
+ all_functions = known_aggregate_functions.union(user_defined_functions)
281
+
282
+ for func, clause in self.query.functions:
283
+ func_name = func.get_name()
284
+
285
+ if func_name is None:
286
+ continue
287
+
288
+ if func_name.upper() not in all_functions:
289
+ results.append(DetectedError(SqlErrors.UNDEFINED_FUNCTION, (func_name, clause)))
290
+
291
+ return results
292
+
293
+ def detect_6_undefined_function_parameters(self) -> list[DetectedError]:
294
+ '''Checks for undefined function parameters'''
295
+
296
+ results: list[DetectedError] = []
297
+
298
+ for token, val in self.query.tokens:
299
+ if any(val.startswith(p) for p in (':', '@', '?')):
300
+ results.append(DetectedError(SqlErrors.UNDEFINED_PARAMETER, (val,)))
301
+
302
+ return results
303
+
304
+ def detect_25_using_an_undefined_correlation_name(self) -> list[DetectedError]:
305
+ return []
306
+ # endregion
307
+
308
+ # region 3) Fixable errors
309
+ def detect_9_misspellings_schemas_tables(self) -> list[DetectedError]:
310
+ '''
311
+ Check for misspellings in table names.
312
+ '''
313
+
314
+ results: set[DetectedError] = set() # use a set to avoid applying the same correction multiple times
315
+
316
+ for select in self.query.selects:
317
+ select = select.strip_subqueries()
318
+
319
+ if select.ast is None:
320
+ continue
321
+
322
+ for table in select.ast.find_all(exp.Table):
323
+ table = deepcopy(table) # avoid modifying the original AST until we are sure we want to apply the correction
324
+ table_str = table.sql()
325
+ table_name = util.ast.table.get_real_name(table)
326
+ schema_name = util.ast.table.get_schema(table)
327
+
328
+ if schema_name:
329
+ # Fully qualified table (schema.table)
330
+ if select.catalog.has_table(schema_name, table_name):
331
+ continue
332
+
333
+ # check "schema.table" for more accurate matches in edge cases (i.e. can't determine if the misspelled part is schema or table)
334
+ available_tables = {f'{s}.{t}' for s in select.catalog.schema_names for t in select.catalog[s].table_names}
335
+ match = difflib.get_close_matches(f'{schema_name}.{table_name}', available_tables, n=1, cutoff=0.6)
336
+ if match:
337
+ s, t = match[0].split('.')
338
+
339
+ table.set('db', exp.TableAlias(this=exp.to_identifier(s, quoted=True)))
340
+ table.set('this', exp.to_identifier(t, quoted=True))
341
+
342
+ results.add(DetectedError(SqlErrors.MISSPELLINGS, (table_str, table.sql())))
343
+ continue
344
+
345
+ else:
346
+ # Unqualified table (table)
347
+ # Check if table is a CTE
348
+ if select.catalog.has_table('', table_name):
349
+ continue
350
+
351
+ # Check if table is in the current schema
352
+ if select.catalog.has_table(select.search_path, table_name):
353
+ continue
354
+
355
+ available_tables = {t for s in select.catalog.schema_names for t in select.catalog[s].table_names}
356
+ match = difflib.get_close_matches(table_name, available_tables, n=1, cutoff=0.6)
357
+ if match:
358
+ db = next(s for s in select.catalog.schema_names if select.catalog.has_table(s, match[0]))
359
+ table.set('this', exp.to_identifier(match[0], quoted=True))
360
+ if db != select.search_path:
361
+ table.set('db', exp.TableAlias(this=exp.to_identifier(db, quoted=True)))
362
+ results.add(DetectedError(SqlErrors.MISSPELLINGS, (table_str, table.sql())))
363
+
364
+ return [*results]
365
+
366
+ def detect_9_misspellings_columns(self) -> list[DetectedError]:
367
+ '''
368
+ Check for misspellings in table and column names.
369
+ Performs two passes: first try to match objects to their own type, then try to match to any type.
370
+ '''
371
+ results: set[DetectedError] = set() # use a set to avoid applying the same correction multiple times
372
+
373
+ for select in self.query.selects:
374
+ select = select.strip_subqueries()
375
+
376
+ if select.ast is None:
377
+ continue
378
+
379
+ for column in select.ast.find_all(exp.Column):
380
+ # skip `table.*` syntax, we only want to check actual column references
381
+ if isinstance(column.this, exp.Star):
382
+ continue
383
+
384
+ column = deepcopy(column) # avoid modifying the original AST until we are sure we want to apply the correction
385
+ column_str = column.sql()
386
+ column_name = util.ast.column.get_name(column)
387
+ table_name = util.ast.column.get_table(column)
388
+
389
+ found = False
390
+
391
+ for table in select.referenced_tables:
392
+ if table_name and table.name != table_name:
393
+ # Qualified column (table.column)
394
+ # check if column exists only in the specified table
395
+ continue
396
+ if table.has_column(column_name):
397
+ found = True
398
+ break
399
+
400
+ if found:
401
+ continue
402
+
403
+ if table_name:
404
+ # Qualified column (table.column)
405
+ available_columns = {f'{t.name}.{c.name}' for t in select.referenced_tables for c in t.columns}
406
+ else:
407
+ # Unqualified column (column)
408
+ available_columns = {c.name for t in select.referenced_tables for c in t.columns}
409
+
410
+ match = difflib.get_close_matches(column_name if not table_name else f'{table_name}.{column_name}', available_columns, n=1, cutoff=0.6)
411
+ if match:
412
+ if table_name:
413
+ matched_table, matched_column = match[0].split('.')
414
+ column.set('table', exp.to_identifier(matched_table, quoted=True))
415
+ column.set('this', exp.to_identifier(matched_column, quoted=True))
416
+ else:
417
+ column.set('this', exp.to_identifier(match[0], quoted=True))
418
+
419
+ results.add(DetectedError(SqlErrors.MISSPELLINGS, (column_str, column.sql())))
420
+
421
+ return [*results]
422
+
423
+ def detect_10_synonyms(self) -> list[DetectedError]:
424
+ return []
425
+
426
+ def detect_11_omitted_quotes(self) -> list[DetectedError]:
427
+ '''
428
+ Checks for potential omitting of quotes around character data in WHERE/HAVING clauses.
429
+
430
+ Returns:
431
+ A list of DetectedErrors. data=(offending_value,corrected_value)
432
+ '''
433
+ return []
434
+
435
+ results: list[DetectedError] = []
436
+
437
+
438
+
439
+ comparisons = self.query.comparisons
440
+
441
+
442
+ # for comparison in comparisons:
443
+
444
+
445
+ return results
446
+
447
+ # # 3. Build sets of ALL known identifiers for the entire query (main + subqueries + CTEs)
448
+ # valid_source_identifiers = set()
449
+ # all_known_columns_lower = set()
450
+ # db_tables = self.catalog.get('table_columns', {})
451
+
452
+ # # -- Main Query --
453
+ # main_query_sources = self._get_referenced_tables()
454
+ # main_alias_map = self.query_map.alias_mapping
455
+ # valid_source_identifiers.update(s.lower() for s in main_query_sources)
456
+ # valid_source_identifiers.update(a.lower() for a in main_alias_map.keys())
457
+ # for source in main_query_sources:
458
+ # actual_base_name = next((k for k in db_tables if k.lower() == source.lower()), None)
459
+ # if actual_base_name:
460
+ # all_known_columns_lower.update(c.lower() for c in db_tables[actual_base_name])
461
+
462
+ # # -- Subqueries --
463
+ # for subq_map in self.subquery_map.values():
464
+ # sub_sources = []
465
+ # sub_from = subq_map.from_value
466
+ # if sub_from:
467
+ # sub_sources.append(sub_from)
468
+ # sub_joins = subq_map.join_value
469
+ # sub_sources.extend(sub_joins)
470
+ # sub_aliases = subq_map.alias_mapping
471
+ # valid_source_identifiers.update(s.lower() for s in sub_sources)
472
+ # valid_source_identifiers.update(a.lower() for a in sub_aliases.keys())
473
+ # for source in sub_sources:
474
+ # actual_base_name = next((k for k in db_tables if k.lower() == source.lower()), None)
475
+ # if actual_base_name:
476
+ # all_known_columns_lower.update(c.lower() for c in db_tables[actual_base_name])
477
+
478
+ # # -- CTEs --
479
+ # if self.cte_map:
480
+ # valid_source_identifiers.update(name.lower() for name in self.cte_map.keys())
481
+ # for cte_name, cte_columns in self.cte_catalog.cte_tables.items():
482
+ # all_known_columns_lower.update(c.lower() for c in cte_columns)
483
+
484
+
485
+ # 4. Main Token-based Check
486
+ is_where_or_having = False
487
+ is_rhs_of_comparison = False # nothing prevents an expression to have its sides inverted, although it's unlikely to happen
488
+ comparison_operators = {'=', '<>', '!=', '<', '>', '<=', '>=', 'LIKE', 'NOT LIKE'}
489
+ known_keywords = {'SELECT', 'FROM', 'WHERE', 'JOIN', 'ON', 'GROUP', 'BY', 'ORDER', 'HAVING', 'LIMIT', 'AS', 'DISTINCT'}
490
+
491
+ for i, (tt, val) in enumerate(self.query.tokens):
492
+ if tt == sqlparse.tokens.Keyword and val.upper() in {'WHERE', 'HAVING'}:
493
+ is_where_or_having = True
494
+ if tt == sqlparse.tokens.Error:
495
+ continue
496
+ if val in comparison_operators:
497
+ is_rhs_of_comparison = True
498
+ continue
499
+ if tt in sqlparse.tokens.Literal or tt in (sqlparse.tokens.String.Single, sqlparse.tokens.String.Symbol):
500
+ if is_where_or_having and is_rhs_of_comparison:
501
+ stripped_val = val.strip()
502
+ if stripped_val.startswith('"') and stripped_val.endswith('"'):
503
+ results.append(DetectedError(SqlErrors.SYN_11_OMITTING_QUOTES_AROUND_CHARACTER_DATA, (val,)))
504
+ is_rhs_of_comparison = False
505
+ continue
506
+ if tt is not sqlparse.tokens.Name:
507
+ is_rhs_of_comparison = False
508
+ continue
509
+ if val.upper() in known_keywords:
510
+ is_rhs_of_comparison = False
511
+ continue
512
+ if val.lower() in valid_source_identifiers:
513
+ is_rhs_of_comparison = False
514
+ continue
515
+ if val.lower() in output_aliases_lower:
516
+ continue
517
+
518
+ clean_val = val
519
+
520
+
521
+ # if string OP notcol -> error
522
+ # if date OP notcol2 -> error
523
+ # if extract(notstring FROM ...) -> error
524
+ # like notstring -> error
525
+
526
+ # is this the correct approach? col OP notColumn
527
+ # TODO: literal or string.single/string.symbol in RHS of WHERE/HAVING
528
+ if is_where_or_having and is_rhs_of_comparison:
529
+ if clean_val.isalpha() and clean_val.lower() not in all_known_columns_lower:
530
+ results.append(DetectedError(SqlErrors.SYN_11_OMITTING_QUOTES_AROUND_CHARACTER_DATA, (val,)))
531
+ is_rhs_of_comparison = False
532
+ continue
533
+
534
+ return results
535
+
536
+ def detect_27_confusing_table_names_with_column_names(self) -> list[DetectedError]:
537
+ return []
538
+
539
+ def detect_33_omitting_commas(self) -> list[DetectedError]:
540
+ '''
541
+ Flags queries where commas are likely missing between column expressions
542
+ (e.g., SELECT name age FROM ..., GROUP BY x y).
543
+ '''
544
+ return []
545
+
546
+ results = []
547
+
548
+ clause_starters = {
549
+ "SELECT", "FROM", "WHERE", "GROUP BY", "HAVING", "ORDER BY", "LIMIT", "JOIN", "ON"
550
+ }
551
+ comma_required_clauses = {"SELECT", "GROUP BY", "ORDER BY", "VALUES"}
552
+ current_clause = None
553
+ in_clause_block = False
554
+
555
+ tokens = self.tokens
556
+ i = 0
557
+ while i < len(tokens):
558
+ tt, val = tokens[i]
559
+ val_upper = val.upper().strip()
560
+
561
+ # Detect clause start
562
+ if val_upper in {"SELECT", "GROUP BY", "ORDER BY", "VALUES"}:
563
+ current_clause = val_upper
564
+ in_clause_block = True
565
+ elif val_upper in clause_starters:
566
+ current_clause = None
567
+ in_clause_block = False
568
+
569
+ # Check for missing commas inside comma-required clauses
570
+ if in_clause_block and current_clause in comma_required_clauses:
571
+ is_valid_column = (
572
+ tt in sqlparse.tokens.Name or
573
+ (tt is None and val.replace('.', '').isalnum())
574
+ )
575
+ if is_valid_column and val_upper not in clause_starters:
576
+ # Look ahead to the next non-whitespace token
577
+ j = i + 1
578
+ while j < len(tokens) and tokens[j][0] in sqlparse.tokens.Whitespace:
579
+ j += 1
580
+ if j < len(tokens):
581
+ next_tt, next_val = tokens[j]
582
+ next_val_upper = next_val.upper().strip()
583
+ is_next_valid_column = (
584
+ next_tt in sqlparse.tokens.Name or
585
+ (next_tt is None and next_val.replace('.', '').isalnum())
586
+ )
587
+ if (
588
+ is_next_valid_column and
589
+ next_val_upper not in clause_starters and
590
+ next_val != ','
591
+ ):
592
+ results.append((
593
+ SqlErrors.SYN_33_OMITTING_COMMAS,
594
+ f"Possible missing comma between '{val}' and '{next_val}' in {current_clause} clause"
595
+ ))
596
+ i += 1
597
+
598
+ return results
599
+
600
+ def detect_36_nonstandard_operators(self) -> list[DetectedError]:
601
+ '''
602
+ Flags usage of non-standard or language-specific operators like &&, ||, ==, etc.
603
+ '''
604
+
605
+ results: list[DetectedError] = []
606
+
607
+ # dict {error: correction}
608
+ nonstandard_ops = {
609
+ '==' : '=',
610
+ '===' : '=',
611
+ '!==' : '<>',
612
+ '&&' : ' AND ',
613
+ '||' : ' OR ',
614
+ '!' : ' NOT ',
615
+ # '^' : '',
616
+ # '~' : '',
617
+ '>>' : '>',
618
+ '<<' : '<',
619
+ '≠' : '<>',
620
+ '≥' : '>=',
621
+ '≤' : '<=',
622
+ }
623
+
624
+ for ttype, val in self.query.tokens:
625
+ val_stripped = val.strip()
626
+ if ttype in sqlparse.tokens.Operator or ttype in sqlparse.tokens.Operator.Comparison or ttype == sqlparse.tokens.Error:
627
+ if val_stripped in nonstandard_ops:
628
+ correction = nonstandard_ops[val_stripped]
629
+ results.append(DetectedError(SqlErrors.NONSTANDARD_OPERATORS, (val_stripped, correction)))
630
+
631
+ return results
632
+ # endregion
633
+
634
+ # region 4) Other checks
635
+ def detect_12_failure_to_specify_column_name_twice(self) -> list[DetectedError]:
636
+ return []
637
+
638
+ def detect_13_data_type_mismatch(self) -> list[DetectedError]:
639
+ '''
640
+ Checks for data type mismatches in comparisons within the query.
641
+ '''
642
+
643
+ def parse_set_operation(set_op: 'SetOperation', location: str) -> list[DetectedError]:
644
+
645
+ '''
646
+ Util function to parse a SetOperation and check for data type mismatches among its main selects.
647
+ '''
648
+ errors: list[DetectedError] = []
649
+ expected_output = None # type of the first select's output
650
+ for select in set_op.main_selects:
651
+
652
+ typed_ast = select.typed_ast
653
+
654
+ if typed_ast is None:
655
+ continue
656
+
657
+ columns_type = get_type(typed_ast, select.catalog, select.search_path)
658
+
659
+ # 1st select: set expected output type
660
+ if expected_output is None:
661
+ expected_output = columns_type
662
+ else:
663
+ # compare with expected output type
664
+ if expected_output != columns_type:
665
+ errors.append(DetectedError(SqlErrors.DATA_TYPE_MISMATCH, (location,"setop types inconsistent")))
666
+
667
+ # load found messages
668
+ for message in columns_type.messages:
669
+ errors.append(DetectedError(SqlErrors.DATA_TYPE_MISMATCH, message))
670
+
671
+ return errors
672
+
673
+ results: list[DetectedError] = []
674
+
675
+ # CTEs
676
+ for cte in self.query.ctes:
677
+ results.extend(parse_set_operation(cte, f"CTE {cte.output.name}"))
678
+
679
+ # Main Query
680
+ results.extend(parse_set_operation(self.query.main_query, "Main Query"))
681
+
682
+ return results
683
+
684
+ def detect_14_aggregate_function_outside_select_or_having(self) -> list[DetectedError]:
685
+ '''
686
+ Flags use of aggregate functions (SUM, AVG, COUNT, MIN, MAX) outside SELECT or HAVING clauses,
687
+ respecting subquery scopes.
688
+ '''
689
+
690
+ results: list[DetectedError] = []
691
+
692
+ functions = self.query.functions
693
+ for function, clause in functions:
694
+ function_name = function.get_name()
695
+ if function_name and function_name.upper() in {'SUM', 'AVG', 'COUNT', 'MIN', 'MAX'}:
696
+ if clause not in {'SELECT', 'HAVING'}:
697
+ results.append(DetectedError(SqlErrors.AGGREGATE_FUNCTION_OUTSIDE_SELECT_OR_HAVING, (function_name, clause)))
698
+
699
+ return results
700
+
701
+ def detect_15_aggregate_functions_cannot_be_nested(self) -> list[DetectedError]:
702
+ '''
703
+ Flags cases where aggregate functions are nested within the *same query scope*,
704
+ which mainstream SQL dialects do not allow (e.g., SUM(MAX(x))).
705
+ '''
706
+ results: list[DetectedError] = []
707
+
708
+ for select in self.query.selects:
709
+ stripped = select.strip_subqueries()
710
+
711
+ if stripped.ast is None:
712
+ continue
713
+
714
+ aggregate_functions = stripped.ast.find_all(exp.AggFunc)
715
+
716
+ for outer_agg in aggregate_functions:
717
+ inner = outer_agg.this
718
+ for inner_agg in inner.find_all(exp.AggFunc):
719
+ results.append(DetectedError(
720
+ SqlErrors.AGGREGATE_FUNCTIONS_CANNOT_BE_NESTED,
721
+ (outer_agg.sql(),)
722
+ ))
723
+
724
+ return results
725
+
726
+ def detect_16_extraneous_or_omitted_grouping_column(self) -> list[DetectedError]:
727
+ '''
728
+ All columns in SELECT must be either included in the GROUP BY clause or aggregated.
729
+
730
+ All non-aggregated columns in HAVING must not be included in the GROUP BY clause.
731
+ '''
732
+
733
+ @dataclass(frozen=True)
734
+ class ColumnInfo:
735
+ name: str
736
+ alias: str
737
+ is_aggregated: bool = False
738
+
739
+ def get_column_name(col: exp.Column | exp.Alias) -> ColumnInfo:
740
+ '''Return normalized column name and alias. If no alias, both are the same.'''
741
+ col_name = util.ast.column.get_real_name(col)
742
+ col_alias = util.ast.column.get_name(col)
743
+ return ColumnInfo(col_name, col_alias)
744
+
745
+ results: list[DetectedError] = []
746
+
747
+ for select in self.query.selects:
748
+ if select.ast is None:
749
+ continue
750
+
751
+ if not select.group_by:
752
+ continue # no GROUP BY, skip
753
+
754
+ select_columns: list[ColumnInfo] = [] # we need a list for positional GROUP BY handling
755
+
756
+ # Gather non-aggregated columns from SELECT
757
+ for col in select.ast.expressions:
758
+ if isinstance(col, exp.Star):
759
+ # SELECT * case: expand to all columns from all referenced tables
760
+ for table in select.referenced_tables:
761
+ for table_col in table.columns:
762
+ select_columns.append(ColumnInfo(table_col.name, table_col.name))
763
+ if isinstance(col, exp.Column) or isinstance(col, exp.Alias):
764
+ col_name = get_column_name(col)
765
+ select_columns.append(col_name)
766
+ elif isinstance(col, exp.Func):
767
+ # aggregated, add the column but skip it later
768
+ select_columns.append(ColumnInfo(col.sql(), col.sql(), is_aggregated=True))
769
+ else:
770
+ # Complex expression: try to extract columns
771
+ for c in col.find_all(exp.Column):
772
+ col_name = get_column_name(c)
773
+ select_columns.append(col_name)
774
+
775
+ # Gather columns from GROUP BY
776
+ group_by_columns: set[ColumnInfo] = set()
777
+ for gb in select.group_by:
778
+ if isinstance(gb, exp.Column):
779
+ gb_name = get_column_name(gb)
780
+ group_by_columns.add(gb_name)
781
+ elif isinstance(gb, exp.Literal):
782
+ try:
783
+ val = int(gb.this)
784
+ # Positional GROUP BY: map to selected columns
785
+ if 1 <= val <= len(select_columns):
786
+ group_by_columns.add(select_columns[val - 1])
787
+ except ValueError:
788
+ continue
789
+ elif isinstance(gb, exp.AggFunc):
790
+ group_by_columns.add(ColumnInfo(gb.sql(), gb.sql(), is_aggregated=True))
791
+ else:
792
+ # Complex expression in GROUP BY: try to extract columns
793
+ for c in gb.find_all(exp.Column):
794
+ gb_name = get_column_name(c)
795
+ group_by_columns.add(gb_name)
796
+
797
+
798
+ # Ensure all non-aggregated columns in SELECT are in GROUP BY
799
+ for select_col in set(select_columns): # convert to set to avoid outputting the same error multiple times
800
+ if select_col.is_aggregated:
801
+ continue # aggregated, skip
802
+ if any(select_col.name == group_col.name or select_col.alias == group_col.alias for group_col in group_by_columns):
803
+ continue # valid: in GROUP BY
804
+ results.append(DetectedError(SqlErrors.EXTRANEOUS_OR_OMITTED_GROUPING_COLUMN,(select_col.name, 'ONLY IN SELECT')))
805
+
806
+ # Ensure all non-aggregated columns in GROUP BY are in SELECT
807
+ # (Note: aggregated columns in GROUP BY are invalid)
808
+ for group_col in group_by_columns:
809
+ if group_col.is_aggregated:
810
+ results.append(DetectedError(SqlErrors.EXTRANEOUS_OR_OMITTED_GROUPING_COLUMN,(group_col.name, 'AGGREGATED IN GROUP BY')))
811
+ continue
812
+ if any(group_col.name == select_col.name or group_col.alias == select_col.alias for select_col in select_columns):
813
+ continue # valid: in SELECT
814
+ results.append(DetectedError(SqlErrors.EXTRANEOUS_OR_OMITTED_GROUPING_COLUMN,(group_col.name, 'ONLY IN GROUP BY')))
815
+ # Ensure all non-aggregated columns in HAVING are in GROUP BY
816
+
817
+ return results
818
+
819
+ def detect_17_having_without_group_by(self) -> list[DetectedError]:
820
+ '''
821
+ Flags queries where HAVING is used without a GROUP BY clause.
822
+ '''
823
+ results: list[DetectedError] = []
824
+
825
+ for select in self.query.selects:
826
+ if select.having and not select.group_by:
827
+ results.append(DetectedError(SqlErrors.HAVING_WITHOUT_GROUP_BY))
828
+
829
+ return results
830
+
831
+ def detect_18_confusing_function_with_function_parameter(self) -> list[DetectedError]:
832
+ return []
833
+
834
+ def detect_19_using_where_twice(self) -> list[DetectedError]:
835
+ '''
836
+ Flags multiple WHERE clauses in a single query block (main query, CTEs, subqueries).
837
+ '''
838
+
839
+ results: list[DetectedError] = []
840
+
841
+ for select in self.query.selects:
842
+
843
+ # By removing subqueries, we can check only the top-level WHERE clauses in this select.
844
+ stripped = select.strip_subqueries()
845
+
846
+ where_count = 0
847
+ for ttype, val in stripped.tokens:
848
+ if ttype == sqlparse.tokens.Keyword and val.upper() == 'WHERE':
849
+ where_count += 1
850
+
851
+ if where_count > 1:
852
+ results.append(DetectedError(SqlErrors.USING_WHERE_TWICE, (select.sql, where_count)))
853
+
854
+ return results
855
+
856
+ def detect_20_omitted_from_clause(self) -> list[DetectedError]:
857
+ '''
858
+ Flags queries that omit the FROM clause entirely when it's required.
859
+ A FROM clause is not required if:
860
+ - The query selects only constants/literals
861
+ - The query uses CTEs and references them implicitly
862
+ '''
863
+ results: list[DetectedError] = []
864
+
865
+ for select in self.query.selects:
866
+ stripped = select.strip_subqueries()
867
+
868
+ from_found = False
869
+ for ttype, val in stripped.tokens:
870
+ if ttype == sqlparse.tokens.Keyword and val.upper() == 'FROM':
871
+ from_found = True
872
+ break
873
+
874
+ if from_found:
875
+ continue # valid, has FROM clause
876
+
877
+ # Check if selecting only constants/literals
878
+ for col in stripped.output.columns:
879
+ if not col.is_constant:
880
+ results.append(DetectedError(SqlErrors.OMITTED_FROM_CLAUSE, (select.sql,)))
881
+ break
882
+
883
+ return results
884
+
885
+ def detect_21_comparison_with_null(self) -> list[DetectedError]:
886
+ '''
887
+ Flags SQL comparisons using = NULL, <> NULL, etc. instead of IS NULL / IS NOT NULL.
888
+ '''
889
+ results: list[DetectedError] = []
890
+
891
+ for select in self.query.selects:
892
+ select = select.strip_subqueries(replacement='1') # avoid false positives from subqueries
893
+
894
+ if select.ast is None:
895
+ continue
896
+
897
+ for comparison in select.ast.find_all(exp.EQ, exp.NEQ, exp.LT, exp.GT, exp.LTE, exp.GTE):
898
+ left = comparison.left
899
+ right = comparison.right
900
+ if (isinstance(left, exp.Null) or isinstance(right, exp.Null)):
901
+ results.append(DetectedError(SqlErrors.COMPARISON_WITH_NULL, (comparison.sql(),)))
902
+
903
+ return results
904
+
905
+ def detect_23_date_time_field_overflow(self) -> list[DetectedError]:
906
+ return []
907
+
908
+ def detect_24_duplicate_clause(self) -> list[DetectedError]:
909
+ '''
910
+ Flags queries that contain duplicate clauses (e.g., two WHERE clauses).
911
+ '''
912
+ results: list[DetectedError] = []
913
+
914
+ clause_keywords = {'SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT', 'OFFSET'}
915
+
916
+ for select in self.query.selects:
917
+ stripped = select.strip_subqueries()
918
+
919
+ clause_count = {}
920
+ for ttype, val in stripped.tokens:
921
+ val_upper = val.upper()
922
+ if ttype == sqlparse.tokens.DML and val_upper == 'SELECT':
923
+ clause_count[val_upper] = clause_count.get(val_upper, 0) + 1
924
+ if ttype == sqlparse.tokens.Keyword and val_upper in clause_keywords:
925
+ clause_count[val_upper] = clause_count.get(val_upper, 0) + 1
926
+
927
+ for clause, count in clause_count.items():
928
+ if count > 1:
929
+ results.append(DetectedError(SqlErrors.DUPLICATE_CLAUSE, (clause, count)))
930
+
931
+ return results
932
+
933
+ def detect_26_too_many_columns_in_subquery(self) -> list[DetectedError]:
934
+ '''
935
+ Flags subqueries that return more columns than expected in contexts like WHERE IN (subquery).
936
+ '''
937
+
938
+ results: list[DetectedError] = []
939
+
940
+ for select in self.query.selects:
941
+ for subquery, clause, depth in select.subqueries:
942
+ if clause in ('FROM', 'EXISTS'):
943
+ continue # FROM/EXISTS subqueries can have any number of columns
944
+
945
+ output_columns = len(subquery.output.columns)
946
+ expected_columns = 1 # Default expected columns for most contexts
947
+
948
+ col_difference = output_columns - expected_columns
949
+ if col_difference != 0:
950
+ results.append(DetectedError(SqlErrors.TOO_MANY_COLUMNS_IN_SUBQUERY, (subquery.sql, col_difference)))
951
+
952
+ return results
953
+
954
+ def detect_30_confused_order_of_keywords(self) -> list[DetectedError]:
955
+ '''
956
+ Flags queries where the standard order of SQL clauses is not respected.
957
+ Expected order:
958
+ SELECT → FROM → WHERE → GROUP BY → HAVING → ORDER BY → LIMIT → OFFSET
959
+ '''
960
+ results: list[DetectedError] = []
961
+
962
+ for select in self.query.selects:
963
+ stripped = select.strip_subqueries()
964
+
965
+ expected_order = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT', 'OFFSET']
966
+ actual_order: list[str] = []
967
+
968
+ for ttype, val in stripped.tokens:
969
+ if ttype == sqlparse.tokens.DML:
970
+ actual_order.append('SELECT')
971
+ elif ttype == sqlparse.tokens.Keyword:
972
+ val_upper = val.upper()
973
+ if val_upper == 'FROM':
974
+ actual_order.append('FROM')
975
+ elif val_upper == 'WHERE':
976
+ actual_order.append('WHERE')
977
+ elif val_upper == 'GROUP BY':
978
+ actual_order.append('GROUP BY')
979
+ elif val_upper == 'HAVING':
980
+ actual_order.append('HAVING')
981
+ elif val_upper == 'ORDER BY':
982
+ actual_order.append('ORDER BY')
983
+ elif val_upper == 'LIMIT':
984
+ actual_order.append('LIMIT')
985
+ elif val_upper == 'OFFSET':
986
+ actual_order.append('OFFSET')
987
+
988
+ # Check the order of clauses
989
+ last_index = -1
990
+ for clause in actual_order:
991
+ if clause in expected_order:
992
+ current_index = expected_order.index(clause)
993
+ if current_index < last_index:
994
+ results.append(DetectedError(
995
+ SqlErrors.CONFUSED_ORDER_OF_KEYWORDS,
996
+ (actual_order,)
997
+ ))
998
+ break
999
+ last_index = current_index
1000
+
1001
+ return results
1002
+
1003
+ # NOTE: is this implementation actually coherent with the error description?
1004
+ def detect_32_confused_syntax_of_keywords(self) -> list[DetectedError]:
1005
+ '''
1006
+ Flags use of SQL keywords like LIKE, IN, BETWEEN, etc. with incorrect function-like syntax (e.g., LIKE(...)).
1007
+ '''
1008
+ return []
1009
+
1010
+ results = []
1011
+ tokens = self.tokens
1012
+ keywords = {"LIKE", "BETWEEN", "IS", "IS NOT"}
1013
+
1014
+ i = 0
1015
+ while i < len(tokens):
1016
+ tt, val = tokens[i]
1017
+ val_upper = val.upper()
1018
+
1019
+ # Handle two-word operators like NOT IN and IS NOT
1020
+ if val_upper == "NOT" and i + 1 < len(tokens) and tokens[i + 1][1].upper() == "IN":
1021
+ keyword = "NOT IN"
1022
+ next_index = i + 2
1023
+ elif val_upper == "IS" and i + 1 < len(tokens) and tokens[i + 1][1].upper() == "NOT":
1024
+ keyword = "IS NOT"
1025
+ next_index = i + 2
1026
+ elif val_upper in keywords:
1027
+ keyword = val_upper
1028
+ next_index = i + 1
1029
+ else:
1030
+ i += 1
1031
+ continue
1032
+
1033
+ # Look for '(' after the keyword → indicates function misuse
1034
+ if next_index < len(tokens):
1035
+ next_val = tokens[next_index][1].strip()
1036
+ if next_val == "(":
1037
+ results.append((
1038
+ SqlErrors.SYN_32_CONFUSING_THE_SYNTAX_OF_KEYWORDS,
1039
+ f"Misuse of keyword '{keyword}' as a function with parentheses"
1040
+ ))
1041
+ i = next_index # Skip ahead to avoid duplicate flag
1042
+ i += 1
1043
+
1044
+ return results
1045
+
1046
+ def detect_107_108_curly_square_or_unmatched_brackets(self) -> list[DetectedError]:
1047
+ '''
1048
+ Flags unmatched parentheses or usage of non-standard square or curly brackets in the SQL query.
1049
+ '''
1050
+
1051
+ results: list[DetectedError] = []
1052
+
1053
+ round_open = 0
1054
+ round_close = 0
1055
+ square_open = 0
1056
+ square_close = 0
1057
+ curly_open = 0
1058
+ curly_close = 0
1059
+
1060
+ for ttype, val in self.query.tokens:
1061
+ if ttype is sqlparse.tokens.Punctuation:
1062
+ if val == '(':
1063
+ round_open += 1
1064
+ elif val == ')':
1065
+ round_close += 1
1066
+ elif val == '[':
1067
+ square_open += 1
1068
+ elif val == ']':
1069
+ square_close += 1
1070
+ elif ttype is sqlparse.tokens.Error:
1071
+ if val == '{':
1072
+ curly_open += 1
1073
+ elif val == '}':
1074
+ curly_close += 1
1075
+ elif ttype is sqlparse.tokens.Name:
1076
+ if val.startswith('{') or val.endswith('}'):
1077
+ curly_open += val.count('{')
1078
+ curly_close += val.count('}')
1079
+ if val.startswith('[') or val.endswith(']'):
1080
+ square_open += val.count('[')
1081
+ square_close += val.count(']')
1082
+
1083
+ # Check for imbalance
1084
+ if round_open != round_close:
1085
+ results.append(DetectedError(SqlErrors.UNMATCHED_BRACKETS, ('round', round_open, round_close)))
1086
+ if square_open > 0 or square_close > 0:
1087
+ results.append(DetectedError(SqlErrors.CURLY_OR_SQUARE_BRACKETS, ('square', square_open, square_close)))
1088
+ if curly_open > 0 or curly_close > 0:
1089
+ results.append(DetectedError(SqlErrors.CURLY_OR_SQUARE_BRACKETS, ('curly', curly_open, curly_close)))
1090
+
1091
+ return results
1092
+
1093
+ def detect_35_is_where_not_applicable(self) -> list[DetectedError]:
1094
+ '''
1095
+ Find all erroneous usages of IS where it is not applicable
1096
+ '''
1097
+
1098
+ def parse_set_operation(set_operation: 'SetOperation') -> list[DetectedError]:
1099
+ '''
1100
+ Util function to parse a SetOperation and check for invalid usage of IS in all its main selects.
1101
+ '''
1102
+
1103
+ errors: list[DetectedError] = []
1104
+ for select in set_operation.main_selects:
1105
+
1106
+ typed_ast = select.typed_ast
1107
+
1108
+ if typed_ast is None:
1109
+ continue
1110
+
1111
+ for is_expr in typed_ast.find_all(exp.Is):
1112
+ for error in collect_errors(is_expr, select.catalog, select.search_path):
1113
+
1114
+ # if the expected type is boolean|null, it means that the part after IS is not valid
1115
+ if error[2] == 'boolean|null':
1116
+ errors.append(DetectedError(SqlErrors.IS_WHERE_NOT_APPLICABLE, error))
1117
+
1118
+ return errors
1119
+
1120
+ results: list[DetectedError] = []
1121
+
1122
+ # CTEs
1123
+ for cte in self.query.ctes:
1124
+ results.extend(parse_set_operation(cte))
1125
+
1126
+ # Main Query
1127
+ results.extend(parse_set_operation(self.query.main_query))
1128
+
1129
+ return results
1130
+
1131
+ def detect_36_nonstandard_keywords_or_standard_keywords_in_wrong_context(self) -> list[DetectedError]:
1132
+ return []
1133
+
1134
+ def detect_109_different_tuples_in_set_operation(self) -> list[DetectedError]:
1135
+ return []
1136
+
1137
+ def detect_106_missing_quantifier(self) -> list[DetectedError]:
1138
+ return []
1139
+ # endregion
1140
+