sqlscope 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlscope/__init__.py +4 -0
- sqlscope/catalog/__init__.py +12 -0
- sqlscope/catalog/builder/__init__.py +8 -0
- sqlscope/catalog/builder/postgres.py +207 -0
- sqlscope/catalog/builder/sql.py +219 -0
- sqlscope/catalog/catalog.py +147 -0
- sqlscope/catalog/column.py +68 -0
- sqlscope/catalog/constraint.py +83 -0
- sqlscope/catalog/schema.py +60 -0
- sqlscope/catalog/table.py +112 -0
- sqlscope/query/__init__.py +5 -0
- sqlscope/query/extractors.py +118 -0
- sqlscope/query/query.py +191 -0
- sqlscope/query/set_operations/__init__.py +181 -0
- sqlscope/query/set_operations/binary_set_operation.py +162 -0
- sqlscope/query/set_operations/select.py +664 -0
- sqlscope/query/set_operations/set_operation.py +59 -0
- sqlscope/query/smt.py +334 -0
- sqlscope/query/tokenized_sql.py +70 -0
- sqlscope/query/typechecking/__init__.py +21 -0
- sqlscope/query/typechecking/base.py +9 -0
- sqlscope/query/typechecking/binary_ops.py +57 -0
- sqlscope/query/typechecking/functions.py +81 -0
- sqlscope/query/typechecking/predicates.py +123 -0
- sqlscope/query/typechecking/primitives.py +80 -0
- sqlscope/query/typechecking/queries.py +35 -0
- sqlscope/query/typechecking/types.py +59 -0
- sqlscope/query/typechecking/unary_ops.py +51 -0
- sqlscope/query/typechecking/util.py +51 -0
- sqlscope/util/__init__.py +18 -0
- sqlscope/util/ast/__init__.py +55 -0
- sqlscope/util/ast/column.py +55 -0
- sqlscope/util/ast/function.py +10 -0
- sqlscope/util/ast/subquery.py +23 -0
- sqlscope/util/ast/table.py +36 -0
- sqlscope/util/sql.py +27 -0
- sqlscope/util/tokens.py +17 -0
- sqlscope-1.0.0.dist-info/METADATA +52 -0
- sqlscope-1.0.0.dist-info/RECORD +41 -0
- sqlscope-1.0.0.dist-info/WHEEL +4 -0
- sqlscope-1.0.0.dist-info/licenses/LICENSE +21 -0
sqlscope/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
'''Represents a catalog of database schemas, tables, and columns.'''
|
|
2
|
+
|
|
3
|
+
# API exports
|
|
4
|
+
from .constraint import Constraint, ConstraintColumn, ConstraintType
|
|
5
|
+
from .column import Column
|
|
6
|
+
from .table import Table
|
|
7
|
+
from .schema import Schema
|
|
8
|
+
from .catalog import Catalog
|
|
9
|
+
from .builder import CatalogColumnInfo, CatalogUniqueConstraintInfo, build_catalog, build_catalog_from_postgres, load_catalog, build_catalog_from_sql
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from ..catalog import Catalog
|
|
2
|
+
from .postgres import build_catalog, build_catalog_from_postgres, CatalogColumnInfo, CatalogUniqueConstraintInfo
|
|
3
|
+
from .sql import build_catalog_from_sql
|
|
4
|
+
|
|
5
|
+
def load_catalog(path: str) -> Catalog:
|
|
6
|
+
'''Loads a catalog from a JSON file.'''
|
|
7
|
+
return Catalog.load_json(path)
|
|
8
|
+
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from ..catalog import Catalog
|
|
2
|
+
from ..constraint import ConstraintType
|
|
3
|
+
|
|
4
|
+
import psycopg2
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
# region Data Classes
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class CatalogColumnInfo:
|
|
11
|
+
'''Holds information about a database column.'''
|
|
12
|
+
schema_name: str
|
|
13
|
+
table_name: str
|
|
14
|
+
column_name: str
|
|
15
|
+
column_type: str
|
|
16
|
+
numeric_precision: int | None
|
|
17
|
+
numeric_scale: int | None
|
|
18
|
+
is_nullable: bool
|
|
19
|
+
foreign_key_schema: str | None
|
|
20
|
+
foreign_key_table: str | None
|
|
21
|
+
foreign_key_column: str | None
|
|
22
|
+
|
|
23
|
+
def to_dict(self) -> dict:
|
|
24
|
+
return {
|
|
25
|
+
'schema_name': self.schema_name,
|
|
26
|
+
'table_name': self.table_name,
|
|
27
|
+
'column_name': self.column_name,
|
|
28
|
+
'column_type': self.column_type,
|
|
29
|
+
'numeric_precision': self.numeric_precision,
|
|
30
|
+
'numeric_scale': self.numeric_scale,
|
|
31
|
+
'is_nullable': self.is_nullable,
|
|
32
|
+
'foreign_key_schema': self.foreign_key_schema,
|
|
33
|
+
'foreign_key_table': self.foreign_key_table,
|
|
34
|
+
'foreign_key_column': self.foreign_key_column,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class CatalogUniqueConstraintInfo:
|
|
39
|
+
'''Holds information about a unique constraint or primary key.'''
|
|
40
|
+
schema_name: str
|
|
41
|
+
table_name: str
|
|
42
|
+
constraint_type: str
|
|
43
|
+
columns: str # Postgres returns this as a string like '{col1,col2,...}'
|
|
44
|
+
|
|
45
|
+
def to_dict(self) -> dict:
|
|
46
|
+
return {
|
|
47
|
+
'schema_name': self.schema_name,
|
|
48
|
+
'table_name': self.table_name,
|
|
49
|
+
'constraint_type': self.constraint_type,
|
|
50
|
+
'columns': self.columns,
|
|
51
|
+
}
|
|
52
|
+
# endregion
|
|
53
|
+
|
|
54
|
+
# region Catalog Builder
|
|
55
|
+
def build_catalog(columns_info: list[CatalogColumnInfo], unique_constraints_info: list[CatalogUniqueConstraintInfo]) -> Catalog:
|
|
56
|
+
'''Builds a catalog from the provided column and unique constraint information.'''
|
|
57
|
+
result = Catalog()
|
|
58
|
+
|
|
59
|
+
for column in columns_info:
|
|
60
|
+
result.add_column(
|
|
61
|
+
schema_name=column.schema_name,
|
|
62
|
+
table_name=column.table_name,
|
|
63
|
+
column_name=column.column_name,
|
|
64
|
+
column_type=column.column_type,
|
|
65
|
+
numeric_precision=column.numeric_precision,
|
|
66
|
+
numeric_scale=column.numeric_scale,
|
|
67
|
+
is_nullable=column.is_nullable,
|
|
68
|
+
fk_schema=column.foreign_key_schema,
|
|
69
|
+
fk_table=column.foreign_key_table,
|
|
70
|
+
fk_column=column.foreign_key_column,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
for constraint in unique_constraints_info:
|
|
74
|
+
columns = set(constraint.columns.strip('{}').split(',')) # Postgres returns {col1,col2,...}
|
|
75
|
+
constraint_type = ConstraintType.PRIMARY_KEY if constraint.constraint_type == 'PRIMARY KEY' else ConstraintType.UNIQUE
|
|
76
|
+
|
|
77
|
+
result[constraint.schema_name][constraint.table_name].add_unique_constraint(columns, constraint_type=constraint_type)
|
|
78
|
+
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def build_catalog_from_postgres(sql_string: str, *, hostname: str, port: int, user: str, password: str, schema: str | None = None, create_temp_schema: bool = False) -> Catalog:
|
|
83
|
+
'''Builds a catalog by executing the provided SQL string in a temporary PostgreSQL database.'''
|
|
84
|
+
if sql_string.strip() == '':
|
|
85
|
+
return Catalog()
|
|
86
|
+
|
|
87
|
+
conn = psycopg2.connect(host=hostname, port=port, user=user, password=password)
|
|
88
|
+
cur = conn.cursor()
|
|
89
|
+
|
|
90
|
+
# Use a temporary schema with a fixed name
|
|
91
|
+
if create_temp_schema:
|
|
92
|
+
if schema is None:
|
|
93
|
+
schema_name = f'sql_error_categorizer_{time.time_ns()}'
|
|
94
|
+
else:
|
|
95
|
+
schema_name = schema
|
|
96
|
+
cur.execute(f'CREATE schema {schema_name};')
|
|
97
|
+
cur.execute(f'SET search_path TO {schema_name};')
|
|
98
|
+
else:
|
|
99
|
+
schema_name = '%' if schema is None else schema
|
|
100
|
+
|
|
101
|
+
# Create the tables
|
|
102
|
+
cur.execute(sql_string)
|
|
103
|
+
|
|
104
|
+
# Fetch the catalog information
|
|
105
|
+
cur.execute(COLUMNS(schema_name))
|
|
106
|
+
columns_info = cur.fetchall()
|
|
107
|
+
|
|
108
|
+
columns_data = [
|
|
109
|
+
CatalogColumnInfo(
|
|
110
|
+
schema_name=row[0],
|
|
111
|
+
table_name=row[1],
|
|
112
|
+
column_name=row[2],
|
|
113
|
+
column_type=row[3],
|
|
114
|
+
numeric_precision=row[4],
|
|
115
|
+
numeric_scale=row[5],
|
|
116
|
+
is_nullable=row[6],
|
|
117
|
+
foreign_key_schema=row[7],
|
|
118
|
+
foreign_key_table=row[8],
|
|
119
|
+
foreign_key_column=row[9],
|
|
120
|
+
)
|
|
121
|
+
for row in columns_info
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
# Fetch unique constraints (including primary keys)
|
|
125
|
+
cur.execute(UNIQUE_COLUMNS(schema_name))
|
|
126
|
+
unique_constraints_info = cur.fetchall()
|
|
127
|
+
|
|
128
|
+
unique_constraints_data = [
|
|
129
|
+
CatalogUniqueConstraintInfo(
|
|
130
|
+
schema_name=row[0],
|
|
131
|
+
table_name=row[1],
|
|
132
|
+
constraint_type=row[2],
|
|
133
|
+
columns=row[3],
|
|
134
|
+
)
|
|
135
|
+
for row in unique_constraints_info
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
# Clean up
|
|
139
|
+
if create_temp_schema:
|
|
140
|
+
cur.execute(f'DROP schema {schema_name} CASCADE;')
|
|
141
|
+
conn.rollback() # no need to save anything
|
|
142
|
+
|
|
143
|
+
return build_catalog(columns_data, unique_constraints_data)
|
|
144
|
+
# endregion
|
|
145
|
+
|
|
146
|
+
# region SQL Queries
|
|
147
|
+
def UNIQUE_COLUMNS(schema_name: str = '%') -> str:
|
|
148
|
+
return f'''
|
|
149
|
+
SELECT
|
|
150
|
+
kcu.table_schema AS schema_name,
|
|
151
|
+
kcu.table_name,
|
|
152
|
+
tc.constraint_type,
|
|
153
|
+
array_agg(kcu.column_name ORDER BY kcu.ordinal_position) AS columns
|
|
154
|
+
FROM information_schema.table_constraints tc
|
|
155
|
+
JOIN information_schema.key_column_usage kcu
|
|
156
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
157
|
+
AND tc.constraint_schema = kcu.constraint_schema
|
|
158
|
+
WHERE tc.constraint_type IN ('UNIQUE', 'PRIMARY KEY')
|
|
159
|
+
AND kcu.table_schema LIKE '{schema_name}'
|
|
160
|
+
GROUP BY
|
|
161
|
+
kcu.table_schema,
|
|
162
|
+
kcu.table_name,
|
|
163
|
+
kcu.constraint_name,
|
|
164
|
+
tc.constraint_type;
|
|
165
|
+
'''
|
|
166
|
+
|
|
167
|
+
def COLUMNS(schema_name: str = '%') -> str:
|
|
168
|
+
return f'''
|
|
169
|
+
SELECT
|
|
170
|
+
cols.table_schema AS schema_name,
|
|
171
|
+
cols.table_name,
|
|
172
|
+
cols.column_name,
|
|
173
|
+
cols.data_type AS column_type,
|
|
174
|
+
cols.numeric_precision,
|
|
175
|
+
cols.numeric_scale,
|
|
176
|
+
(cols.is_nullable = 'YES') AS is_nullable,
|
|
177
|
+
fk.foreign_table_schema AS foreign_key_schema,
|
|
178
|
+
fk.foreign_table_name AS foreign_key_table,
|
|
179
|
+
fk.foreign_column_name AS foreign_key_column
|
|
180
|
+
FROM information_schema.columns AS cols
|
|
181
|
+
|
|
182
|
+
-- Foreign Key
|
|
183
|
+
LEFT JOIN (
|
|
184
|
+
SELECT
|
|
185
|
+
kcu.table_schema,
|
|
186
|
+
kcu.table_name,
|
|
187
|
+
kcu.column_name,
|
|
188
|
+
ccu.table_schema AS foreign_table_schema,
|
|
189
|
+
ccu.table_name AS foreign_table_name,
|
|
190
|
+
ccu.column_name AS foreign_column_name
|
|
191
|
+
FROM information_schema.table_constraints AS tc
|
|
192
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
193
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
194
|
+
AND tc.constraint_schema = kcu.constraint_schema
|
|
195
|
+
AND tc.table_schema = kcu.table_schema
|
|
196
|
+
AND tc.table_name = kcu.table_name
|
|
197
|
+
JOIN information_schema.constraint_column_usage AS ccu
|
|
198
|
+
ON tc.constraint_name = ccu.constraint_name
|
|
199
|
+
AND tc.constraint_schema = ccu.constraint_schema
|
|
200
|
+
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
201
|
+
) fk ON fk.table_schema = cols.table_schema
|
|
202
|
+
AND fk.table_name = cols.table_name
|
|
203
|
+
AND fk.column_name = cols.column_name
|
|
204
|
+
|
|
205
|
+
WHERE cols.table_schema LIKE '{schema_name}'
|
|
206
|
+
'''
|
|
207
|
+
# endregion
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import sqlglot
|
|
2
|
+
from sqlglot import exp
|
|
3
|
+
from ..catalog import Catalog
|
|
4
|
+
from ..constraint import ConstraintType
|
|
5
|
+
|
|
6
|
+
def _get_identifier_name(identifier_exp: exp.Identifier) -> str:
|
|
7
|
+
'''Returns the normalized name from an Identifier expression.'''
|
|
8
|
+
|
|
9
|
+
if identifier_exp.quoted:
|
|
10
|
+
return identifier_exp.name
|
|
11
|
+
return identifier_exp.name.lower()
|
|
12
|
+
|
|
13
|
+
def _get_table_name(table_exp: exp.Table) -> str:
|
|
14
|
+
'''Returns the normalized table name from a Table expression.'''
|
|
15
|
+
|
|
16
|
+
if isinstance(table_exp.this, exp.Identifier):
|
|
17
|
+
return _get_identifier_name(table_exp.this)
|
|
18
|
+
return str(table_exp.this).lower()
|
|
19
|
+
|
|
20
|
+
def _get_schema_name(table_exp: exp.Table, default_schema: str) -> str:
|
|
21
|
+
'''Returns the normalized schema name from a Table expression, or the default schema if not specified.'''
|
|
22
|
+
|
|
23
|
+
if table_exp.db:
|
|
24
|
+
if isinstance(table_exp.db, exp.Identifier):
|
|
25
|
+
return _get_identifier_name(table_exp.db)
|
|
26
|
+
return str(table_exp.db).lower()
|
|
27
|
+
return default_schema
|
|
28
|
+
|
|
29
|
+
def _get_column_name(column_exp: exp.ColumnDef) -> str:
|
|
30
|
+
'''Returns the normalized column name from a ColumnDef expression.'''
|
|
31
|
+
|
|
32
|
+
if isinstance(column_exp.this, exp.Identifier):
|
|
33
|
+
return _get_identifier_name(column_exp.this)
|
|
34
|
+
return str(column_exp.this).lower()
|
|
35
|
+
|
|
36
|
+
def _extract_datatype(column_exp: exp.ColumnDef) -> tuple[str, int | None, int | None]:
|
|
37
|
+
'''Extracts datatype information from a ColumnDef expression.'''
|
|
38
|
+
|
|
39
|
+
datatype_exp = column_exp.kind
|
|
40
|
+
assert isinstance(datatype_exp, exp.DataType), 'Expected DataType expression in ColumnDef'
|
|
41
|
+
|
|
42
|
+
datatype = datatype_exp.this.value
|
|
43
|
+
|
|
44
|
+
numeric_precision = None
|
|
45
|
+
numeric_scale = None
|
|
46
|
+
|
|
47
|
+
if datatype_exp.expressions:
|
|
48
|
+
if datatype in {'DECIMAL', 'NUMERIC', 'NUMBER', 'FLOAT'}:
|
|
49
|
+
if len(datatype_exp.expressions) >= 1:
|
|
50
|
+
precision_exp = datatype_exp.expressions[0]
|
|
51
|
+
if isinstance(precision_exp, exp.Literal) and precision_exp.is_int:
|
|
52
|
+
numeric_precision = int(precision_exp.name)
|
|
53
|
+
if len(datatype_exp.expressions) == 2:
|
|
54
|
+
scale_exp = datatype_exp.expressions[1]
|
|
55
|
+
if isinstance(scale_exp, exp.Literal) and scale_exp.is_int:
|
|
56
|
+
numeric_scale = int(scale_exp.name)
|
|
57
|
+
|
|
58
|
+
return datatype, numeric_precision, numeric_scale
|
|
59
|
+
|
|
60
|
+
def build_catalog_from_sql(sql_string: str, search_path: str = 'public') -> Catalog:
|
|
61
|
+
'''Builds a catalog from the provided SQL string without executing it in a database.'''
|
|
62
|
+
|
|
63
|
+
statements = sqlglot.parse(sql_string)
|
|
64
|
+
|
|
65
|
+
# Filter to only CREATE TABLE statements
|
|
66
|
+
statements = [stmt for stmt in statements if isinstance(stmt, exp.Create) and stmt.kind and stmt.kind.upper() == 'TABLE']
|
|
67
|
+
|
|
68
|
+
catalog = Catalog()
|
|
69
|
+
|
|
70
|
+
for statement in statements:
|
|
71
|
+
table_exp = statement.find(exp.Table)
|
|
72
|
+
|
|
73
|
+
assert table_exp is not None, 'Expected Table expression in CREATE TABLE statement'
|
|
74
|
+
|
|
75
|
+
# "CREATE TABLE <schema_name>.<table_name>" handling
|
|
76
|
+
table_name = _get_table_name(table_exp)
|
|
77
|
+
schema_name = _get_schema_name(table_exp, search_path)
|
|
78
|
+
|
|
79
|
+
# Extract other relevant information
|
|
80
|
+
column_exps: list[exp.ColumnDef] = list(statement.find_all(exp.ColumnDef))
|
|
81
|
+
'''Column definitions'''
|
|
82
|
+
|
|
83
|
+
pk_exp: exp.PrimaryKey | None = statement.find(exp.PrimaryKey)
|
|
84
|
+
'''PRIMARY KEY defined at table level, e.g., PRIMARY KEY (col1, col2)'''
|
|
85
|
+
|
|
86
|
+
fk_exps: list[exp.ForeignKey] = list(statement.find_all(exp.ForeignKey))
|
|
87
|
+
'''FOREIGN KEY defined at table level, e.g., FOREIGN KEY (col1) REFERENCES other_table(other_col)'''
|
|
88
|
+
|
|
89
|
+
fks: dict[str, tuple[str, str, str]] = {}
|
|
90
|
+
'''Mapping of foreign key column names to (schema, table, column) tuples'''
|
|
91
|
+
# NOTE: this needs to be filled in before adding columns to the catalog
|
|
92
|
+
|
|
93
|
+
unique_exps: list[exp.UniqueColumnConstraint] = list(statement.find_all(exp.UniqueColumnConstraint))
|
|
94
|
+
'''UNIQUE constraints defined at table level, e.g., UNIQUE (col1, col2)'''
|
|
95
|
+
|
|
96
|
+
pk_col_names: set[str] = set()
|
|
97
|
+
'''Set to keep track of primary key column names'''
|
|
98
|
+
|
|
99
|
+
unique_col_names: list[set[str]] = []
|
|
100
|
+
'''List to keep track of unique constraint column name sets'''
|
|
101
|
+
|
|
102
|
+
# Process table-level Foreign Key constraints
|
|
103
|
+
for fk_exp in fk_exps:
|
|
104
|
+
fk_id_exps = fk_exp.expressions
|
|
105
|
+
fk_column_names = [_get_identifier_name(col_exp) for col_exp in fk_id_exps]
|
|
106
|
+
|
|
107
|
+
ref_exp = fk_exp.find(exp.Reference)
|
|
108
|
+
assert ref_exp is not None, 'Expected Reference expression in Foreign Key definition'
|
|
109
|
+
|
|
110
|
+
ref_schema_exp = ref_exp.this
|
|
111
|
+
assert isinstance(ref_schema_exp, exp.Schema), 'Expected Schema expression in Foreign Key reference'
|
|
112
|
+
|
|
113
|
+
ref_table_exp = ref_schema_exp.this
|
|
114
|
+
assert isinstance(ref_table_exp, exp.Table), 'Expected Table expression in Foreign Key reference'
|
|
115
|
+
|
|
116
|
+
ref_schema_name = _get_schema_name(ref_table_exp, search_path)
|
|
117
|
+
ref_table_name = _get_table_name(ref_table_exp)
|
|
118
|
+
|
|
119
|
+
ref_id_exps = ref_schema_exp.expressions
|
|
120
|
+
ref_column_names = [_get_identifier_name(col_exp) for col_exp in ref_id_exps]
|
|
121
|
+
|
|
122
|
+
# e.g. "FOREIGN KEY (tenant_id, order_id) REFERENCES orders (tenant_id, order_id)"
|
|
123
|
+
for fk_col_name, ref_col_name in zip(fk_column_names, ref_column_names):
|
|
124
|
+
fks[fk_col_name] = (ref_schema_name, ref_table_name, ref_col_name)
|
|
125
|
+
|
|
126
|
+
# Process columns
|
|
127
|
+
for column_exp in column_exps:
|
|
128
|
+
column_name = _get_column_name(column_exp)
|
|
129
|
+
|
|
130
|
+
# Primary Key handling
|
|
131
|
+
is_pk = any(isinstance(c.kind, exp.PrimaryKeyColumnConstraint) for c in column_exp.constraints)
|
|
132
|
+
if is_pk:
|
|
133
|
+
pk_col_names.add(column_name)
|
|
134
|
+
|
|
135
|
+
# Unique handling
|
|
136
|
+
is_unique = any(isinstance(c.kind, exp.UniqueColumnConstraint) for c in column_exp.constraints)
|
|
137
|
+
if is_unique:
|
|
138
|
+
unique_col_names.append({column_name})
|
|
139
|
+
|
|
140
|
+
# Not Null handling
|
|
141
|
+
is_not_null = any(isinstance(c.kind, exp.NotNullColumnConstraint) for c in column_exp.constraints)
|
|
142
|
+
|
|
143
|
+
# Foreign Key handling
|
|
144
|
+
fk_constraint = next((c for c in column_exp.constraints if isinstance(c.kind, exp.Reference)), None)
|
|
145
|
+
|
|
146
|
+
if fk_constraint:
|
|
147
|
+
fk_reference = fk_constraint.kind
|
|
148
|
+
assert isinstance(fk_reference, exp.Reference), 'Expected Reference expression in Foreign Key constraint'
|
|
149
|
+
|
|
150
|
+
fk_schema_exp = fk_reference.this
|
|
151
|
+
assert isinstance(fk_schema_exp, exp.Schema), 'Expected Schema expression in Foreign Key constraint'
|
|
152
|
+
|
|
153
|
+
fk_table_exp = fk_schema_exp.this
|
|
154
|
+
assert isinstance(fk_table_exp, exp.Table), 'Expected Table expression in Foreign Key constraint'
|
|
155
|
+
|
|
156
|
+
fk_schema_name = _get_schema_name(fk_table_exp, search_path)
|
|
157
|
+
fk_table_name = _get_table_name(fk_table_exp)
|
|
158
|
+
|
|
159
|
+
fk_column_exp = fk_schema_exp.expressions[0]
|
|
160
|
+
assert isinstance(fk_column_exp, exp.Identifier), 'Expected Identifier expression in Foreign Key column'
|
|
161
|
+
fk_column_name = _get_identifier_name(fk_column_exp)
|
|
162
|
+
elif column_name in fks:
|
|
163
|
+
fk_schema_name, fk_table_name, fk_column_name = fks[column_name]
|
|
164
|
+
else:
|
|
165
|
+
fk_schema_name = None
|
|
166
|
+
fk_table_name = None
|
|
167
|
+
fk_column_name = None
|
|
168
|
+
|
|
169
|
+
# Datatype handling
|
|
170
|
+
column_type, numeric_precision, numeric_scale = _extract_datatype(column_exp)
|
|
171
|
+
|
|
172
|
+
# Add column to catalog
|
|
173
|
+
catalog[schema_name][table_name].add_column(
|
|
174
|
+
name=column_name,
|
|
175
|
+
column_type=column_type,
|
|
176
|
+
real_name=column_name,
|
|
177
|
+
numeric_precision=numeric_precision,
|
|
178
|
+
numeric_scale=numeric_scale,
|
|
179
|
+
is_nullable=not is_not_null,
|
|
180
|
+
fk_schema=fk_schema_name,
|
|
181
|
+
fk_table=fk_table_name,
|
|
182
|
+
fk_column=fk_column_name)
|
|
183
|
+
|
|
184
|
+
# Process table-level Primary Key constraint
|
|
185
|
+
if pk_exp:
|
|
186
|
+
for ordered_exp in pk_exp.expressions:
|
|
187
|
+
col_exp = ordered_exp.find(exp.Column)
|
|
188
|
+
assert col_exp is not None, 'Expected Column expression in Primary Key definition'
|
|
189
|
+
col_name = _get_column_name(col_exp)
|
|
190
|
+
pk_col_names.add(col_name)
|
|
191
|
+
|
|
192
|
+
# Process table-level Unique constraints
|
|
193
|
+
for unique_exp in unique_exps:
|
|
194
|
+
unique_schema_exp = unique_exp.this
|
|
195
|
+
assert isinstance(unique_schema_exp, exp.Schema), 'Expected Schema expression in Unique constraint'
|
|
196
|
+
|
|
197
|
+
unique_column_names = set()
|
|
198
|
+
for col_id_exp in unique_exp.expressions:
|
|
199
|
+
col_name = _get_identifier_name(col_id_exp)
|
|
200
|
+
unique_column_names.add(col_name)
|
|
201
|
+
unique_col_names.append(unique_column_names)
|
|
202
|
+
|
|
203
|
+
# Add Primary Key constraint to catalog
|
|
204
|
+
# NOTE: needs to be perfomed after all columns have been added, since PKs can be defined at both column and table level
|
|
205
|
+
assert len(pk_col_names) > 0, 'Primary Key columns should have been identified'
|
|
206
|
+
catalog[schema_name][table_name].add_unique_constraint(
|
|
207
|
+
columns=pk_col_names,
|
|
208
|
+
constraint_type=ConstraintType.PRIMARY_KEY
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Add Unique constraints to catalog
|
|
212
|
+
# NOTE: needs to be perfomed after all columns have been added, since Unique constraints can be defined at both column and table level
|
|
213
|
+
for unique_col_name_set in unique_col_names:
|
|
214
|
+
catalog[schema_name][table_name].add_unique_constraint(
|
|
215
|
+
columns=unique_col_name_set,
|
|
216
|
+
constraint_type=ConstraintType.UNIQUE
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return catalog
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from .table import Table
|
|
2
|
+
from .schema import Schema
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
import json
|
|
6
|
+
from typing import Self
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Catalog:
|
|
11
|
+
'''A database catalog, with schemas, tables, and columns.'''
|
|
12
|
+
|
|
13
|
+
_schemas: dict[str, Schema] = field(default_factory=dict)
|
|
14
|
+
|
|
15
|
+
def __getitem__(self, schema_name: str) -> Schema:
|
|
16
|
+
'''Gets a schema from the catalog, creating it if it does not exist.'''
|
|
17
|
+
|
|
18
|
+
if schema_name not in self._schemas:
|
|
19
|
+
self._schemas[schema_name] = Schema(schema_name)
|
|
20
|
+
return self._schemas[schema_name]
|
|
21
|
+
|
|
22
|
+
def __setitem__(self, schema_name: str, schema: Schema) -> Schema:
|
|
23
|
+
'''Sets a schema in the catalog, replacing any existing schema with the same name.'''
|
|
24
|
+
|
|
25
|
+
self._schemas[schema_name] = schema
|
|
26
|
+
return schema
|
|
27
|
+
|
|
28
|
+
def has_schema(self, schema_name: str) -> bool:
|
|
29
|
+
'''Checks if a schema exists in the catalog.'''
|
|
30
|
+
|
|
31
|
+
return schema_name in self._schemas
|
|
32
|
+
|
|
33
|
+
def copy_table(self, schema_name: str, table_name: str, table: Table) -> Table:
|
|
34
|
+
'''Copies a table into the catalog, creating the schema if it does not exist.'''
|
|
35
|
+
|
|
36
|
+
new_table = deepcopy(table)
|
|
37
|
+
self[schema_name][table_name] = new_table
|
|
38
|
+
|
|
39
|
+
return new_table
|
|
40
|
+
|
|
41
|
+
def has_table(self, schema_name: str, table_name: str) -> bool:
|
|
42
|
+
'''
|
|
43
|
+
Checks if a table exists in the specified schema in the catalog.
|
|
44
|
+
|
|
45
|
+
Returns False if the schema or table do not exist.
|
|
46
|
+
'''
|
|
47
|
+
|
|
48
|
+
if not self.has_schema(schema_name):
|
|
49
|
+
return False
|
|
50
|
+
return self.__getitem__(schema_name).has_table(table_name)
|
|
51
|
+
|
|
52
|
+
def add_column(self, schema_name: str, table_name: str, column_name: str,
|
|
53
|
+
column_type: str, numeric_precision: int | None = None, numeric_scale: int | None = None,
|
|
54
|
+
is_nullable: bool = True,
|
|
55
|
+
fk_schema: str | None = None, fk_table: str | None = None, fk_column: str | None = None) -> None:
|
|
56
|
+
'''Adds a column to the catalog, creating the schema and table if they do not exist.'''
|
|
57
|
+
|
|
58
|
+
self[schema_name][table_name].add_column(name=column_name,
|
|
59
|
+
column_type=column_type, numeric_precision=numeric_precision, numeric_scale=numeric_scale,
|
|
60
|
+
is_nullable=is_nullable,
|
|
61
|
+
fk_schema=fk_schema, fk_table=fk_table, fk_column=fk_column)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def schema_names(self) -> set[str]:
|
|
65
|
+
'''Returns all schema names in the catalog.'''
|
|
66
|
+
return set(self._schemas.keys())
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def table_names(self) -> set[str]:
|
|
70
|
+
'''Returns all table names in the catalog, regardless of schema.'''
|
|
71
|
+
|
|
72
|
+
result = set()
|
|
73
|
+
for schema in self._schemas.values():
|
|
74
|
+
result.update(schema.table_names)
|
|
75
|
+
return result
|
|
76
|
+
|
|
77
|
+
def copy(self) -> Self:
|
|
78
|
+
'''Creates a deep copy of the catalog.'''
|
|
79
|
+
return deepcopy(self)
|
|
80
|
+
|
|
81
|
+
def __repr__(self) -> str:
|
|
82
|
+
schemas = [schema.__repr__(1) for schema in self._schemas.values()]
|
|
83
|
+
|
|
84
|
+
result = 'Catalog('
|
|
85
|
+
for schema in schemas:
|
|
86
|
+
result += '\n' + schema
|
|
87
|
+
result += '\n)'
|
|
88
|
+
|
|
89
|
+
return result
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def to_dict(self) -> dict:
|
|
93
|
+
'''Converts the Catalog to a dictionary.'''
|
|
94
|
+
return {
|
|
95
|
+
'version': 1,
|
|
96
|
+
'schemas': {name: sch.to_dict() for name, sch in self._schemas.items()},
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def from_dict(cls, data: dict) -> 'Catalog':
|
|
101
|
+
'''Creates a Catalog from a dictionary.'''
|
|
102
|
+
cat = cls()
|
|
103
|
+
for _, sch_data in (data.get('schemas') or {}).items():
|
|
104
|
+
sch = Schema.from_dict(sch_data)
|
|
105
|
+
cat._schemas[sch.name] = sch
|
|
106
|
+
return cat
|
|
107
|
+
|
|
108
|
+
# String-based JSON (handy for DB/blob storage)
|
|
109
|
+
def to_json(self, *, indent: int | None = 2) -> str:
|
|
110
|
+
'''Converts the Catalog to a JSON string.'''
|
|
111
|
+
return json.dumps(self.to_dict(), indent=indent)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_json(cls, s: str) -> 'Catalog':
|
|
115
|
+
'''Creates a Catalog from a JSON string.'''
|
|
116
|
+
return cls.from_dict(json.loads(s))
|
|
117
|
+
|
|
118
|
+
def to_sqlglot_schema(self) -> dict[str, dict[str, dict[str, str]]]:
|
|
119
|
+
'''Converts to a sqlglot-compatible catalog format.'''
|
|
120
|
+
|
|
121
|
+
result: dict[str, dict[str, dict[str, str]]] = {}
|
|
122
|
+
|
|
123
|
+
for sch_name, sch in self._schemas.items():
|
|
124
|
+
result[sch_name] = {}
|
|
125
|
+
for tbl_name, tbl in sch._tables.items():
|
|
126
|
+
if not tbl.columns:
|
|
127
|
+
continue
|
|
128
|
+
result[sch_name][tbl_name] = {}
|
|
129
|
+
for col in tbl.columns:
|
|
130
|
+
result[sch_name][tbl_name][col.name] = col.column_type
|
|
131
|
+
if not result[sch_name]:
|
|
132
|
+
del result[sch_name]
|
|
133
|
+
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
# Convenience file helpers
|
|
137
|
+
def save_json(self, path: str, *, indent: int | None = 2) -> None:
|
|
138
|
+
'''Saves the Catalog to a JSON file.'''
|
|
139
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
140
|
+
json.dump(self.to_dict(), f, indent=indent)
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def load_json(cls, path: str) -> 'Catalog':
|
|
144
|
+
'''Loads a Catalog from a JSON file.'''
|
|
145
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
146
|
+
data = json.load(f)
|
|
147
|
+
return cls.from_dict(data)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
@dataclass
|
|
4
|
+
class Column:
|
|
5
|
+
'''A database table column, with type and constraints.'''
|
|
6
|
+
|
|
7
|
+
name: str
|
|
8
|
+
real_name: str = field(init=False)
|
|
9
|
+
table_idx: int | None = None
|
|
10
|
+
'''Index of the table in `referenced_tables`. If None, the column is not associated with a specific table in `referenced_tables`.'''
|
|
11
|
+
|
|
12
|
+
column_type: str = 'UNKNOWN'
|
|
13
|
+
numeric_precision: int | None = None
|
|
14
|
+
numeric_scale: int | None = None
|
|
15
|
+
is_nullable: bool = True
|
|
16
|
+
is_constant: bool = False
|
|
17
|
+
fk_schema: str | None = None
|
|
18
|
+
fk_table: str | None = None
|
|
19
|
+
fk_column: str | None = None
|
|
20
|
+
|
|
21
|
+
def __post_init__(self):
|
|
22
|
+
self.real_name = self.name
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def is_fk(self) -> bool:
|
|
26
|
+
'''Returns True if the column is a foreign key.'''
|
|
27
|
+
return all([self.fk_schema, self.fk_table, self.fk_column])
|
|
28
|
+
|
|
29
|
+
def __repr__(self, level: int = 0, max_col_len: int = 20) -> str:
|
|
30
|
+
indent = ' ' * level
|
|
31
|
+
|
|
32
|
+
idx_str = f'table_idx={self.table_idx}, ' if self.table_idx is not None else ''
|
|
33
|
+
return f'{indent}Column(' \
|
|
34
|
+
f'name=\'{self.name}\',{" " * (max_col_len - len(self.name))} ' \
|
|
35
|
+
f'real_name=\'{self.real_name}\',{" " * (max_col_len - len(self.real_name))} ' \
|
|
36
|
+
f'{idx_str}' \
|
|
37
|
+
f'is_fk={self.is_fk}, ' \
|
|
38
|
+
f'is_nullable={self.is_nullable}, ' \
|
|
39
|
+
f'is_constant={self.is_constant}, ' \
|
|
40
|
+
f'type=\'{self.column_type}\'' \
|
|
41
|
+
f')'
|
|
42
|
+
|
|
43
|
+
def to_dict(self) -> dict:
|
|
44
|
+
'''Converts the Column to a dictionary.'''
|
|
45
|
+
return {
|
|
46
|
+
'name': self.name,
|
|
47
|
+
'column_type': self.column_type,
|
|
48
|
+
'numeric_precision': self.numeric_precision,
|
|
49
|
+
'numeric_scale': self.numeric_scale,
|
|
50
|
+
'is_nullable': self.is_nullable,
|
|
51
|
+
'fk_schema': self.fk_schema,
|
|
52
|
+
'fk_table': self.fk_table,
|
|
53
|
+
'fk_column': self.fk_column,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_dict(cls, data: dict) -> 'Column':
|
|
58
|
+
'''Creates a Column from a dictionary.'''
|
|
59
|
+
return cls(
|
|
60
|
+
name=data['name'],
|
|
61
|
+
column_type=data['column_type'],
|
|
62
|
+
numeric_precision=data.get('numeric_precision'),
|
|
63
|
+
numeric_scale=data.get('numeric_scale'),
|
|
64
|
+
is_nullable=data.get('is_nullable', True),
|
|
65
|
+
fk_schema=(data.get('fk_schema') or None),
|
|
66
|
+
fk_table=(data.get('fk_table') or None),
|
|
67
|
+
fk_column=(data.get('fk_column') or None),
|
|
68
|
+
)
|