sql-error-categorizer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,56 @@
1
+ # Hidden, internal use only
2
+ from .detectors import BaseDetector as _BaseDetector, Detector as _Detector
3
+
4
+ # Public API
5
+ from .sql_errors import SqlErrors
6
+ from .catalog import Catalog, build_catalog, load_json as load_catalog
7
+ from .detectors import SyntaxErrorDetector, SemanticErrorDetector, LogicalErrorDetector, ComplicationDetector, DetectedError
8
+
9
+ def get_errors(query_str: str,
10
+ solutions: list[str] = [],
11
+ catalog: Catalog = Catalog(),
12
+ search_path: str = 'public',
13
+ solution_search_path: str = 'public',
14
+ detectors: list[type[_BaseDetector]] = [
15
+ SyntaxErrorDetector,
16
+ SemanticErrorDetector,
17
+ LogicalErrorDetector,
18
+ ComplicationDetector
19
+ ],
20
+ debug: bool = False) -> list[DetectedError]:
21
+ '''Detect SQL errors in the given query string.'''
22
+ det = _Detector(query_str,
23
+ solutions=solutions,
24
+ catalog=catalog,
25
+ search_path=search_path,
26
+ solution_search_path=solution_search_path,
27
+ debug=debug)
28
+
29
+ for detector in detectors:
30
+ det.add_detector(detector)
31
+
32
+ return det.run()
33
+
34
+ def get_error_types(query_str: str,
35
+ solutions: list[str] = [],
36
+ catalog: Catalog = Catalog(),
37
+ search_path: str = 'public',
38
+ solution_search_path: str = 'public',
39
+ detectors: list[type[_BaseDetector]] = [
40
+ SyntaxErrorDetector,
41
+ SemanticErrorDetector,
42
+ LogicalErrorDetector,
43
+ ComplicationDetector
44
+ ],
45
+ debug: bool = False) -> set[SqlErrors]:
46
+ '''Detect SQL error types in the given query string.'''
47
+
48
+ detected_errors = get_errors(query_str,
49
+ solutions=solutions,
50
+ catalog=catalog,
51
+ search_path=search_path,
52
+ solution_search_path=solution_search_path,
53
+ detectors=detectors,
54
+ debug=debug)
55
+
56
+ return {error.error for error in detected_errors}
@@ -0,0 +1,73 @@
1
+ '''Builds a catalog of existing schemas/tables/columns by executing the provided SQL string in a temporary PostgreSQL database.'''
2
+
3
+ from .catalog import Catalog, Schema, Table, Column, UniqueConstraintType
4
+ import psycopg2
5
+ import time
6
+ from . import queries
7
+
8
+ def build_catalog(sql_string: str, *, hostname: str, port: int, user: str, password: str, use_temp_schema: bool = True) -> Catalog:
9
+ '''Builds a catalog by executing the provided SQL string in a temporary PostgreSQL database.'''
10
+ result = Catalog()
11
+
12
+ if sql_string.strip() == '':
13
+ return result
14
+
15
+ conn = psycopg2.connect(host=hostname, port=port, user=user, password=password)
16
+ cur = conn.cursor()
17
+
18
+ # Use a temporary schema with a fixed name
19
+ if use_temp_schema:
20
+ schema_name = f'sql_error_categorizer_{time.time_ns()}'
21
+ cur.execute(f'CREATE schema {schema_name};')
22
+ cur.execute(f'SET search_path TO {schema_name};')
23
+ else:
24
+ schema_name = '%' # TODO: it's a bit hackish, find a more stable solution
25
+
26
+ # Create the tables
27
+ cur.execute(sql_string)
28
+
29
+ from dav_tools import messages
30
+ messages.info(queries.COLUMNS(schema_name))
31
+
32
+ # Fetch the catalog information
33
+ cur.execute(queries.COLUMNS(schema_name))
34
+ columns_info = cur.fetchall()
35
+
36
+ messages.debug(f'Fetched {len(columns_info)} columns from the database.')
37
+
38
+ for column in columns_info:
39
+ messages.debug(f'Processing column: {column}')
40
+ schema_name, table_name, column_name, column_type, numeric_precision, numeric_scale, is_nullable, fk_schema, fk_table, fk_column = column
41
+
42
+ result.add_column(schema_name, table_name, column_name,
43
+ column_type, numeric_precision, numeric_scale,
44
+ is_nullable,
45
+ fk_schema, fk_table, fk_column)
46
+
47
+ # Fetch unique constraints (including primary keys)
48
+ cur.execute(queries.UNIQUE_COLUMNS(schema_name))
49
+ unique_constraints_info = cur.fetchall()
50
+ for constraint in unique_constraints_info:
51
+ schema_name, table_name, constraint_type, columns = constraint
52
+ columns = set(columns.strip('{}').split(',')) # Postgres returns {col1,col2,...}
53
+
54
+ if constraint_type == 'PRIMARY KEY':
55
+ constraint_type = UniqueConstraintType.PRIMARY_KEY
56
+ elif constraint_type == 'UNIQUE':
57
+ constraint_type = UniqueConstraintType.UNIQUE
58
+ else:
59
+ raise ValueError(f'Unknown constraint type: {constraint_type}')
60
+
61
+ result[schema_name][table_name].add_unique_constraint(columns, constraint_type)
62
+
63
+ # Clean up
64
+ if use_temp_schema:
65
+ cur.execute(f'DROP schema {schema_name} CASCADE;')
66
+ conn.rollback() # no need to save anything
67
+
68
+ return result
69
+
70
+ def load_json(path: str) -> Catalog:
71
+ '''Loads a catalog from a JSON file.'''
72
+ return Catalog.load_json(path)
73
+
@@ -0,0 +1,328 @@
1
+ from dataclasses import dataclass, field
2
+ import json
3
+ from typing import Self
4
+ from enum import Enum
5
+ from copy import deepcopy
6
+
7
+ # region UniqueConstraint
8
+ class UniqueConstraintType(Enum):
9
+ PRIMARY_KEY = 'PRIMARY KEY'
10
+ UNIQUE = 'UNIQUE'
11
+
12
+ class UniqueConstraint:
13
+ def __init__(self, columns: set[str], constraint_type: UniqueConstraintType) -> None:
14
+ self.columns = columns
15
+ self.constraint_type = constraint_type
16
+
17
+ def __repr__(self, level: int = 0) -> str:
18
+ indent = ' ' * level
19
+ return f'{indent}UniqueConstraint({self.constraint_type.value}: {self.columns})'
20
+
21
+ def to_dict(self) -> dict:
22
+ return {
23
+ 'columns': list(self.columns), # JSON-friendly (list)
24
+ 'constraint_type': self.constraint_type.value,
25
+ }
26
+
27
+ @classmethod
28
+ def from_dict(cls, data: dict) -> 'UniqueConstraint':
29
+ return cls(columns=set(c.lower() for c in data['columns']),
30
+ constraint_type=UniqueConstraintType(data['constraint_type']))
31
+ # endregion
32
+
33
+ # region Column
34
+ @dataclass
35
+ class Column:
36
+ name: str
37
+ column_type: str = 'UNKNOWN'
38
+ numeric_precision: int | None = None
39
+ numeric_scale: int | None = None
40
+ is_nullable: bool = True
41
+ is_constant: bool = False
42
+ fk_schema: str | None = None
43
+ fk_table: str | None = None
44
+ fk_column: str | None = None
45
+
46
+ @property
47
+ def is_fk(self) -> bool:
48
+ return all([self.fk_schema, self.fk_table, self.fk_column])
49
+
50
+ def __repr__(self, level: int = 0) -> str:
51
+ indent = ' ' * level
52
+ return f'{indent}Column(name=\'{self.name}\', type=\'{self.column_type}\', is_fk={self.is_fk}, is_nullable={self.is_nullable}, is_constant={self.is_constant})'
53
+
54
+ def to_dict(self) -> dict:
55
+ return {
56
+ 'name': self.name,
57
+ 'column_type': self.column_type,
58
+ 'numeric_precision': self.numeric_precision,
59
+ 'numeric_scale': self.numeric_scale,
60
+ 'is_nullable': self.is_nullable,
61
+ 'fk_schema': self.fk_schema,
62
+ 'fk_table': self.fk_table,
63
+ 'fk_column': self.fk_column,
64
+ }
65
+
66
+ @classmethod
67
+ def from_dict(cls, data: dict) -> 'Column':
68
+ return cls(
69
+ name=data['name'],
70
+ column_type=data['column_type'],
71
+ numeric_precision=data.get('numeric_precision'),
72
+ numeric_scale=data.get('numeric_scale'),
73
+ is_nullable=data.get('is_nullable', True),
74
+ fk_schema=(data.get('fk_schema') or None),
75
+ fk_table=(data.get('fk_table') or None),
76
+ fk_column=(data.get('fk_column') or None),
77
+ )
78
+ # endregion
79
+
80
+ # region Table
81
+ @dataclass
82
+ class Table:
83
+ '''A database table, with columns and unique constraints. Supports multiple columns with the same name (e.g. from joins).'''
84
+ name: str
85
+ unique_constraints: list[UniqueConstraint] = field(default_factory=list)
86
+ columns: list[Column] = field(default_factory=list)
87
+
88
+ def add_unique_constraint(self, columns: set[str], constraint_type: UniqueConstraintType) -> None:
89
+ self.unique_constraints.append(UniqueConstraint(columns, constraint_type))
90
+
91
+ def add_column(self,
92
+ name: str,
93
+ column_type: str,
94
+ numeric_precision: int | None = None,
95
+ numeric_scale: int | None = None,
96
+ is_nullable: bool = True,
97
+ is_constant: bool = False,
98
+ fk_schema: str | None = None,
99
+ fk_table: str | None = None,
100
+ fk_column: str | None = None) -> Column:
101
+ column = Column(name=name,
102
+ column_type=column_type,
103
+ numeric_precision=numeric_precision,
104
+ numeric_scale=numeric_scale,
105
+ is_nullable=is_nullable,
106
+ is_constant=is_constant,
107
+ fk_schema=fk_schema,
108
+ fk_table=fk_table,
109
+ fk_column=fk_column)
110
+ self.columns.append(column)
111
+ return column
112
+
113
+ def has_column(self, column_name: str) -> bool:
114
+ '''Checks if a column exists in the table.'''
115
+ return any(col.name == column_name for col in self.columns)
116
+
117
+ def __getitem__(self, column_name: str) -> Column:
118
+ '''Gets a column from the table, creating it if it does not exist.'''
119
+ for col in self.columns:
120
+ if col.name == column_name:
121
+ return col
122
+
123
+ new_col = Column(name=column_name)
124
+ self.columns.append(new_col)
125
+ return new_col
126
+
127
+ def __repr__(self, level: int = 0) -> str:
128
+ indent = ' ' * level
129
+
130
+ columns = '\n'.join([col.__repr__(level + 1) for col in self.columns])
131
+ if len(self.unique_constraints) < 2:
132
+ unique_constraints_str = ', '.join([uc.__repr__(0) for uc in self.unique_constraints])
133
+ else:
134
+ unique_constraints_str = '\n' + '\n'.join([uc.__repr__(level + 1) for uc in self.unique_constraints]) + '\n' + indent
135
+
136
+ if len(self.columns) > 0:
137
+ columns = '\n' + columns + '\n' + indent
138
+
139
+ return f'{indent}Table(name=\'{self.name}\', columns=[{columns}], unique_constraints=[{unique_constraints_str}])'
140
+
141
+ def to_dict(self) -> dict:
142
+ return {
143
+ 'name': self.name,
144
+ 'unique_constraints': [uc.to_dict() for uc in self.unique_constraints],
145
+ 'columns': [col.to_dict() for col in self.columns],
146
+ }
147
+
148
+ @classmethod
149
+ def from_dict(cls, data: dict) -> 'Table':
150
+ table = cls(name=data['name'])
151
+ # Unique constraints first (so Column.is_pk works immediately on repr, etc.)
152
+ for uc_data in data.get('unique_constraints', []):
153
+ uc = UniqueConstraint.from_dict(uc_data)
154
+ table.unique_constraints.append(uc)
155
+ # Columns
156
+ for col_data in (data.get('columns') or []):
157
+ col = Column.from_dict(col_data)
158
+ # Keep internal store normalized to lowercase
159
+ table.columns.append(col)
160
+ return table
161
+ # endregion
162
+
163
+ # region Schema
164
+ @dataclass
165
+ class Schema:
166
+ name: str
167
+ _tables: dict[str, Table] = field(default_factory=dict)
168
+ functions: set[str] = field(default_factory=set)
169
+
170
+ def __getitem__(self, table_name: str) -> Table:
171
+ '''Gets a table from the schema, creating it if it does not exist.'''
172
+ if table_name not in self._tables:
173
+ self._tables[table_name] = Table(table_name)
174
+ return self._tables[table_name]
175
+
176
+ def __setitem__(self, table_name: str, table: Table) -> None:
177
+ '''Sets a table in the schema, replacing any existing table with the same name.'''
178
+ self._tables[table_name] = table
179
+
180
+ def has_table(self, table_name: str) -> bool:
181
+ '''Checks if a table exists in the schema.'''
182
+ return table_name in self._tables
183
+
184
+ def has_column(self, table_name: str, column_name: str) -> bool:
185
+ '''Checks if a column exists in the schema.'''
186
+ if not self.has_table(table_name):
187
+ return False
188
+ return self.__getitem__(table_name).has_column(column_name)
189
+
190
+ @property
191
+ def table_names(self) -> set[str]:
192
+ '''Returns all table names in the schema.'''
193
+ return set(self._tables.keys())
194
+
195
+ def __repr__(self, level: int = 0) -> str:
196
+ indent = ' ' * level
197
+ tables = '\n'.join([table.__repr__(level + 1) for table in self._tables.values()])
198
+ return f'{indent}Schema(name=\'{self.name}\', tables=[\n{tables}\n{indent}])'
199
+
200
+ def to_dict(self) -> dict:
201
+ return {
202
+ 'name': self.name,
203
+ 'tables': {name: tbl.to_dict() for name, tbl in self._tables.items()},
204
+ }
205
+
206
+ @classmethod
207
+ def from_dict(cls, data: dict) -> 'Schema':
208
+ schema = cls(name=data['name'])
209
+ for _, tbl_data in (data.get('tables') or {}).items():
210
+ tbl = Table.from_dict(tbl_data)
211
+ schema._tables[tbl.name] = tbl
212
+ return schema
213
+ # endregion
214
+
215
+ # region Catalog
216
+ @dataclass
217
+ class Catalog:
218
+ _schemas: dict[str, Schema] = field(default_factory=dict)
219
+
220
+ def __getitem__(self, schema_name: str) -> Schema:
221
+ '''Gets a schema from the catalog, creating it if it does not exist.'''
222
+
223
+ if schema_name not in self._schemas:
224
+ self._schemas[schema_name] = Schema(schema_name)
225
+ return self._schemas[schema_name]
226
+
227
+ def __setitem__(self, schema_name: str, schema: Schema) -> Schema:
228
+ '''Sets a schema in the catalog, replacing any existing schema with the same name.'''
229
+
230
+ self._schemas[schema_name] = schema
231
+ return schema
232
+
233
+ def has_schema(self, schema_name: str) -> bool:
234
+ '''Checks if a schema exists in the catalog.'''
235
+
236
+ return schema_name in self._schemas
237
+
238
+ def copy_table(self, schema_name: str, table_name: str, table: Table) -> Table:
239
+ '''Copies a table into the catalog, creating the schema if it does not exist.'''
240
+
241
+ new_table = deepcopy(table)
242
+ self[schema_name][table_name] = new_table
243
+
244
+ return new_table
245
+
246
+ def has_table(self, schema_name: str, table_name: str) -> bool:
247
+ '''
248
+ Checks if a table exists in the specified schema in the catalog.
249
+
250
+ Returns False if the schema or table do not exist.
251
+ '''
252
+
253
+ if not self.has_schema(schema_name):
254
+ return False
255
+ return self.__getitem__(schema_name).has_table(table_name)
256
+
257
+ def add_column(self, schema_name: str, table_name: str, column_name: str,
258
+ column_type: str, numeric_precision: int | None = None, numeric_scale: int | None = None,
259
+ is_nullable: bool = True,
260
+ fk_schema: str | None = None, fk_table: str | None = None, fk_column: str | None = None) -> None:
261
+ '''Adds a column to the catalog, creating the schema and table if they do not exist.'''
262
+
263
+ self[schema_name][table_name].add_column(name=column_name,
264
+ column_type=column_type, numeric_precision=numeric_precision, numeric_scale=numeric_scale,
265
+ is_nullable=is_nullable,
266
+ fk_schema=fk_schema, fk_table=fk_table, fk_column=fk_column)
267
+
268
+ @property
269
+ def schema_names(self) -> set[str]:
270
+ '''Returns all schema names in the catalog.'''
271
+ return set(self._schemas.keys())
272
+
273
+ @property
274
+ def table_names(self) -> set[str]:
275
+ '''Returns all table names in the catalog, regardless of schema.'''
276
+
277
+ result = set()
278
+ for schema in self._schemas.values():
279
+ result.update(schema.table_names)
280
+ return result
281
+
282
+ def copy(self) -> Self:
283
+ '''Creates a deep copy of the catalog.'''
284
+ return deepcopy(self)
285
+
286
+ def __repr__(self) -> str:
287
+ schemas = [schema.__repr__(1) for schema in self._schemas.values()]
288
+
289
+ result = 'Catalog('
290
+ for schema in schemas:
291
+ result += '\n' + schema
292
+ result += '\n)'
293
+
294
+ return result
295
+
296
+
297
+ def to_dict(self) -> dict:
298
+ return {
299
+ 'version': 1,
300
+ 'schemas': {name: sch.to_dict() for name, sch in self._schemas.items()},
301
+ }
302
+
303
+ @classmethod
304
+ def from_dict(cls, data: dict) -> 'Catalog':
305
+ cat = cls()
306
+ for _, sch_data in (data.get('schemas') or {}).items():
307
+ sch = Schema.from_dict(sch_data)
308
+ cat._schemas[sch.name] = sch
309
+ return cat
310
+
311
+ # String-based JSON (handy for DB/blob storage)
312
+ def to_json(self, *, indent: int | None = 2) -> str:
313
+ return json.dumps(self.to_dict(), indent=indent)
314
+
315
+ @classmethod
316
+ def from_json(cls, s: str) -> 'Catalog':
317
+ return cls.from_dict(json.loads(s))
318
+
319
+ # Convenience file helpers
320
+ def save_json(self, path: str, *, indent: int | None = 2) -> None:
321
+ with open(path, 'w', encoding='utf-8') as f:
322
+ json.dump(self.to_dict(), f, indent=indent)
323
+
324
+ @classmethod
325
+ def load_json(cls, path: str) -> 'Catalog':
326
+ with open(path, 'r', encoding='utf-8') as f:
327
+ data = json.load(f)
328
+ return cls.from_dict(data)
@@ -0,0 +1,60 @@
1
+ def UNIQUE_COLUMNS(schema_name: str = '%') -> str:
2
+ return f'''
3
+ SELECT
4
+ kcu.table_schema AS schema_name,
5
+ kcu.table_name,
6
+ tc.constraint_type,
7
+ array_agg(kcu.column_name ORDER BY kcu.ordinal_position) AS columns
8
+ FROM information_schema.table_constraints tc
9
+ JOIN information_schema.key_column_usage kcu
10
+ ON tc.constraint_name = kcu.constraint_name
11
+ AND tc.constraint_schema = kcu.constraint_schema
12
+ WHERE tc.constraint_type IN ('UNIQUE', 'PRIMARY KEY')
13
+ AND kcu.table_schema LIKE '{schema_name}'
14
+ GROUP BY
15
+ kcu.table_schema,
16
+ kcu.table_name,
17
+ kcu.constraint_name,
18
+ tc.constraint_type;
19
+ '''
20
+
21
+ def COLUMNS(schema_name: str = '%') -> str:
22
+ return f'''
23
+ SELECT
24
+ cols.table_schema AS schema_name,
25
+ cols.table_name,
26
+ cols.column_name,
27
+ cols.data_type AS column_type,
28
+ cols.numeric_precision,
29
+ cols.numeric_scale,
30
+ (cols.is_nullable = 'YES') AS is_nullable,
31
+ fk.foreign_table_schema AS foreign_key_schema,
32
+ fk.foreign_table_name AS foreign_key_table,
33
+ fk.foreign_column_name AS foreign_key_column
34
+ FROM information_schema.columns AS cols
35
+
36
+ -- Foreign Key
37
+ LEFT JOIN (
38
+ SELECT
39
+ kcu.table_schema,
40
+ kcu.table_name,
41
+ kcu.column_name,
42
+ ccu.table_schema AS foreign_table_schema,
43
+ ccu.table_name AS foreign_table_name,
44
+ ccu.column_name AS foreign_column_name
45
+ FROM information_schema.table_constraints AS tc
46
+ JOIN information_schema.key_column_usage AS kcu
47
+ ON tc.constraint_name = kcu.constraint_name
48
+ AND tc.constraint_schema = kcu.constraint_schema
49
+ AND tc.table_schema = kcu.table_schema
50
+ AND tc.table_name = kcu.table_name
51
+ JOIN information_schema.constraint_column_usage AS ccu
52
+ ON tc.constraint_name = ccu.constraint_name
53
+ AND tc.constraint_schema = ccu.constraint_schema
54
+ WHERE tc.constraint_type = 'FOREIGN KEY'
55
+ ) fk ON fk.table_schema = cols.table_schema
56
+ AND fk.table_name = cols.table_name
57
+ AND fk.column_name = cols.column_name
58
+
59
+ WHERE cols.table_schema LIKE '{schema_name}'
60
+ '''
@@ -0,0 +1,88 @@
1
+ from .. import catalog
2
+ from ..query import Query
3
+ from ..sql_errors import SqlErrors
4
+ from .base import BaseDetector, DetectedError
5
+ from .syntax import SyntaxErrorDetector
6
+ from .semantic import SemanticErrorDetector
7
+ from .logical import LogicalErrorDetector
8
+ from .complications import ComplicationDetector
9
+
10
+ class Detector:
11
+ def __init__(self,
12
+ query: str,
13
+ *,
14
+ search_path: str = 'public',
15
+ solution_search_path: str = 'public',
16
+ solutions: list[str] = [],
17
+ catalog: catalog.Catalog = catalog.Catalog(),
18
+ detectors: list[type[BaseDetector]] = [],
19
+ debug: bool = False):
20
+
21
+ # Context data: they don't need to be parsed again if the query changes
22
+ self.search_path = search_path
23
+ self.solution_search_path = solution_search_path
24
+ self.catalog = catalog
25
+ self.solutions = [Query(sol, catalog=self.catalog, search_path=self.solution_search_path) for sol in solutions]
26
+ self.detectors: list[BaseDetector] = []
27
+ self.debug = debug
28
+
29
+ self.set_query(query)
30
+
31
+ # NOTE: Add detectors after setting the query to ensure they are correctly initialized
32
+ for detector_cls in detectors:
33
+ self.add_detector(detector_cls)
34
+
35
+ def set_query(self, query: str, reason: str | None = None) -> None:
36
+ '''Set a new query, re-parse it, and update all detectors. Doesn't affect detected errors.'''
37
+
38
+ if self.debug:
39
+ print('=' * 20)
40
+ if reason:
41
+ print(f'Updating query ({reason}):\n{query}')
42
+ else:
43
+ print(f'Updating query:\n{query}')
44
+ print('=' * 20)
45
+
46
+ self.query = Query(query, catalog=self.catalog, search_path=self.search_path)
47
+
48
+ # Update all detectors with the new query and parse results
49
+ for detector in self.detectors:
50
+ detector.query = self.query
51
+ detector.update_query = lambda new_query, reason=None: self.set_query(new_query, reason)
52
+
53
+ def add_detector(self, detector_cls: type[BaseDetector]) -> None:
54
+ '''Add a detector instance to the list of detectors'''
55
+
56
+ # Make copies to avoid possible modifications during detection
57
+ # TODO: check if it's needed
58
+ detector = detector_cls(
59
+ query=self.query,
60
+ solutions=self.solutions,
61
+ update_query=lambda new_query, reason=None: self.set_query(new_query, reason),
62
+ )
63
+
64
+ self.detectors.append(detector)
65
+
66
+ def run(self) -> list[DetectedError]:
67
+ '''
68
+ Run all detectors and return a list of detected errors.
69
+ This function can return duplicate errors, as well as additional information on the detected errors.
70
+ '''
71
+
72
+ if self.debug:
73
+ print('===== Query =====')
74
+ print(self.query.sql)
75
+
76
+ results: list[DetectedError] = []
77
+
78
+ for detector in self.detectors:
79
+ errors = detector.run()
80
+
81
+ if self.debug:
82
+ print(f'===== Detected errors from {detector.__class__.__name__} =====')
83
+ for error in errors:
84
+ print(error)
85
+
86
+ results.extend(errors)
87
+
88
+ return results
@@ -0,0 +1,39 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Callable
4
+
5
+ from ..sql_errors import SqlErrors
6
+ from ..query import Query
7
+ from ..catalog import Catalog
8
+
9
+ @dataclass(repr=False)
10
+ class DetectedError:
11
+ error: SqlErrors
12
+ data: tuple[Any, ...] = field(default_factory=tuple)
13
+
14
+ def __repr__(self):
15
+ return f"DetectedError({self.error.value} - {self.error.name}: {self.data})"
16
+
17
+ def __str__(self) -> str:
18
+ if self.data:
19
+ return f'[{self.error.value:3}] {self.error.name}: {self.data}'
20
+ return f'[{self.error.value:3}] {self.error.name}'
21
+
22
+ def __hash__(self) -> int:
23
+ return hash((self.error, self.data))
24
+
25
+ class BaseDetector(ABC):
26
+ def __init__(self, *,
27
+ query: Query,
28
+ solutions: list[Query] = [],
29
+ update_query: Callable[[str, str | None], None],
30
+ ):
31
+ self.query = query
32
+ self.solutions = solutions
33
+ self.update_query = update_query
34
+
35
+ @abstractmethod
36
+ def run(self) -> list[DetectedError]:
37
+ '''Run the detector and return a list of detected errors with their descriptions'''
38
+ return []
39
+