sql-error-categorizer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_error_categorizer/__init__.py +56 -0
- sql_error_categorizer/catalog/__init__.py +73 -0
- sql_error_categorizer/catalog/catalog.py +328 -0
- sql_error_categorizer/catalog/queries.py +60 -0
- sql_error_categorizer/detectors/__init__.py +88 -0
- sql_error_categorizer/detectors/base.py +39 -0
- sql_error_categorizer/detectors/complications.py +393 -0
- sql_error_categorizer/detectors/logical.py +708 -0
- sql_error_categorizer/detectors/semantic.py +493 -0
- sql_error_categorizer/detectors/syntax.py +1278 -0
- sql_error_categorizer/query/__init__.py +4 -0
- sql_error_categorizer/query/extractors.py +134 -0
- sql_error_categorizer/query/query.py +98 -0
- sql_error_categorizer/query/set_operations/__init__.py +150 -0
- sql_error_categorizer/query/set_operations/binary_set_operation.py +89 -0
- sql_error_categorizer/query/set_operations/select.py +361 -0
- sql_error_categorizer/query/set_operations/set_operation.py +45 -0
- sql_error_categorizer/query/smt.py +206 -0
- sql_error_categorizer/query/tokenized_sql.py +68 -0
- sql_error_categorizer/query/typechecking.py +242 -0
- sql_error_categorizer/query/util.py +27 -0
- sql_error_categorizer/sql_errors.py +112 -0
- sql_error_categorizer/util.py +101 -0
- sql_error_categorizer-0.1.0.dist-info/METADATA +149 -0
- sql_error_categorizer-0.1.0.dist-info/RECORD +27 -0
- sql_error_categorizer-0.1.0.dist-info/WHEEL +4 -0
- sql_error_categorizer-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Hidden, internal use only
|
|
2
|
+
from .detectors import BaseDetector as _BaseDetector, Detector as _Detector
|
|
3
|
+
|
|
4
|
+
# Public API
|
|
5
|
+
from .sql_errors import SqlErrors
|
|
6
|
+
from .catalog import Catalog, build_catalog, load_json as load_catalog
|
|
7
|
+
from .detectors import SyntaxErrorDetector, SemanticErrorDetector, LogicalErrorDetector, ComplicationDetector, DetectedError
|
|
8
|
+
|
|
9
|
+
def get_errors(query_str: str,
|
|
10
|
+
solutions: list[str] = [],
|
|
11
|
+
catalog: Catalog = Catalog(),
|
|
12
|
+
search_path: str = 'public',
|
|
13
|
+
solution_search_path: str = 'public',
|
|
14
|
+
detectors: list[type[_BaseDetector]] = [
|
|
15
|
+
SyntaxErrorDetector,
|
|
16
|
+
SemanticErrorDetector,
|
|
17
|
+
LogicalErrorDetector,
|
|
18
|
+
ComplicationDetector
|
|
19
|
+
],
|
|
20
|
+
debug: bool = False) -> list[DetectedError]:
|
|
21
|
+
'''Detect SQL errors in the given query string.'''
|
|
22
|
+
det = _Detector(query_str,
|
|
23
|
+
solutions=solutions,
|
|
24
|
+
catalog=catalog,
|
|
25
|
+
search_path=search_path,
|
|
26
|
+
solution_search_path=solution_search_path,
|
|
27
|
+
debug=debug)
|
|
28
|
+
|
|
29
|
+
for detector in detectors:
|
|
30
|
+
det.add_detector(detector)
|
|
31
|
+
|
|
32
|
+
return det.run()
|
|
33
|
+
|
|
34
|
+
def get_error_types(query_str: str,
|
|
35
|
+
solutions: list[str] = [],
|
|
36
|
+
catalog: Catalog = Catalog(),
|
|
37
|
+
search_path: str = 'public',
|
|
38
|
+
solution_search_path: str = 'public',
|
|
39
|
+
detectors: list[type[_BaseDetector]] = [
|
|
40
|
+
SyntaxErrorDetector,
|
|
41
|
+
SemanticErrorDetector,
|
|
42
|
+
LogicalErrorDetector,
|
|
43
|
+
ComplicationDetector
|
|
44
|
+
],
|
|
45
|
+
debug: bool = False) -> set[SqlErrors]:
|
|
46
|
+
'''Detect SQL error types in the given query string.'''
|
|
47
|
+
|
|
48
|
+
detected_errors = get_errors(query_str,
|
|
49
|
+
solutions=solutions,
|
|
50
|
+
catalog=catalog,
|
|
51
|
+
search_path=search_path,
|
|
52
|
+
solution_search_path=solution_search_path,
|
|
53
|
+
detectors=detectors,
|
|
54
|
+
debug=debug)
|
|
55
|
+
|
|
56
|
+
return {error.error for error in detected_errors}
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
'''Builds a catalog of existing schemas/tables/columns by executing the provided SQL string in a temporary PostgreSQL database.'''
|
|
2
|
+
|
|
3
|
+
from .catalog import Catalog, Schema, Table, Column, UniqueConstraintType
|
|
4
|
+
import psycopg2
|
|
5
|
+
import time
|
|
6
|
+
from . import queries
|
|
7
|
+
|
|
8
|
+
def build_catalog(sql_string: str, *, hostname: str, port: int, user: str, password: str, use_temp_schema: bool = True) -> Catalog:
|
|
9
|
+
'''Builds a catalog by executing the provided SQL string in a temporary PostgreSQL database.'''
|
|
10
|
+
result = Catalog()
|
|
11
|
+
|
|
12
|
+
if sql_string.strip() == '':
|
|
13
|
+
return result
|
|
14
|
+
|
|
15
|
+
conn = psycopg2.connect(host=hostname, port=port, user=user, password=password)
|
|
16
|
+
cur = conn.cursor()
|
|
17
|
+
|
|
18
|
+
# Use a temporary schema with a fixed name
|
|
19
|
+
if use_temp_schema:
|
|
20
|
+
schema_name = f'sql_error_categorizer_{time.time_ns()}'
|
|
21
|
+
cur.execute(f'CREATE schema {schema_name};')
|
|
22
|
+
cur.execute(f'SET search_path TO {schema_name};')
|
|
23
|
+
else:
|
|
24
|
+
schema_name = '%' # TODO: it's a bit hackish, find a more stable solution
|
|
25
|
+
|
|
26
|
+
# Create the tables
|
|
27
|
+
cur.execute(sql_string)
|
|
28
|
+
|
|
29
|
+
from dav_tools import messages
|
|
30
|
+
messages.info(queries.COLUMNS(schema_name))
|
|
31
|
+
|
|
32
|
+
# Fetch the catalog information
|
|
33
|
+
cur.execute(queries.COLUMNS(schema_name))
|
|
34
|
+
columns_info = cur.fetchall()
|
|
35
|
+
|
|
36
|
+
messages.debug(f'Fetched {len(columns_info)} columns from the database.')
|
|
37
|
+
|
|
38
|
+
for column in columns_info:
|
|
39
|
+
messages.debug(f'Processing column: {column}')
|
|
40
|
+
schema_name, table_name, column_name, column_type, numeric_precision, numeric_scale, is_nullable, fk_schema, fk_table, fk_column = column
|
|
41
|
+
|
|
42
|
+
result.add_column(schema_name, table_name, column_name,
|
|
43
|
+
column_type, numeric_precision, numeric_scale,
|
|
44
|
+
is_nullable,
|
|
45
|
+
fk_schema, fk_table, fk_column)
|
|
46
|
+
|
|
47
|
+
# Fetch unique constraints (including primary keys)
|
|
48
|
+
cur.execute(queries.UNIQUE_COLUMNS(schema_name))
|
|
49
|
+
unique_constraints_info = cur.fetchall()
|
|
50
|
+
for constraint in unique_constraints_info:
|
|
51
|
+
schema_name, table_name, constraint_type, columns = constraint
|
|
52
|
+
columns = set(columns.strip('{}').split(',')) # Postgres returns {col1,col2,...}
|
|
53
|
+
|
|
54
|
+
if constraint_type == 'PRIMARY KEY':
|
|
55
|
+
constraint_type = UniqueConstraintType.PRIMARY_KEY
|
|
56
|
+
elif constraint_type == 'UNIQUE':
|
|
57
|
+
constraint_type = UniqueConstraintType.UNIQUE
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f'Unknown constraint type: {constraint_type}')
|
|
60
|
+
|
|
61
|
+
result[schema_name][table_name].add_unique_constraint(columns, constraint_type)
|
|
62
|
+
|
|
63
|
+
# Clean up
|
|
64
|
+
if use_temp_schema:
|
|
65
|
+
cur.execute(f'DROP schema {schema_name} CASCADE;')
|
|
66
|
+
conn.rollback() # no need to save anything
|
|
67
|
+
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
def load_json(path: str) -> Catalog:
|
|
71
|
+
'''Loads a catalog from a JSON file.'''
|
|
72
|
+
return Catalog.load_json(path)
|
|
73
|
+
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
import json
|
|
3
|
+
from typing import Self
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
|
|
7
|
+
# region UniqueConstraint
|
|
8
|
+
class UniqueConstraintType(Enum):
|
|
9
|
+
PRIMARY_KEY = 'PRIMARY KEY'
|
|
10
|
+
UNIQUE = 'UNIQUE'
|
|
11
|
+
|
|
12
|
+
class UniqueConstraint:
|
|
13
|
+
def __init__(self, columns: set[str], constraint_type: UniqueConstraintType) -> None:
|
|
14
|
+
self.columns = columns
|
|
15
|
+
self.constraint_type = constraint_type
|
|
16
|
+
|
|
17
|
+
def __repr__(self, level: int = 0) -> str:
|
|
18
|
+
indent = ' ' * level
|
|
19
|
+
return f'{indent}UniqueConstraint({self.constraint_type.value}: {self.columns})'
|
|
20
|
+
|
|
21
|
+
def to_dict(self) -> dict:
|
|
22
|
+
return {
|
|
23
|
+
'columns': list(self.columns), # JSON-friendly (list)
|
|
24
|
+
'constraint_type': self.constraint_type.value,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_dict(cls, data: dict) -> 'UniqueConstraint':
|
|
29
|
+
return cls(columns=set(c.lower() for c in data['columns']),
|
|
30
|
+
constraint_type=UniqueConstraintType(data['constraint_type']))
|
|
31
|
+
# endregion
|
|
32
|
+
|
|
33
|
+
# region Column
|
|
34
|
+
@dataclass
|
|
35
|
+
class Column:
|
|
36
|
+
name: str
|
|
37
|
+
column_type: str = 'UNKNOWN'
|
|
38
|
+
numeric_precision: int | None = None
|
|
39
|
+
numeric_scale: int | None = None
|
|
40
|
+
is_nullable: bool = True
|
|
41
|
+
is_constant: bool = False
|
|
42
|
+
fk_schema: str | None = None
|
|
43
|
+
fk_table: str | None = None
|
|
44
|
+
fk_column: str | None = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def is_fk(self) -> bool:
|
|
48
|
+
return all([self.fk_schema, self.fk_table, self.fk_column])
|
|
49
|
+
|
|
50
|
+
def __repr__(self, level: int = 0) -> str:
|
|
51
|
+
indent = ' ' * level
|
|
52
|
+
return f'{indent}Column(name=\'{self.name}\', type=\'{self.column_type}\', is_fk={self.is_fk}, is_nullable={self.is_nullable}, is_constant={self.is_constant})'
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> dict:
|
|
55
|
+
return {
|
|
56
|
+
'name': self.name,
|
|
57
|
+
'column_type': self.column_type,
|
|
58
|
+
'numeric_precision': self.numeric_precision,
|
|
59
|
+
'numeric_scale': self.numeric_scale,
|
|
60
|
+
'is_nullable': self.is_nullable,
|
|
61
|
+
'fk_schema': self.fk_schema,
|
|
62
|
+
'fk_table': self.fk_table,
|
|
63
|
+
'fk_column': self.fk_column,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_dict(cls, data: dict) -> 'Column':
|
|
68
|
+
return cls(
|
|
69
|
+
name=data['name'],
|
|
70
|
+
column_type=data['column_type'],
|
|
71
|
+
numeric_precision=data.get('numeric_precision'),
|
|
72
|
+
numeric_scale=data.get('numeric_scale'),
|
|
73
|
+
is_nullable=data.get('is_nullable', True),
|
|
74
|
+
fk_schema=(data.get('fk_schema') or None),
|
|
75
|
+
fk_table=(data.get('fk_table') or None),
|
|
76
|
+
fk_column=(data.get('fk_column') or None),
|
|
77
|
+
)
|
|
78
|
+
# endregion
|
|
79
|
+
|
|
80
|
+
# region Table
|
|
81
|
+
@dataclass
|
|
82
|
+
class Table:
|
|
83
|
+
'''A database table, with columns and unique constraints. Supports multiple columns with the same name (e.g. from joins).'''
|
|
84
|
+
name: str
|
|
85
|
+
unique_constraints: list[UniqueConstraint] = field(default_factory=list)
|
|
86
|
+
columns: list[Column] = field(default_factory=list)
|
|
87
|
+
|
|
88
|
+
def add_unique_constraint(self, columns: set[str], constraint_type: UniqueConstraintType) -> None:
|
|
89
|
+
self.unique_constraints.append(UniqueConstraint(columns, constraint_type))
|
|
90
|
+
|
|
91
|
+
def add_column(self,
|
|
92
|
+
name: str,
|
|
93
|
+
column_type: str,
|
|
94
|
+
numeric_precision: int | None = None,
|
|
95
|
+
numeric_scale: int | None = None,
|
|
96
|
+
is_nullable: bool = True,
|
|
97
|
+
is_constant: bool = False,
|
|
98
|
+
fk_schema: str | None = None,
|
|
99
|
+
fk_table: str | None = None,
|
|
100
|
+
fk_column: str | None = None) -> Column:
|
|
101
|
+
column = Column(name=name,
|
|
102
|
+
column_type=column_type,
|
|
103
|
+
numeric_precision=numeric_precision,
|
|
104
|
+
numeric_scale=numeric_scale,
|
|
105
|
+
is_nullable=is_nullable,
|
|
106
|
+
is_constant=is_constant,
|
|
107
|
+
fk_schema=fk_schema,
|
|
108
|
+
fk_table=fk_table,
|
|
109
|
+
fk_column=fk_column)
|
|
110
|
+
self.columns.append(column)
|
|
111
|
+
return column
|
|
112
|
+
|
|
113
|
+
def has_column(self, column_name: str) -> bool:
|
|
114
|
+
'''Checks if a column exists in the table.'''
|
|
115
|
+
return any(col.name == column_name for col in self.columns)
|
|
116
|
+
|
|
117
|
+
def __getitem__(self, column_name: str) -> Column:
|
|
118
|
+
'''Gets a column from the table, creating it if it does not exist.'''
|
|
119
|
+
for col in self.columns:
|
|
120
|
+
if col.name == column_name:
|
|
121
|
+
return col
|
|
122
|
+
|
|
123
|
+
new_col = Column(name=column_name)
|
|
124
|
+
self.columns.append(new_col)
|
|
125
|
+
return new_col
|
|
126
|
+
|
|
127
|
+
def __repr__(self, level: int = 0) -> str:
|
|
128
|
+
indent = ' ' * level
|
|
129
|
+
|
|
130
|
+
columns = '\n'.join([col.__repr__(level + 1) for col in self.columns])
|
|
131
|
+
if len(self.unique_constraints) < 2:
|
|
132
|
+
unique_constraints_str = ', '.join([uc.__repr__(0) for uc in self.unique_constraints])
|
|
133
|
+
else:
|
|
134
|
+
unique_constraints_str = '\n' + '\n'.join([uc.__repr__(level + 1) for uc in self.unique_constraints]) + '\n' + indent
|
|
135
|
+
|
|
136
|
+
if len(self.columns) > 0:
|
|
137
|
+
columns = '\n' + columns + '\n' + indent
|
|
138
|
+
|
|
139
|
+
return f'{indent}Table(name=\'{self.name}\', columns=[{columns}], unique_constraints=[{unique_constraints_str}])'
|
|
140
|
+
|
|
141
|
+
def to_dict(self) -> dict:
|
|
142
|
+
return {
|
|
143
|
+
'name': self.name,
|
|
144
|
+
'unique_constraints': [uc.to_dict() for uc in self.unique_constraints],
|
|
145
|
+
'columns': [col.to_dict() for col in self.columns],
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def from_dict(cls, data: dict) -> 'Table':
|
|
150
|
+
table = cls(name=data['name'])
|
|
151
|
+
# Unique constraints first (so Column.is_pk works immediately on repr, etc.)
|
|
152
|
+
for uc_data in data.get('unique_constraints', []):
|
|
153
|
+
uc = UniqueConstraint.from_dict(uc_data)
|
|
154
|
+
table.unique_constraints.append(uc)
|
|
155
|
+
# Columns
|
|
156
|
+
for col_data in (data.get('columns') or []):
|
|
157
|
+
col = Column.from_dict(col_data)
|
|
158
|
+
# Keep internal store normalized to lowercase
|
|
159
|
+
table.columns.append(col)
|
|
160
|
+
return table
|
|
161
|
+
# endregion
|
|
162
|
+
|
|
163
|
+
# region Schema
|
|
164
|
+
@dataclass
|
|
165
|
+
class Schema:
|
|
166
|
+
name: str
|
|
167
|
+
_tables: dict[str, Table] = field(default_factory=dict)
|
|
168
|
+
functions: set[str] = field(default_factory=set)
|
|
169
|
+
|
|
170
|
+
def __getitem__(self, table_name: str) -> Table:
|
|
171
|
+
'''Gets a table from the schema, creating it if it does not exist.'''
|
|
172
|
+
if table_name not in self._tables:
|
|
173
|
+
self._tables[table_name] = Table(table_name)
|
|
174
|
+
return self._tables[table_name]
|
|
175
|
+
|
|
176
|
+
def __setitem__(self, table_name: str, table: Table) -> None:
|
|
177
|
+
'''Sets a table in the schema, replacing any existing table with the same name.'''
|
|
178
|
+
self._tables[table_name] = table
|
|
179
|
+
|
|
180
|
+
def has_table(self, table_name: str) -> bool:
|
|
181
|
+
'''Checks if a table exists in the schema.'''
|
|
182
|
+
return table_name in self._tables
|
|
183
|
+
|
|
184
|
+
def has_column(self, table_name: str, column_name: str) -> bool:
|
|
185
|
+
'''Checks if a column exists in the schema.'''
|
|
186
|
+
if not self.has_table(table_name):
|
|
187
|
+
return False
|
|
188
|
+
return self.__getitem__(table_name).has_column(column_name)
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def table_names(self) -> set[str]:
|
|
192
|
+
'''Returns all table names in the schema.'''
|
|
193
|
+
return set(self._tables.keys())
|
|
194
|
+
|
|
195
|
+
def __repr__(self, level: int = 0) -> str:
|
|
196
|
+
indent = ' ' * level
|
|
197
|
+
tables = '\n'.join([table.__repr__(level + 1) for table in self._tables.values()])
|
|
198
|
+
return f'{indent}Schema(name=\'{self.name}\', tables=[\n{tables}\n{indent}])'
|
|
199
|
+
|
|
200
|
+
def to_dict(self) -> dict:
|
|
201
|
+
return {
|
|
202
|
+
'name': self.name,
|
|
203
|
+
'tables': {name: tbl.to_dict() for name, tbl in self._tables.items()},
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def from_dict(cls, data: dict) -> 'Schema':
|
|
208
|
+
schema = cls(name=data['name'])
|
|
209
|
+
for _, tbl_data in (data.get('tables') or {}).items():
|
|
210
|
+
tbl = Table.from_dict(tbl_data)
|
|
211
|
+
schema._tables[tbl.name] = tbl
|
|
212
|
+
return schema
|
|
213
|
+
# endregion
|
|
214
|
+
|
|
215
|
+
# region Catalog
|
|
216
|
+
@dataclass
|
|
217
|
+
class Catalog:
|
|
218
|
+
_schemas: dict[str, Schema] = field(default_factory=dict)
|
|
219
|
+
|
|
220
|
+
def __getitem__(self, schema_name: str) -> Schema:
|
|
221
|
+
'''Gets a schema from the catalog, creating it if it does not exist.'''
|
|
222
|
+
|
|
223
|
+
if schema_name not in self._schemas:
|
|
224
|
+
self._schemas[schema_name] = Schema(schema_name)
|
|
225
|
+
return self._schemas[schema_name]
|
|
226
|
+
|
|
227
|
+
def __setitem__(self, schema_name: str, schema: Schema) -> Schema:
|
|
228
|
+
'''Sets a schema in the catalog, replacing any existing schema with the same name.'''
|
|
229
|
+
|
|
230
|
+
self._schemas[schema_name] = schema
|
|
231
|
+
return schema
|
|
232
|
+
|
|
233
|
+
def has_schema(self, schema_name: str) -> bool:
|
|
234
|
+
'''Checks if a schema exists in the catalog.'''
|
|
235
|
+
|
|
236
|
+
return schema_name in self._schemas
|
|
237
|
+
|
|
238
|
+
def copy_table(self, schema_name: str, table_name: str, table: Table) -> Table:
|
|
239
|
+
'''Copies a table into the catalog, creating the schema if it does not exist.'''
|
|
240
|
+
|
|
241
|
+
new_table = deepcopy(table)
|
|
242
|
+
self[schema_name][table_name] = new_table
|
|
243
|
+
|
|
244
|
+
return new_table
|
|
245
|
+
|
|
246
|
+
def has_table(self, schema_name: str, table_name: str) -> bool:
|
|
247
|
+
'''
|
|
248
|
+
Checks if a table exists in the specified schema in the catalog.
|
|
249
|
+
|
|
250
|
+
Returns False if the schema or table do not exist.
|
|
251
|
+
'''
|
|
252
|
+
|
|
253
|
+
if not self.has_schema(schema_name):
|
|
254
|
+
return False
|
|
255
|
+
return self.__getitem__(schema_name).has_table(table_name)
|
|
256
|
+
|
|
257
|
+
def add_column(self, schema_name: str, table_name: str, column_name: str,
|
|
258
|
+
column_type: str, numeric_precision: int | None = None, numeric_scale: int | None = None,
|
|
259
|
+
is_nullable: bool = True,
|
|
260
|
+
fk_schema: str | None = None, fk_table: str | None = None, fk_column: str | None = None) -> None:
|
|
261
|
+
'''Adds a column to the catalog, creating the schema and table if they do not exist.'''
|
|
262
|
+
|
|
263
|
+
self[schema_name][table_name].add_column(name=column_name,
|
|
264
|
+
column_type=column_type, numeric_precision=numeric_precision, numeric_scale=numeric_scale,
|
|
265
|
+
is_nullable=is_nullable,
|
|
266
|
+
fk_schema=fk_schema, fk_table=fk_table, fk_column=fk_column)
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def schema_names(self) -> set[str]:
|
|
270
|
+
'''Returns all schema names in the catalog.'''
|
|
271
|
+
return set(self._schemas.keys())
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def table_names(self) -> set[str]:
|
|
275
|
+
'''Returns all table names in the catalog, regardless of schema.'''
|
|
276
|
+
|
|
277
|
+
result = set()
|
|
278
|
+
for schema in self._schemas.values():
|
|
279
|
+
result.update(schema.table_names)
|
|
280
|
+
return result
|
|
281
|
+
|
|
282
|
+
def copy(self) -> Self:
|
|
283
|
+
'''Creates a deep copy of the catalog.'''
|
|
284
|
+
return deepcopy(self)
|
|
285
|
+
|
|
286
|
+
def __repr__(self) -> str:
|
|
287
|
+
schemas = [schema.__repr__(1) for schema in self._schemas.values()]
|
|
288
|
+
|
|
289
|
+
result = 'Catalog('
|
|
290
|
+
for schema in schemas:
|
|
291
|
+
result += '\n' + schema
|
|
292
|
+
result += '\n)'
|
|
293
|
+
|
|
294
|
+
return result
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def to_dict(self) -> dict:
|
|
298
|
+
return {
|
|
299
|
+
'version': 1,
|
|
300
|
+
'schemas': {name: sch.to_dict() for name, sch in self._schemas.items()},
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
@classmethod
|
|
304
|
+
def from_dict(cls, data: dict) -> 'Catalog':
|
|
305
|
+
cat = cls()
|
|
306
|
+
for _, sch_data in (data.get('schemas') or {}).items():
|
|
307
|
+
sch = Schema.from_dict(sch_data)
|
|
308
|
+
cat._schemas[sch.name] = sch
|
|
309
|
+
return cat
|
|
310
|
+
|
|
311
|
+
# String-based JSON (handy for DB/blob storage)
|
|
312
|
+
def to_json(self, *, indent: int | None = 2) -> str:
|
|
313
|
+
return json.dumps(self.to_dict(), indent=indent)
|
|
314
|
+
|
|
315
|
+
@classmethod
|
|
316
|
+
def from_json(cls, s: str) -> 'Catalog':
|
|
317
|
+
return cls.from_dict(json.loads(s))
|
|
318
|
+
|
|
319
|
+
# Convenience file helpers
|
|
320
|
+
def save_json(self, path: str, *, indent: int | None = 2) -> None:
|
|
321
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
322
|
+
json.dump(self.to_dict(), f, indent=indent)
|
|
323
|
+
|
|
324
|
+
@classmethod
|
|
325
|
+
def load_json(cls, path: str) -> 'Catalog':
|
|
326
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
327
|
+
data = json.load(f)
|
|
328
|
+
return cls.from_dict(data)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
def UNIQUE_COLUMNS(schema_name: str = '%') -> str:
|
|
2
|
+
return f'''
|
|
3
|
+
SELECT
|
|
4
|
+
kcu.table_schema AS schema_name,
|
|
5
|
+
kcu.table_name,
|
|
6
|
+
tc.constraint_type,
|
|
7
|
+
array_agg(kcu.column_name ORDER BY kcu.ordinal_position) AS columns
|
|
8
|
+
FROM information_schema.table_constraints tc
|
|
9
|
+
JOIN information_schema.key_column_usage kcu
|
|
10
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
11
|
+
AND tc.constraint_schema = kcu.constraint_schema
|
|
12
|
+
WHERE tc.constraint_type IN ('UNIQUE', 'PRIMARY KEY')
|
|
13
|
+
AND kcu.table_schema LIKE '{schema_name}'
|
|
14
|
+
GROUP BY
|
|
15
|
+
kcu.table_schema,
|
|
16
|
+
kcu.table_name,
|
|
17
|
+
kcu.constraint_name,
|
|
18
|
+
tc.constraint_type;
|
|
19
|
+
'''
|
|
20
|
+
|
|
21
|
+
def COLUMNS(schema_name: str = '%') -> str:
|
|
22
|
+
return f'''
|
|
23
|
+
SELECT
|
|
24
|
+
cols.table_schema AS schema_name,
|
|
25
|
+
cols.table_name,
|
|
26
|
+
cols.column_name,
|
|
27
|
+
cols.data_type AS column_type,
|
|
28
|
+
cols.numeric_precision,
|
|
29
|
+
cols.numeric_scale,
|
|
30
|
+
(cols.is_nullable = 'YES') AS is_nullable,
|
|
31
|
+
fk.foreign_table_schema AS foreign_key_schema,
|
|
32
|
+
fk.foreign_table_name AS foreign_key_table,
|
|
33
|
+
fk.foreign_column_name AS foreign_key_column
|
|
34
|
+
FROM information_schema.columns AS cols
|
|
35
|
+
|
|
36
|
+
-- Foreign Key
|
|
37
|
+
LEFT JOIN (
|
|
38
|
+
SELECT
|
|
39
|
+
kcu.table_schema,
|
|
40
|
+
kcu.table_name,
|
|
41
|
+
kcu.column_name,
|
|
42
|
+
ccu.table_schema AS foreign_table_schema,
|
|
43
|
+
ccu.table_name AS foreign_table_name,
|
|
44
|
+
ccu.column_name AS foreign_column_name
|
|
45
|
+
FROM information_schema.table_constraints AS tc
|
|
46
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
47
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
48
|
+
AND tc.constraint_schema = kcu.constraint_schema
|
|
49
|
+
AND tc.table_schema = kcu.table_schema
|
|
50
|
+
AND tc.table_name = kcu.table_name
|
|
51
|
+
JOIN information_schema.constraint_column_usage AS ccu
|
|
52
|
+
ON tc.constraint_name = ccu.constraint_name
|
|
53
|
+
AND tc.constraint_schema = ccu.constraint_schema
|
|
54
|
+
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
55
|
+
) fk ON fk.table_schema = cols.table_schema
|
|
56
|
+
AND fk.table_name = cols.table_name
|
|
57
|
+
AND fk.column_name = cols.column_name
|
|
58
|
+
|
|
59
|
+
WHERE cols.table_schema LIKE '{schema_name}'
|
|
60
|
+
'''
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from .. import catalog
|
|
2
|
+
from ..query import Query
|
|
3
|
+
from ..sql_errors import SqlErrors
|
|
4
|
+
from .base import BaseDetector, DetectedError
|
|
5
|
+
from .syntax import SyntaxErrorDetector
|
|
6
|
+
from .semantic import SemanticErrorDetector
|
|
7
|
+
from .logical import LogicalErrorDetector
|
|
8
|
+
from .complications import ComplicationDetector
|
|
9
|
+
|
|
10
|
+
class Detector:
|
|
11
|
+
def __init__(self,
|
|
12
|
+
query: str,
|
|
13
|
+
*,
|
|
14
|
+
search_path: str = 'public',
|
|
15
|
+
solution_search_path: str = 'public',
|
|
16
|
+
solutions: list[str] = [],
|
|
17
|
+
catalog: catalog.Catalog = catalog.Catalog(),
|
|
18
|
+
detectors: list[type[BaseDetector]] = [],
|
|
19
|
+
debug: bool = False):
|
|
20
|
+
|
|
21
|
+
# Context data: they don't need to be parsed again if the query changes
|
|
22
|
+
self.search_path = search_path
|
|
23
|
+
self.solution_search_path = solution_search_path
|
|
24
|
+
self.catalog = catalog
|
|
25
|
+
self.solutions = [Query(sol, catalog=self.catalog, search_path=self.solution_search_path) for sol in solutions]
|
|
26
|
+
self.detectors: list[BaseDetector] = []
|
|
27
|
+
self.debug = debug
|
|
28
|
+
|
|
29
|
+
self.set_query(query)
|
|
30
|
+
|
|
31
|
+
# NOTE: Add detectors after setting the query to ensure they are correctly initialized
|
|
32
|
+
for detector_cls in detectors:
|
|
33
|
+
self.add_detector(detector_cls)
|
|
34
|
+
|
|
35
|
+
def set_query(self, query: str, reason: str | None = None) -> None:
|
|
36
|
+
'''Set a new query, re-parse it, and update all detectors. Doesn't affect detected errors.'''
|
|
37
|
+
|
|
38
|
+
if self.debug:
|
|
39
|
+
print('=' * 20)
|
|
40
|
+
if reason:
|
|
41
|
+
print(f'Updating query ({reason}):\n{query}')
|
|
42
|
+
else:
|
|
43
|
+
print(f'Updating query:\n{query}')
|
|
44
|
+
print('=' * 20)
|
|
45
|
+
|
|
46
|
+
self.query = Query(query, catalog=self.catalog, search_path=self.search_path)
|
|
47
|
+
|
|
48
|
+
# Update all detectors with the new query and parse results
|
|
49
|
+
for detector in self.detectors:
|
|
50
|
+
detector.query = self.query
|
|
51
|
+
detector.update_query = lambda new_query, reason=None: self.set_query(new_query, reason)
|
|
52
|
+
|
|
53
|
+
def add_detector(self, detector_cls: type[BaseDetector]) -> None:
|
|
54
|
+
'''Add a detector instance to the list of detectors'''
|
|
55
|
+
|
|
56
|
+
# Make copies to avoid possible modifications during detection
|
|
57
|
+
# TODO: check if it's needed
|
|
58
|
+
detector = detector_cls(
|
|
59
|
+
query=self.query,
|
|
60
|
+
solutions=self.solutions,
|
|
61
|
+
update_query=lambda new_query, reason=None: self.set_query(new_query, reason),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.detectors.append(detector)
|
|
65
|
+
|
|
66
|
+
def run(self) -> list[DetectedError]:
|
|
67
|
+
'''
|
|
68
|
+
Run all detectors and return a list of detected errors.
|
|
69
|
+
This function can return duplicate errors, as well as additional information on the detected errors.
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
if self.debug:
|
|
73
|
+
print('===== Query =====')
|
|
74
|
+
print(self.query.sql)
|
|
75
|
+
|
|
76
|
+
results: list[DetectedError] = []
|
|
77
|
+
|
|
78
|
+
for detector in self.detectors:
|
|
79
|
+
errors = detector.run()
|
|
80
|
+
|
|
81
|
+
if self.debug:
|
|
82
|
+
print(f'===== Detected errors from {detector.__class__.__name__} =====')
|
|
83
|
+
for error in errors:
|
|
84
|
+
print(error)
|
|
85
|
+
|
|
86
|
+
results.extend(errors)
|
|
87
|
+
|
|
88
|
+
return results
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from ..sql_errors import SqlErrors
|
|
6
|
+
from ..query import Query
|
|
7
|
+
from ..catalog import Catalog
|
|
8
|
+
|
|
9
|
+
@dataclass(repr=False)
|
|
10
|
+
class DetectedError:
|
|
11
|
+
error: SqlErrors
|
|
12
|
+
data: tuple[Any, ...] = field(default_factory=tuple)
|
|
13
|
+
|
|
14
|
+
def __repr__(self):
|
|
15
|
+
return f"DetectedError({self.error.value} - {self.error.name}: {self.data})"
|
|
16
|
+
|
|
17
|
+
def __str__(self) -> str:
|
|
18
|
+
if self.data:
|
|
19
|
+
return f'[{self.error.value:3}] {self.error.name}: {self.data}'
|
|
20
|
+
return f'[{self.error.value:3}] {self.error.name}'
|
|
21
|
+
|
|
22
|
+
def __hash__(self) -> int:
|
|
23
|
+
return hash((self.error, self.data))
|
|
24
|
+
|
|
25
|
+
class BaseDetector(ABC):
|
|
26
|
+
def __init__(self, *,
|
|
27
|
+
query: Query,
|
|
28
|
+
solutions: list[Query] = [],
|
|
29
|
+
update_query: Callable[[str, str | None], None],
|
|
30
|
+
):
|
|
31
|
+
self.query = query
|
|
32
|
+
self.solutions = solutions
|
|
33
|
+
self.update_query = update_query
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def run(self) -> list[DetectedError]:
|
|
37
|
+
'''Run the detector and return a list of detected errors with their descriptions'''
|
|
38
|
+
return []
|
|
39
|
+
|