sqlscope 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sqlscope/__init__.py +4 -0
  2. sqlscope/catalog/__init__.py +12 -0
  3. sqlscope/catalog/builder/__init__.py +8 -0
  4. sqlscope/catalog/builder/postgres.py +207 -0
  5. sqlscope/catalog/builder/sql.py +219 -0
  6. sqlscope/catalog/catalog.py +147 -0
  7. sqlscope/catalog/column.py +68 -0
  8. sqlscope/catalog/constraint.py +83 -0
  9. sqlscope/catalog/schema.py +60 -0
  10. sqlscope/catalog/table.py +112 -0
  11. sqlscope/query/__init__.py +5 -0
  12. sqlscope/query/extractors.py +118 -0
  13. sqlscope/query/query.py +191 -0
  14. sqlscope/query/set_operations/__init__.py +181 -0
  15. sqlscope/query/set_operations/binary_set_operation.py +162 -0
  16. sqlscope/query/set_operations/select.py +664 -0
  17. sqlscope/query/set_operations/set_operation.py +59 -0
  18. sqlscope/query/smt.py +334 -0
  19. sqlscope/query/tokenized_sql.py +70 -0
  20. sqlscope/query/typechecking/__init__.py +21 -0
  21. sqlscope/query/typechecking/base.py +9 -0
  22. sqlscope/query/typechecking/binary_ops.py +57 -0
  23. sqlscope/query/typechecking/functions.py +81 -0
  24. sqlscope/query/typechecking/predicates.py +123 -0
  25. sqlscope/query/typechecking/primitives.py +80 -0
  26. sqlscope/query/typechecking/queries.py +35 -0
  27. sqlscope/query/typechecking/types.py +59 -0
  28. sqlscope/query/typechecking/unary_ops.py +51 -0
  29. sqlscope/query/typechecking/util.py +51 -0
  30. sqlscope/util/__init__.py +18 -0
  31. sqlscope/util/ast/__init__.py +55 -0
  32. sqlscope/util/ast/column.py +55 -0
  33. sqlscope/util/ast/function.py +10 -0
  34. sqlscope/util/ast/subquery.py +23 -0
  35. sqlscope/util/ast/table.py +36 -0
  36. sqlscope/util/sql.py +27 -0
  37. sqlscope/util/tokens.py +17 -0
  38. sqlscope-1.0.0.dist-info/METADATA +52 -0
  39. sqlscope-1.0.0.dist-info/RECORD +41 -0
  40. sqlscope-1.0.0.dist-info/WHEEL +4 -0
  41. sqlscope-1.0.0.dist-info/licenses/LICENSE +21 -0
sqlscope/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ '''Extracts catalog and query metadata from a SQL query'''
2
+
3
+ from .catalog import Catalog, build_catalog, load_catalog, build_catalog_from_postgres, build_catalog_from_sql
4
+ from .query import Query
@@ -0,0 +1,12 @@
1
+ '''Represents a catalog of database schemas, tables, and columns.'''
2
+
3
+ # API exports
4
+ from .constraint import Constraint, ConstraintColumn, ConstraintType
5
+ from .column import Column
6
+ from .table import Table
7
+ from .schema import Schema
8
+ from .catalog import Catalog
9
+ from .builder import CatalogColumnInfo, CatalogUniqueConstraintInfo, build_catalog, build_catalog_from_postgres, load_catalog, build_catalog_from_sql
10
+
11
+
12
+
@@ -0,0 +1,8 @@
1
+ from ..catalog import Catalog
2
+ from .postgres import build_catalog, build_catalog_from_postgres, CatalogColumnInfo, CatalogUniqueConstraintInfo
3
+ from .sql import build_catalog_from_sql
4
+
5
+ def load_catalog(path: str) -> Catalog:
6
+ '''Loads a catalog from a JSON file.'''
7
+ return Catalog.load_json(path)
8
+
@@ -0,0 +1,207 @@
1
+ from ..catalog import Catalog
2
+ from ..constraint import ConstraintType
3
+
4
+ import psycopg2
5
+ import time
6
+ from dataclasses import dataclass
7
+
8
+ # region Data Classes
9
+ @dataclass(frozen=True)
10
+ class CatalogColumnInfo:
11
+ '''Holds information about a database column.'''
12
+ schema_name: str
13
+ table_name: str
14
+ column_name: str
15
+ column_type: str
16
+ numeric_precision: int | None
17
+ numeric_scale: int | None
18
+ is_nullable: bool
19
+ foreign_key_schema: str | None
20
+ foreign_key_table: str | None
21
+ foreign_key_column: str | None
22
+
23
+ def to_dict(self) -> dict:
24
+ return {
25
+ 'schema_name': self.schema_name,
26
+ 'table_name': self.table_name,
27
+ 'column_name': self.column_name,
28
+ 'column_type': self.column_type,
29
+ 'numeric_precision': self.numeric_precision,
30
+ 'numeric_scale': self.numeric_scale,
31
+ 'is_nullable': self.is_nullable,
32
+ 'foreign_key_schema': self.foreign_key_schema,
33
+ 'foreign_key_table': self.foreign_key_table,
34
+ 'foreign_key_column': self.foreign_key_column,
35
+ }
36
+
37
+ @dataclass(frozen=True)
38
+ class CatalogUniqueConstraintInfo:
39
+ '''Holds information about a unique constraint or primary key.'''
40
+ schema_name: str
41
+ table_name: str
42
+ constraint_type: str
43
+ columns: str # Postgres returns this as a string like '{col1,col2,...}'
44
+
45
+ def to_dict(self) -> dict:
46
+ return {
47
+ 'schema_name': self.schema_name,
48
+ 'table_name': self.table_name,
49
+ 'constraint_type': self.constraint_type,
50
+ 'columns': self.columns,
51
+ }
52
+ # endregion
53
+
54
+ # region Catalog Builder
55
+ def build_catalog(columns_info: list[CatalogColumnInfo], unique_constraints_info: list[CatalogUniqueConstraintInfo]) -> Catalog:
56
+ '''Builds a catalog from the provided column and unique constraint information.'''
57
+ result = Catalog()
58
+
59
+ for column in columns_info:
60
+ result.add_column(
61
+ schema_name=column.schema_name,
62
+ table_name=column.table_name,
63
+ column_name=column.column_name,
64
+ column_type=column.column_type,
65
+ numeric_precision=column.numeric_precision,
66
+ numeric_scale=column.numeric_scale,
67
+ is_nullable=column.is_nullable,
68
+ fk_schema=column.foreign_key_schema,
69
+ fk_table=column.foreign_key_table,
70
+ fk_column=column.foreign_key_column,
71
+ )
72
+
73
+ for constraint in unique_constraints_info:
74
+ columns = set(constraint.columns.strip('{}').split(',')) # Postgres returns {col1,col2,...}
75
+ constraint_type = ConstraintType.PRIMARY_KEY if constraint.constraint_type == 'PRIMARY KEY' else ConstraintType.UNIQUE
76
+
77
+ result[constraint.schema_name][constraint.table_name].add_unique_constraint(columns, constraint_type=constraint_type)
78
+
79
+ return result
80
+
81
+
82
+ def build_catalog_from_postgres(sql_string: str, *, hostname: str, port: int, user: str, password: str, schema: str | None = None, create_temp_schema: bool = False) -> Catalog:
83
+ '''Builds a catalog by executing the provided SQL string in a temporary PostgreSQL database.'''
84
+ if sql_string.strip() == '':
85
+ return Catalog()
86
+
87
+ conn = psycopg2.connect(host=hostname, port=port, user=user, password=password)
88
+ cur = conn.cursor()
89
+
90
+ # Use a temporary schema with a fixed name
91
+ if create_temp_schema:
92
+ if schema is None:
93
+ schema_name = f'sql_error_categorizer_{time.time_ns()}'
94
+ else:
95
+ schema_name = schema
96
+ cur.execute(f'CREATE schema {schema_name};')
97
+ cur.execute(f'SET search_path TO {schema_name};')
98
+ else:
99
+ schema_name = '%' if schema is None else schema
100
+
101
+ # Create the tables
102
+ cur.execute(sql_string)
103
+
104
+ # Fetch the catalog information
105
+ cur.execute(COLUMNS(schema_name))
106
+ columns_info = cur.fetchall()
107
+
108
+ columns_data = [
109
+ CatalogColumnInfo(
110
+ schema_name=row[0],
111
+ table_name=row[1],
112
+ column_name=row[2],
113
+ column_type=row[3],
114
+ numeric_precision=row[4],
115
+ numeric_scale=row[5],
116
+ is_nullable=row[6],
117
+ foreign_key_schema=row[7],
118
+ foreign_key_table=row[8],
119
+ foreign_key_column=row[9],
120
+ )
121
+ for row in columns_info
122
+ ]
123
+
124
+ # Fetch unique constraints (including primary keys)
125
+ cur.execute(UNIQUE_COLUMNS(schema_name))
126
+ unique_constraints_info = cur.fetchall()
127
+
128
+ unique_constraints_data = [
129
+ CatalogUniqueConstraintInfo(
130
+ schema_name=row[0],
131
+ table_name=row[1],
132
+ constraint_type=row[2],
133
+ columns=row[3],
134
+ )
135
+ for row in unique_constraints_info
136
+ ]
137
+
138
+ # Clean up
139
+ if create_temp_schema:
140
+ cur.execute(f'DROP schema {schema_name} CASCADE;')
141
+ conn.rollback() # no need to save anything
142
+
143
+ return build_catalog(columns_data, unique_constraints_data)
144
+ # endregion
145
+
146
+ # region SQL Queries
147
+ def UNIQUE_COLUMNS(schema_name: str = '%') -> str:
148
+ return f'''
149
+ SELECT
150
+ kcu.table_schema AS schema_name,
151
+ kcu.table_name,
152
+ tc.constraint_type,
153
+ array_agg(kcu.column_name ORDER BY kcu.ordinal_position) AS columns
154
+ FROM information_schema.table_constraints tc
155
+ JOIN information_schema.key_column_usage kcu
156
+ ON tc.constraint_name = kcu.constraint_name
157
+ AND tc.constraint_schema = kcu.constraint_schema
158
+ WHERE tc.constraint_type IN ('UNIQUE', 'PRIMARY KEY')
159
+ AND kcu.table_schema LIKE '{schema_name}'
160
+ GROUP BY
161
+ kcu.table_schema,
162
+ kcu.table_name,
163
+ kcu.constraint_name,
164
+ tc.constraint_type;
165
+ '''
166
+
167
+ def COLUMNS(schema_name: str = '%') -> str:
168
+ return f'''
169
+ SELECT
170
+ cols.table_schema AS schema_name,
171
+ cols.table_name,
172
+ cols.column_name,
173
+ cols.data_type AS column_type,
174
+ cols.numeric_precision,
175
+ cols.numeric_scale,
176
+ (cols.is_nullable = 'YES') AS is_nullable,
177
+ fk.foreign_table_schema AS foreign_key_schema,
178
+ fk.foreign_table_name AS foreign_key_table,
179
+ fk.foreign_column_name AS foreign_key_column
180
+ FROM information_schema.columns AS cols
181
+
182
+ -- Foreign Key
183
+ LEFT JOIN (
184
+ SELECT
185
+ kcu.table_schema,
186
+ kcu.table_name,
187
+ kcu.column_name,
188
+ ccu.table_schema AS foreign_table_schema,
189
+ ccu.table_name AS foreign_table_name,
190
+ ccu.column_name AS foreign_column_name
191
+ FROM information_schema.table_constraints AS tc
192
+ JOIN information_schema.key_column_usage AS kcu
193
+ ON tc.constraint_name = kcu.constraint_name
194
+ AND tc.constraint_schema = kcu.constraint_schema
195
+ AND tc.table_schema = kcu.table_schema
196
+ AND tc.table_name = kcu.table_name
197
+ JOIN information_schema.constraint_column_usage AS ccu
198
+ ON tc.constraint_name = ccu.constraint_name
199
+ AND tc.constraint_schema = ccu.constraint_schema
200
+ WHERE tc.constraint_type = 'FOREIGN KEY'
201
+ ) fk ON fk.table_schema = cols.table_schema
202
+ AND fk.table_name = cols.table_name
203
+ AND fk.column_name = cols.column_name
204
+
205
+ WHERE cols.table_schema LIKE '{schema_name}'
206
+ '''
207
+ # endregion
@@ -0,0 +1,219 @@
1
+ import sqlglot
2
+ from sqlglot import exp
3
+ from ..catalog import Catalog
4
+ from ..constraint import ConstraintType
5
+
6
+ def _get_identifier_name(identifier_exp: exp.Identifier) -> str:
7
+ '''Returns the normalized name from an Identifier expression.'''
8
+
9
+ if identifier_exp.quoted:
10
+ return identifier_exp.name
11
+ return identifier_exp.name.lower()
12
+
13
+ def _get_table_name(table_exp: exp.Table) -> str:
14
+ '''Returns the normalized table name from a Table expression.'''
15
+
16
+ if isinstance(table_exp.this, exp.Identifier):
17
+ return _get_identifier_name(table_exp.this)
18
+ return str(table_exp.this).lower()
19
+
20
+ def _get_schema_name(table_exp: exp.Table, default_schema: str) -> str:
21
+ '''Returns the normalized schema name from a Table expression, or the default schema if not specified.'''
22
+
23
+ if table_exp.db:
24
+ if isinstance(table_exp.db, exp.Identifier):
25
+ return _get_identifier_name(table_exp.db)
26
+ return str(table_exp.db).lower()
27
+ return default_schema
28
+
29
+ def _get_column_name(column_exp: exp.ColumnDef) -> str:
30
+ '''Returns the normalized column name from a ColumnDef expression.'''
31
+
32
+ if isinstance(column_exp.this, exp.Identifier):
33
+ return _get_identifier_name(column_exp.this)
34
+ return str(column_exp.this).lower()
35
+
36
+ def _extract_datatype(column_exp: exp.ColumnDef) -> tuple[str, int | None, int | None]:
37
+ '''Extracts datatype information from a ColumnDef expression.'''
38
+
39
+ datatype_exp = column_exp.kind
40
+ assert isinstance(datatype_exp, exp.DataType), 'Expected DataType expression in ColumnDef'
41
+
42
+ datatype = datatype_exp.this.value
43
+
44
+ numeric_precision = None
45
+ numeric_scale = None
46
+
47
+ if datatype_exp.expressions:
48
+ if datatype in {'DECIMAL', 'NUMERIC', 'NUMBER', 'FLOAT'}:
49
+ if len(datatype_exp.expressions) >= 1:
50
+ precision_exp = datatype_exp.expressions[0]
51
+ if isinstance(precision_exp, exp.Literal) and precision_exp.is_int:
52
+ numeric_precision = int(precision_exp.name)
53
+ if len(datatype_exp.expressions) == 2:
54
+ scale_exp = datatype_exp.expressions[1]
55
+ if isinstance(scale_exp, exp.Literal) and scale_exp.is_int:
56
+ numeric_scale = int(scale_exp.name)
57
+
58
+ return datatype, numeric_precision, numeric_scale
59
+
60
+ def build_catalog_from_sql(sql_string: str, search_path: str = 'public') -> Catalog:
61
+ '''Builds a catalog from the provided SQL string without executing it in a database.'''
62
+
63
+ statements = sqlglot.parse(sql_string)
64
+
65
+ # Filter to only CREATE TABLE statements
66
+ statements = [stmt for stmt in statements if isinstance(stmt, exp.Create) and stmt.kind and stmt.kind.upper() == 'TABLE']
67
+
68
+ catalog = Catalog()
69
+
70
+ for statement in statements:
71
+ table_exp = statement.find(exp.Table)
72
+
73
+ assert table_exp is not None, 'Expected Table expression in CREATE TABLE statement'
74
+
75
+ # "CREATE TABLE <schema_name>.<table_name>" handling
76
+ table_name = _get_table_name(table_exp)
77
+ schema_name = _get_schema_name(table_exp, search_path)
78
+
79
+ # Extract other relevant information
80
+ column_exps: list[exp.ColumnDef] = list(statement.find_all(exp.ColumnDef))
81
+ '''Column definitions'''
82
+
83
+ pk_exp: exp.PrimaryKey | None = statement.find(exp.PrimaryKey)
84
+ '''PRIMARY KEY defined at table level, e.g., PRIMARY KEY (col1, col2)'''
85
+
86
+ fk_exps: list[exp.ForeignKey] = list(statement.find_all(exp.ForeignKey))
87
+ '''FOREIGN KEY defined at table level, e.g., FOREIGN KEY (col1) REFERENCES other_table(other_col)'''
88
+
89
+ fks: dict[str, tuple[str, str, str]] = {}
90
+ '''Mapping of foreign key column names to (schema, table, column) tuples'''
91
+ # NOTE: this needs to be filled in before adding columns to the catalog
92
+
93
+ unique_exps: list[exp.UniqueColumnConstraint] = list(statement.find_all(exp.UniqueColumnConstraint))
94
+ '''UNIQUE constraints defined at table level, e.g., UNIQUE (col1, col2)'''
95
+
96
+ pk_col_names: set[str] = set()
97
+ '''Set to keep track of primary key column names'''
98
+
99
+ unique_col_names: list[set[str]] = []
100
+ '''List to keep track of unique constraint column name sets'''
101
+
102
+ # Process table-level Foreign Key constraints
103
+ for fk_exp in fk_exps:
104
+ fk_id_exps = fk_exp.expressions
105
+ fk_column_names = [_get_identifier_name(col_exp) for col_exp in fk_id_exps]
106
+
107
+ ref_exp = fk_exp.find(exp.Reference)
108
+ assert ref_exp is not None, 'Expected Reference expression in Foreign Key definition'
109
+
110
+ ref_schema_exp = ref_exp.this
111
+ assert isinstance(ref_schema_exp, exp.Schema), 'Expected Schema expression in Foreign Key reference'
112
+
113
+ ref_table_exp = ref_schema_exp.this
114
+ assert isinstance(ref_table_exp, exp.Table), 'Expected Table expression in Foreign Key reference'
115
+
116
+ ref_schema_name = _get_schema_name(ref_table_exp, search_path)
117
+ ref_table_name = _get_table_name(ref_table_exp)
118
+
119
+ ref_id_exps = ref_schema_exp.expressions
120
+ ref_column_names = [_get_identifier_name(col_exp) for col_exp in ref_id_exps]
121
+
122
+ # e.g. "FOREIGN KEY (tenant_id, order_id) REFERENCES orders (tenant_id, order_id)"
123
+ for fk_col_name, ref_col_name in zip(fk_column_names, ref_column_names):
124
+ fks[fk_col_name] = (ref_schema_name, ref_table_name, ref_col_name)
125
+
126
+ # Process columns
127
+ for column_exp in column_exps:
128
+ column_name = _get_column_name(column_exp)
129
+
130
+ # Primary Key handling
131
+ is_pk = any(isinstance(c.kind, exp.PrimaryKeyColumnConstraint) for c in column_exp.constraints)
132
+ if is_pk:
133
+ pk_col_names.add(column_name)
134
+
135
+ # Unique handling
136
+ is_unique = any(isinstance(c.kind, exp.UniqueColumnConstraint) for c in column_exp.constraints)
137
+ if is_unique:
138
+ unique_col_names.append({column_name})
139
+
140
+ # Not Null handling
141
+ is_not_null = any(isinstance(c.kind, exp.NotNullColumnConstraint) for c in column_exp.constraints)
142
+
143
+ # Foreign Key handling
144
+ fk_constraint = next((c for c in column_exp.constraints if isinstance(c.kind, exp.Reference)), None)
145
+
146
+ if fk_constraint:
147
+ fk_reference = fk_constraint.kind
148
+ assert isinstance(fk_reference, exp.Reference), 'Expected Reference expression in Foreign Key constraint'
149
+
150
+ fk_schema_exp = fk_reference.this
151
+ assert isinstance(fk_schema_exp, exp.Schema), 'Expected Schema expression in Foreign Key constraint'
152
+
153
+ fk_table_exp = fk_schema_exp.this
154
+ assert isinstance(fk_table_exp, exp.Table), 'Expected Table expression in Foreign Key constraint'
155
+
156
+ fk_schema_name = _get_schema_name(fk_table_exp, search_path)
157
+ fk_table_name = _get_table_name(fk_table_exp)
158
+
159
+ fk_column_exp = fk_schema_exp.expressions[0]
160
+ assert isinstance(fk_column_exp, exp.Identifier), 'Expected Identifier expression in Foreign Key column'
161
+ fk_column_name = _get_identifier_name(fk_column_exp)
162
+ elif column_name in fks:
163
+ fk_schema_name, fk_table_name, fk_column_name = fks[column_name]
164
+ else:
165
+ fk_schema_name = None
166
+ fk_table_name = None
167
+ fk_column_name = None
168
+
169
+ # Datatype handling
170
+ column_type, numeric_precision, numeric_scale = _extract_datatype(column_exp)
171
+
172
+ # Add column to catalog
173
+ catalog[schema_name][table_name].add_column(
174
+ name=column_name,
175
+ column_type=column_type,
176
+ real_name=column_name,
177
+ numeric_precision=numeric_precision,
178
+ numeric_scale=numeric_scale,
179
+ is_nullable=not is_not_null,
180
+ fk_schema=fk_schema_name,
181
+ fk_table=fk_table_name,
182
+ fk_column=fk_column_name)
183
+
184
+ # Process table-level Primary Key constraint
185
+ if pk_exp:
186
+ for ordered_exp in pk_exp.expressions:
187
+ col_exp = ordered_exp.find(exp.Column)
188
+ assert col_exp is not None, 'Expected Column expression in Primary Key definition'
189
+ col_name = _get_column_name(col_exp)
190
+ pk_col_names.add(col_name)
191
+
192
+ # Process table-level Unique constraints
193
+ for unique_exp in unique_exps:
194
+ unique_schema_exp = unique_exp.this
195
+ assert isinstance(unique_schema_exp, exp.Schema), 'Expected Schema expression in Unique constraint'
196
+
197
+ unique_column_names = set()
198
+ for col_id_exp in unique_exp.expressions:
199
+ col_name = _get_identifier_name(col_id_exp)
200
+ unique_column_names.add(col_name)
201
+ unique_col_names.append(unique_column_names)
202
+
203
+ # Add Primary Key constraint to catalog
204
+ # NOTE: needs to be perfomed after all columns have been added, since PKs can be defined at both column and table level
205
+ assert len(pk_col_names) > 0, 'Primary Key columns should have been identified'
206
+ catalog[schema_name][table_name].add_unique_constraint(
207
+ columns=pk_col_names,
208
+ constraint_type=ConstraintType.PRIMARY_KEY
209
+ )
210
+
211
+ # Add Unique constraints to catalog
212
+ # NOTE: needs to be perfomed after all columns have been added, since Unique constraints can be defined at both column and table level
213
+ for unique_col_name_set in unique_col_names:
214
+ catalog[schema_name][table_name].add_unique_constraint(
215
+ columns=unique_col_name_set,
216
+ constraint_type=ConstraintType.UNIQUE
217
+ )
218
+
219
+ return catalog
@@ -0,0 +1,147 @@
1
+ from .table import Table
2
+ from .schema import Schema
3
+
4
+ from dataclasses import dataclass, field
5
+ import json
6
+ from typing import Self
7
+ from copy import deepcopy
8
+
9
+ @dataclass
10
+ class Catalog:
11
+ '''A database catalog, with schemas, tables, and columns.'''
12
+
13
+ _schemas: dict[str, Schema] = field(default_factory=dict)
14
+
15
+ def __getitem__(self, schema_name: str) -> Schema:
16
+ '''Gets a schema from the catalog, creating it if it does not exist.'''
17
+
18
+ if schema_name not in self._schemas:
19
+ self._schemas[schema_name] = Schema(schema_name)
20
+ return self._schemas[schema_name]
21
+
22
+ def __setitem__(self, schema_name: str, schema: Schema) -> Schema:
23
+ '''Sets a schema in the catalog, replacing any existing schema with the same name.'''
24
+
25
+ self._schemas[schema_name] = schema
26
+ return schema
27
+
28
+ def has_schema(self, schema_name: str) -> bool:
29
+ '''Checks if a schema exists in the catalog.'''
30
+
31
+ return schema_name in self._schemas
32
+
33
+ def copy_table(self, schema_name: str, table_name: str, table: Table) -> Table:
34
+ '''Copies a table into the catalog, creating the schema if it does not exist.'''
35
+
36
+ new_table = deepcopy(table)
37
+ self[schema_name][table_name] = new_table
38
+
39
+ return new_table
40
+
41
+ def has_table(self, schema_name: str, table_name: str) -> bool:
42
+ '''
43
+ Checks if a table exists in the specified schema in the catalog.
44
+
45
+ Returns False if the schema or table do not exist.
46
+ '''
47
+
48
+ if not self.has_schema(schema_name):
49
+ return False
50
+ return self.__getitem__(schema_name).has_table(table_name)
51
+
52
+ def add_column(self, schema_name: str, table_name: str, column_name: str,
53
+ column_type: str, numeric_precision: int | None = None, numeric_scale: int | None = None,
54
+ is_nullable: bool = True,
55
+ fk_schema: str | None = None, fk_table: str | None = None, fk_column: str | None = None) -> None:
56
+ '''Adds a column to the catalog, creating the schema and table if they do not exist.'''
57
+
58
+ self[schema_name][table_name].add_column(name=column_name,
59
+ column_type=column_type, numeric_precision=numeric_precision, numeric_scale=numeric_scale,
60
+ is_nullable=is_nullable,
61
+ fk_schema=fk_schema, fk_table=fk_table, fk_column=fk_column)
62
+
63
+ @property
64
+ def schema_names(self) -> set[str]:
65
+ '''Returns all schema names in the catalog.'''
66
+ return set(self._schemas.keys())
67
+
68
+ @property
69
+ def table_names(self) -> set[str]:
70
+ '''Returns all table names in the catalog, regardless of schema.'''
71
+
72
+ result = set()
73
+ for schema in self._schemas.values():
74
+ result.update(schema.table_names)
75
+ return result
76
+
77
+ def copy(self) -> Self:
78
+ '''Creates a deep copy of the catalog.'''
79
+ return deepcopy(self)
80
+
81
+ def __repr__(self) -> str:
82
+ schemas = [schema.__repr__(1) for schema in self._schemas.values()]
83
+
84
+ result = 'Catalog('
85
+ for schema in schemas:
86
+ result += '\n' + schema
87
+ result += '\n)'
88
+
89
+ return result
90
+
91
+
92
+ def to_dict(self) -> dict:
93
+ '''Converts the Catalog to a dictionary.'''
94
+ return {
95
+ 'version': 1,
96
+ 'schemas': {name: sch.to_dict() for name, sch in self._schemas.items()},
97
+ }
98
+
99
+ @classmethod
100
+ def from_dict(cls, data: dict) -> 'Catalog':
101
+ '''Creates a Catalog from a dictionary.'''
102
+ cat = cls()
103
+ for _, sch_data in (data.get('schemas') or {}).items():
104
+ sch = Schema.from_dict(sch_data)
105
+ cat._schemas[sch.name] = sch
106
+ return cat
107
+
108
+ # String-based JSON (handy for DB/blob storage)
109
+ def to_json(self, *, indent: int | None = 2) -> str:
110
+ '''Converts the Catalog to a JSON string.'''
111
+ return json.dumps(self.to_dict(), indent=indent)
112
+
113
+ @classmethod
114
+ def from_json(cls, s: str) -> 'Catalog':
115
+ '''Creates a Catalog from a JSON string.'''
116
+ return cls.from_dict(json.loads(s))
117
+
118
+ def to_sqlglot_schema(self) -> dict[str, dict[str, dict[str, str]]]:
119
+ '''Converts to a sqlglot-compatible catalog format.'''
120
+
121
+ result: dict[str, dict[str, dict[str, str]]] = {}
122
+
123
+ for sch_name, sch in self._schemas.items():
124
+ result[sch_name] = {}
125
+ for tbl_name, tbl in sch._tables.items():
126
+ if not tbl.columns:
127
+ continue
128
+ result[sch_name][tbl_name] = {}
129
+ for col in tbl.columns:
130
+ result[sch_name][tbl_name][col.name] = col.column_type
131
+ if not result[sch_name]:
132
+ del result[sch_name]
133
+
134
+ return result
135
+
136
+ # Convenience file helpers
137
+ def save_json(self, path: str, *, indent: int | None = 2) -> None:
138
+ '''Saves the Catalog to a JSON file.'''
139
+ with open(path, 'w', encoding='utf-8') as f:
140
+ json.dump(self.to_dict(), f, indent=indent)
141
+
142
+ @classmethod
143
+ def load_json(cls, path: str) -> 'Catalog':
144
+ '''Loads a Catalog from a JSON file.'''
145
+ with open(path, 'r', encoding='utf-8') as f:
146
+ data = json.load(f)
147
+ return cls.from_dict(data)
@@ -0,0 +1,68 @@
1
+ from dataclasses import dataclass, field
2
+
3
+ @dataclass
4
+ class Column:
5
+ '''A database table column, with type and constraints.'''
6
+
7
+ name: str
8
+ real_name: str = field(init=False)
9
+ table_idx: int | None = None
10
+ '''Index of the table in `referenced_tables`. If None, the column is not associated with a specific table in `referenced_tables`.'''
11
+
12
+ column_type: str = 'UNKNOWN'
13
+ numeric_precision: int | None = None
14
+ numeric_scale: int | None = None
15
+ is_nullable: bool = True
16
+ is_constant: bool = False
17
+ fk_schema: str | None = None
18
+ fk_table: str | None = None
19
+ fk_column: str | None = None
20
+
21
+ def __post_init__(self):
22
+ self.real_name = self.name
23
+
24
+ @property
25
+ def is_fk(self) -> bool:
26
+ '''Returns True if the column is a foreign key.'''
27
+ return all([self.fk_schema, self.fk_table, self.fk_column])
28
+
29
+ def __repr__(self, level: int = 0, max_col_len: int = 20) -> str:
30
+ indent = ' ' * level
31
+
32
+ idx_str = f'table_idx={self.table_idx}, ' if self.table_idx is not None else ''
33
+ return f'{indent}Column(' \
34
+ f'name=\'{self.name}\',{" " * (max_col_len - len(self.name))} ' \
35
+ f'real_name=\'{self.real_name}\',{" " * (max_col_len - len(self.real_name))} ' \
36
+ f'{idx_str}' \
37
+ f'is_fk={self.is_fk}, ' \
38
+ f'is_nullable={self.is_nullable}, ' \
39
+ f'is_constant={self.is_constant}, ' \
40
+ f'type=\'{self.column_type}\'' \
41
+ f')'
42
+
43
+ def to_dict(self) -> dict:
44
+ '''Converts the Column to a dictionary.'''
45
+ return {
46
+ 'name': self.name,
47
+ 'column_type': self.column_type,
48
+ 'numeric_precision': self.numeric_precision,
49
+ 'numeric_scale': self.numeric_scale,
50
+ 'is_nullable': self.is_nullable,
51
+ 'fk_schema': self.fk_schema,
52
+ 'fk_table': self.fk_table,
53
+ 'fk_column': self.fk_column,
54
+ }
55
+
56
+ @classmethod
57
+ def from_dict(cls, data: dict) -> 'Column':
58
+ '''Creates a Column from a dictionary.'''
59
+ return cls(
60
+ name=data['name'],
61
+ column_type=data['column_type'],
62
+ numeric_precision=data.get('numeric_precision'),
63
+ numeric_scale=data.get('numeric_scale'),
64
+ is_nullable=data.get('is_nullable', True),
65
+ fk_schema=(data.get('fk_schema') or None),
66
+ fk_table=(data.get('fk_table') or None),
67
+ fk_column=(data.get('fk_column') or None),
68
+ )