aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/schema.py
ADDED
|
@@ -0,0 +1,973 @@
|
|
|
1
|
+
"""Schema reflection, graph building, and profiling orchestration.
|
|
2
|
+
|
|
3
|
+
Reflects table and column metadata from a live database or SQL file, infers missing foreign keys by naming convention, and builds a shortest-path join graph between all table pairs. A local JSON cache avoids repeated reflection; the cache is invalidated by a content hash of the schema structure. Profiling data (statistics, value domains, LLM-assigned roles, and allowed operations) is layered on top of the reflected schema before caching. Adaptive limits derived from schema statistics are also stored.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import asdict
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from .config import EngineConfig, PolicyConfig
|
|
16
|
+
from .contracts_base import (
|
|
17
|
+
ColumnMetadata,
|
|
18
|
+
FKEdge,
|
|
19
|
+
SchemaGraph,
|
|
20
|
+
SchemaLimits,
|
|
21
|
+
TableMetadata,
|
|
22
|
+
)
|
|
23
|
+
from .core_utils import debug, schema_hash_fp
|
|
24
|
+
from .dialect import get_dialect
|
|
25
|
+
from .schema_profiling import (
|
|
26
|
+
apply_column_roles_llm,
|
|
27
|
+
assign_column_ops,
|
|
28
|
+
extract_tables_from_catalog,
|
|
29
|
+
extract_tables_from_catalog_sql_connector,
|
|
30
|
+
parse_sql_file,
|
|
31
|
+
profile_schema,
|
|
32
|
+
profile_schema_spark,
|
|
33
|
+
profile_schema_sql_connector,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def compute_schema_stats(schema: SchemaGraph) -> dict[str, Any]:
|
|
38
|
+
"""Compute schema-wide column availability statistics for adaptive limit calculation.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
|
|
42
|
+
schema: Populated ``SchemaGraph`` with profiled column metadata.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
|
|
46
|
+
Dictionary with keys: ``total_filterable``, ``total_groupable``, ``total_aggregatable``, ``min/max_filterable_per_table``, ``min/max_groupable_per_table``, ``table_count``, and ``filterable_per_table`` (list of per-table counts).
|
|
47
|
+
"""
|
|
48
|
+
debug("[schema.compute_schema_stats] computing schema statistics for adaptive limits")
|
|
49
|
+
|
|
50
|
+
stats = {
|
|
51
|
+
"total_filterable": 0,
|
|
52
|
+
"total_groupable": 0,
|
|
53
|
+
"total_aggregatable": 0,
|
|
54
|
+
"min_filterable_per_table": float("inf"),
|
|
55
|
+
"max_filterable_per_table": 0,
|
|
56
|
+
"min_groupable_per_table": float("inf"),
|
|
57
|
+
"max_groupable_per_table": 0,
|
|
58
|
+
"table_count": len(schema.tables),
|
|
59
|
+
"filterable_per_table": [],
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
table_details = []
|
|
63
|
+
for table_name, table in schema.tables.items():
|
|
64
|
+
filterable_count = sum(1 for col in table.columns.values() if col.is_filterable)
|
|
65
|
+
groupable_count = sum(1 for col in table.columns.values() if col.is_groupable)
|
|
66
|
+
aggregatable_count = sum(1 for col in table.columns.values() if col.is_aggregatable)
|
|
67
|
+
|
|
68
|
+
table_details.append(
|
|
69
|
+
{
|
|
70
|
+
"table": table_name,
|
|
71
|
+
"filterable": filterable_count,
|
|
72
|
+
"groupable": groupable_count,
|
|
73
|
+
"aggregatable": aggregatable_count,
|
|
74
|
+
}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
stats["total_filterable"] += filterable_count
|
|
78
|
+
stats["total_groupable"] += groupable_count
|
|
79
|
+
stats["total_aggregatable"] += aggregatable_count
|
|
80
|
+
|
|
81
|
+
if filterable_count > 0:
|
|
82
|
+
stats["min_filterable_per_table"] = min(stats["min_filterable_per_table"], filterable_count)
|
|
83
|
+
stats["max_filterable_per_table"] = max(stats["max_filterable_per_table"], filterable_count)
|
|
84
|
+
stats["filterable_per_table"].append(filterable_count)
|
|
85
|
+
|
|
86
|
+
if groupable_count > 0:
|
|
87
|
+
stats["min_groupable_per_table"] = min(stats["min_groupable_per_table"], groupable_count)
|
|
88
|
+
stats["max_groupable_per_table"] = max(stats["max_groupable_per_table"], groupable_count)
|
|
89
|
+
|
|
90
|
+
if stats["min_filterable_per_table"] == float("inf"):
|
|
91
|
+
stats["min_filterable_per_table"] = 0
|
|
92
|
+
if stats["min_groupable_per_table"] == float("inf"):
|
|
93
|
+
stats["min_groupable_per_table"] = 0
|
|
94
|
+
|
|
95
|
+
debug("[schema.compute_schema_stats] per-table column counts:")
|
|
96
|
+
for td in table_details:
|
|
97
|
+
debug(
|
|
98
|
+
f" {td['table']}: filterable={td['filterable']}, groupable={td['groupable']}, aggregatable={td['aggregatable']}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
debug("[schema.compute_schema_stats] schema-wide statistics:")
|
|
102
|
+
debug(f" table_count: {stats['table_count']}")
|
|
103
|
+
debug(f" total_filterable: {stats['total_filterable']}")
|
|
104
|
+
debug(f" total_groupable: {stats['total_groupable']}")
|
|
105
|
+
debug(f" total_aggregatable: {stats['total_aggregatable']}")
|
|
106
|
+
debug(f" min_filterable_per_table: {stats['min_filterable_per_table']}")
|
|
107
|
+
debug(f" max_filterable_per_table: {stats['max_filterable_per_table']}")
|
|
108
|
+
debug(f" min_groupable_per_table: {stats['min_groupable_per_table']}")
|
|
109
|
+
debug(f" max_groupable_per_table: {stats['max_groupable_per_table']}")
|
|
110
|
+
debug(f" filterable_per_table distribution: {stats['filterable_per_table']}")
|
|
111
|
+
|
|
112
|
+
return stats
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def compute_schema_limits(schema_stats: dict[str, Any]) -> SchemaLimits:
|
|
116
|
+
"""Compute adaptive pipeline limits from schema statistics.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
|
|
120
|
+
schema_stats: Dictionary as returned by ``compute_schema_stats``.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
|
|
124
|
+
``SchemaLimits`` with ``max_filters``, ``max_groupby``, and ``max_tables``.
|
|
125
|
+
"""
|
|
126
|
+
table_count = schema_stats.get("table_count", 1)
|
|
127
|
+
total_filterable = schema_stats.get("total_filterable", 0)
|
|
128
|
+
total_groupable = schema_stats.get("total_groupable", 0)
|
|
129
|
+
|
|
130
|
+
max_filters = max(1, total_filterable // table_count) if table_count > 0 else 1
|
|
131
|
+
max_groupby = max(1, total_groupable // table_count) if table_count > 0 else 1
|
|
132
|
+
|
|
133
|
+
if table_count <= 3:
|
|
134
|
+
max_tables = table_count
|
|
135
|
+
elif table_count <= 10:
|
|
136
|
+
max_tables = 3
|
|
137
|
+
else:
|
|
138
|
+
max_tables = 4
|
|
139
|
+
|
|
140
|
+
return SchemaLimits(
|
|
141
|
+
max_filters=max_filters,
|
|
142
|
+
max_groupby=max_groupby,
|
|
143
|
+
max_tables=max_tables,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _edge_key(e: FKEdge) -> tuple[str, tuple[str, ...], str, tuple[str, ...]]:
|
|
148
|
+
"""Generate a stable, sortable tuple key for an FK edge.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
|
|
152
|
+
e: ``FKEdge`` to key.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
|
|
156
|
+
Tuple of ``(src_table, src_cols_tuple, dst_table, dst_cols_tuple)``.
|
|
157
|
+
"""
|
|
158
|
+
return (e.src_table, tuple(e.src_cols), e.dst_table, tuple(e.dst_cols))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _reflect_enum_values(engine: Any, schema_name: str) -> dict[str, list[str]]:
|
|
162
|
+
"""Reflect all enum types and their ordered values from the database.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
|
|
166
|
+
engine: SQLAlchemy ``Engine`` connected to the target database.
|
|
167
|
+
|
|
168
|
+
schema_name: Database schema name to introspect.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
|
|
172
|
+
Dictionary mapping enum type name to a list of its allowed values, or an empty dict if reflection fails.
|
|
173
|
+
"""
|
|
174
|
+
dialect = get_dialect()
|
|
175
|
+
try:
|
|
176
|
+
enum_values = dialect.reflect_enums(engine, schema_name)
|
|
177
|
+
debug(f"[schema.reflect_enum_values] found {len(enum_values)} enum types")
|
|
178
|
+
return enum_values
|
|
179
|
+
except Exception as e:
|
|
180
|
+
debug(f"[schema.reflect_enum_values] failed to reflect enums: {e}")
|
|
181
|
+
return {}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _table_to_dict(table: TableMetadata) -> dict[str, Any]:
|
|
185
|
+
"""Serialize a TableMetadata instance to a plain dictionary.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
|
|
189
|
+
table: ``TableMetadata`` to serialize.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
|
|
193
|
+
Dictionary with ``name``, ``columns``, ``primary_key``, ``foreign_keys``, ``role``, ``row_count``, and ``description`` fields.
|
|
194
|
+
"""
|
|
195
|
+
return {
|
|
196
|
+
"name": table.name,
|
|
197
|
+
"columns": {k: asdict(v) for k, v in table.columns.items()},
|
|
198
|
+
"primary_key": table.primary_key,
|
|
199
|
+
"foreign_keys": [asdict(fk) for fk in table.foreign_keys],
|
|
200
|
+
"partition_columns": table.partition_columns,
|
|
201
|
+
"role": table.role,
|
|
202
|
+
"row_count": table.row_count,
|
|
203
|
+
"description": table.description,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _table_from_dict(d: dict[str, Any]) -> TableMetadata:
|
|
208
|
+
"""Deserialize a TableMetadata instance from a plain dictionary.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
|
|
212
|
+
d: Dictionary with keys matching ``TableMetadata`` fields.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
|
|
216
|
+
Populated ``TableMetadata`` with ``ColumnMetadata`` and ``FKEdge`` sub-objects.
|
|
217
|
+
"""
|
|
218
|
+
columns = {k: ColumnMetadata.from_dict(v) for k, v in d["columns"].items()}
|
|
219
|
+
return TableMetadata(
|
|
220
|
+
name=d["name"],
|
|
221
|
+
columns=columns,
|
|
222
|
+
primary_key=d["primary_key"],
|
|
223
|
+
foreign_keys=[FKEdge(**fk) for fk in d["foreign_keys"]],
|
|
224
|
+
partition_columns=d.get("partition_columns", []),
|
|
225
|
+
role=d.get("role"),
|
|
226
|
+
row_count=d.get("row_count", 0),
|
|
227
|
+
description=d.get("description", ""),
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _reverse_fk_path(path: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
232
|
+
"""Reverse a FK path by flipping each edge's direction and reversing list order.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
|
|
236
|
+
path: List of FK edge dicts with ``src_table``, ``src_cols``, ``dst_table``, and ``dst_cols``.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
|
|
240
|
+
New list of edge dicts with src/dst swapped and order reversed.
|
|
241
|
+
"""
|
|
242
|
+
reversed_path = []
|
|
243
|
+
for e in reversed(path):
|
|
244
|
+
reversed_path.append(
|
|
245
|
+
{
|
|
246
|
+
"src_table": e["dst_table"],
|
|
247
|
+
"src_cols": e["dst_cols"],
|
|
248
|
+
"dst_table": e["src_table"],
|
|
249
|
+
"dst_cols": e["src_cols"],
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
return reversed_path
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _analyze_fk_path_topology(path: list[dict[str, Any]]) -> tuple[str, str, list[str]]:
|
|
256
|
+
"""Analyze an FK path to determine its topology type, anchor table, and leaf tables.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
|
|
260
|
+
path: List of FK edge dicts representing a join path.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
|
|
264
|
+
Tuple of ``(topology_type, anchor_table, leaf_tables)`` where ``topology_type`` is one of ``'none'``, ``'linear'``, ``'star'``, or ``'tree'``.
|
|
265
|
+
"""
|
|
266
|
+
if not path:
|
|
267
|
+
return ("none", "", [])
|
|
268
|
+
table_counts: dict[str, int] = {}
|
|
269
|
+
for e in path:
|
|
270
|
+
src = e["src_table"]
|
|
271
|
+
dst = e["dst_table"]
|
|
272
|
+
table_counts[src] = table_counts.get(src, 0) + 1
|
|
273
|
+
table_counts[dst] = table_counts.get(dst, 0) + 1
|
|
274
|
+
if not table_counts:
|
|
275
|
+
return ("none", "", [])
|
|
276
|
+
leaves = sorted([t for t, c in table_counts.items() if c == 1])
|
|
277
|
+
hubs = sorted(
|
|
278
|
+
[t for t, c in table_counts.items() if c > 1],
|
|
279
|
+
key=lambda t: (-table_counts[t], t),
|
|
280
|
+
)
|
|
281
|
+
if len(leaves) == 2 and len(hubs) == len(table_counts) - 2:
|
|
282
|
+
return ("linear", min(leaves), leaves)
|
|
283
|
+
if len(hubs) == 1:
|
|
284
|
+
return ("star", hubs[0], leaves)
|
|
285
|
+
if hubs:
|
|
286
|
+
return ("tree", hubs[0], leaves)
|
|
287
|
+
return ("linear", min(table_counts.keys()), list(table_counts.keys()))
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _normalize_fk_path(path: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
291
|
+
"""Normalize an FK join path to a canonical form based on its topology.
|
|
292
|
+
|
|
293
|
+
For linear paths, edges are ordered so traversal starts from the lexicographically smallest leaf. For star or tree paths, edges are reoriented to flow outward from the hub (anchor) table.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
|
|
297
|
+
path: List of FK edge dicts to normalize.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
|
|
301
|
+
Reordered and/or flipped list of edge dicts in canonical form.
|
|
302
|
+
"""
|
|
303
|
+
if not path:
|
|
304
|
+
return path
|
|
305
|
+
topology_type, anchor, leaves = _analyze_fk_path_topology(path)
|
|
306
|
+
if topology_type == "none":
|
|
307
|
+
return path
|
|
308
|
+
if topology_type == "linear":
|
|
309
|
+
start_table = path[0]["src_table"]
|
|
310
|
+
if start_table == anchor:
|
|
311
|
+
return path
|
|
312
|
+
return _reverse_fk_path(path)
|
|
313
|
+
edge_map: dict[str, list[dict[str, Any]]] = {}
|
|
314
|
+
for e in path:
|
|
315
|
+
src = e["src_table"]
|
|
316
|
+
dst = e["dst_table"]
|
|
317
|
+
if src == anchor:
|
|
318
|
+
edge_map.setdefault(dst, []).append(e)
|
|
319
|
+
elif dst == anchor:
|
|
320
|
+
flipped = {
|
|
321
|
+
"src_table": dst,
|
|
322
|
+
"src_cols": e["dst_cols"],
|
|
323
|
+
"dst_table": src,
|
|
324
|
+
"dst_cols": e["src_cols"],
|
|
325
|
+
}
|
|
326
|
+
edge_map.setdefault(src, []).append(flipped)
|
|
327
|
+
else:
|
|
328
|
+
for branch_key in sorted(edge_map.keys()):
|
|
329
|
+
branch_tables = set()
|
|
330
|
+
for be in edge_map[branch_key]:
|
|
331
|
+
branch_tables.add(be["src_table"])
|
|
332
|
+
branch_tables.add(be["dst_table"])
|
|
333
|
+
if src in branch_tables or dst in branch_tables:
|
|
334
|
+
edge_map[branch_key].append(e)
|
|
335
|
+
break
|
|
336
|
+
normalized = []
|
|
337
|
+
for branch_key in sorted(edge_map.keys()):
|
|
338
|
+
normalized.extend(edge_map[branch_key])
|
|
339
|
+
return normalized if normalized else path
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _infer_missing_fks(tables: dict[str, TableMetadata]) -> list[FKEdge]:
|
|
343
|
+
"""Infer missing foreign keys from column naming conventions.
|
|
344
|
+
|
|
345
|
+
Matches columns ending in ``'_id'``, ``'_key'``, ``'id'``, or ``'key'`` against table names, then verifies the candidate target table has a matching primary key column.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
|
|
349
|
+
tables: Mapping of table name to ``TableMetadata`` as reflected from the database.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
|
|
353
|
+
List of ``FKEdge`` instances for inferred relationships not already declared in the schema.
|
|
354
|
+
"""
|
|
355
|
+
inferred = []
|
|
356
|
+
suffixes = ["_id", "_key", "id", "key"]
|
|
357
|
+
for table_name, table in tables.items():
|
|
358
|
+
for col_name, col in table.columns.items():
|
|
359
|
+
if col.is_foreign_key:
|
|
360
|
+
continue
|
|
361
|
+
if col.is_primary_key:
|
|
362
|
+
continue
|
|
363
|
+
matched_suffix = None
|
|
364
|
+
for suffix in suffixes:
|
|
365
|
+
if col_name.endswith(suffix):
|
|
366
|
+
matched_suffix = suffix
|
|
367
|
+
break
|
|
368
|
+
if not matched_suffix:
|
|
369
|
+
continue
|
|
370
|
+
prefix_full = col_name[: -len(matched_suffix)]
|
|
371
|
+
if not prefix_full:
|
|
372
|
+
continue
|
|
373
|
+
candidate_prefixes = [prefix_full]
|
|
374
|
+
if "_" in prefix_full:
|
|
375
|
+
parts = prefix_full.split("_")
|
|
376
|
+
candidate_prefixes.append(parts[-1])
|
|
377
|
+
for i in range(len(parts) - 1, 0, -1):
|
|
378
|
+
candidate_prefixes.append("_".join(parts[i:]))
|
|
379
|
+
for prefix in candidate_prefixes:
|
|
380
|
+
if prefix not in tables:
|
|
381
|
+
continue
|
|
382
|
+
if table_name == prefix:
|
|
383
|
+
continue
|
|
384
|
+
target_table = tables[prefix]
|
|
385
|
+
if not target_table.primary_key:
|
|
386
|
+
continue
|
|
387
|
+
target_pk = target_table.primary_key[0]
|
|
388
|
+
target_matched = False
|
|
389
|
+
for target_suffix in suffixes:
|
|
390
|
+
if target_pk.endswith(target_suffix):
|
|
391
|
+
target_prefix = target_pk[: -len(target_suffix)]
|
|
392
|
+
if target_prefix == prefix:
|
|
393
|
+
target_matched = True
|
|
394
|
+
break
|
|
395
|
+
if target_matched:
|
|
396
|
+
debug(f"[schema.infer_missing_fks] candidate: {table_name}.{col_name} -> {prefix}.{target_pk}")
|
|
397
|
+
inferred.append(
|
|
398
|
+
FKEdge(
|
|
399
|
+
src_table=table_name,
|
|
400
|
+
src_cols=[col_name],
|
|
401
|
+
dst_table=prefix,
|
|
402
|
+
dst_cols=[target_pk],
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
break
|
|
406
|
+
return inferred
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _reflect_schema(engine: Any, schema_name: str = None) -> SchemaGraph:
|
|
410
|
+
"""Reflect a database schema using SQLAlchemy and build a join-path graph.
|
|
411
|
+
|
|
412
|
+
Reflects all tables and columns, resolves explicit FK constraints, infers missing FKs by naming convention, reflects enum types, and runs BFS to compute shortest join paths between every table pair.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
|
|
416
|
+
engine: SQLAlchemy ``Engine`` connected to the target database.
|
|
417
|
+
|
|
418
|
+
schema_name: Database schema to reflect. Defaults to ``EngineConfig.RUNTIME.SCHEMA`` or ``'public'`` if not set.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
|
|
422
|
+
Fully populated ``SchemaGraph`` with ``tables``, ``join_paths_multi``, ``schema_hash``, ``created_at`` timestamp, and ``enum_values``.
|
|
423
|
+
"""
|
|
424
|
+
if schema_name is None:
|
|
425
|
+
schema_name = EngineConfig.RUNTIME.SCHEMA if hasattr(EngineConfig.RUNTIME, "SCHEMA") else "public"
|
|
426
|
+
|
|
427
|
+
from sqlalchemy import MetaData
|
|
428
|
+
|
|
429
|
+
debug(f"[schema.reflect_schema] reflecting from database schema '{schema_name}'")
|
|
430
|
+
md = MetaData(schema=schema_name)
|
|
431
|
+
md.reflect(bind=engine)
|
|
432
|
+
|
|
433
|
+
tables: dict[str, TableMetadata] = {}
|
|
434
|
+
|
|
435
|
+
for t in md.tables.values():
|
|
436
|
+
columns: dict[str, ColumnMetadata] = {}
|
|
437
|
+
for c in t.columns:
|
|
438
|
+
columns[c.name] = ColumnMetadata(
|
|
439
|
+
name=c.name,
|
|
440
|
+
data_type=str(c.type),
|
|
441
|
+
is_primary_key=c.name in [pk.name for pk in t.primary_key.columns],
|
|
442
|
+
is_foreign_key=False,
|
|
443
|
+
fk_target=None,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
tables[t.name] = TableMetadata(
|
|
447
|
+
name=t.name,
|
|
448
|
+
columns=columns,
|
|
449
|
+
primary_key=[c.name for c in t.primary_key.columns],
|
|
450
|
+
foreign_keys=[],
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
debug(f"[schema.reflect_schema] found {len(tables)} tables")
|
|
454
|
+
|
|
455
|
+
for t in md.tables.values():
|
|
456
|
+
for fk in t.foreign_key_constraints:
|
|
457
|
+
e = FKEdge(
|
|
458
|
+
src_table=t.name,
|
|
459
|
+
src_cols=[el.parent.name for el in fk.elements],
|
|
460
|
+
dst_table=fk.elements[0].column.table.name,
|
|
461
|
+
dst_cols=[el.column.name for el in fk.elements],
|
|
462
|
+
)
|
|
463
|
+
tables[t.name].foreign_keys.append(e)
|
|
464
|
+
|
|
465
|
+
for i, src_col in enumerate(e.src_cols):
|
|
466
|
+
if src_col in tables[t.name].columns:
|
|
467
|
+
tables[t.name].columns[src_col].is_foreign_key = True
|
|
468
|
+
tables[t.name].columns[src_col].fk_target = (
|
|
469
|
+
e.dst_table,
|
|
470
|
+
e.dst_cols[i],
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
debug(
|
|
474
|
+
f"[schema.reflect_schema] explicit FK: {e.src_table}.{e.src_cols[0]} -> {e.dst_table}.{e.dst_cols[0]}"
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
fk_count = sum(len(tbl.foreign_keys) for tbl in tables.values())
|
|
478
|
+
debug(f"[schema.reflect_schema] found {fk_count} foreign key edges")
|
|
479
|
+
|
|
480
|
+
inferred_fks = _infer_missing_fks(tables)
|
|
481
|
+
if inferred_fks:
|
|
482
|
+
debug(f"[schema.reflect_schema] inferred {len(inferred_fks)} missing FKs from naming conventions")
|
|
483
|
+
for e in inferred_fks:
|
|
484
|
+
tables[e.src_table].foreign_keys.append(e)
|
|
485
|
+
for i, src_col in enumerate(e.src_cols):
|
|
486
|
+
if src_col in tables[e.src_table].columns:
|
|
487
|
+
tables[e.src_table].columns[src_col].is_foreign_key = True
|
|
488
|
+
tables[e.src_table].columns[src_col].fk_target = (
|
|
489
|
+
e.dst_table,
|
|
490
|
+
e.dst_cols[i],
|
|
491
|
+
)
|
|
492
|
+
debug(f"[schema.reflect_schema] {e.src_table}.{e.src_cols[0]} -> {e.dst_table}.{e.dst_cols[0]}")
|
|
493
|
+
|
|
494
|
+
enum_values = _reflect_enum_values(engine, schema_name)
|
|
495
|
+
|
|
496
|
+
adj: dict[str, list[FKEdge]] = {t: [] for t in tables}
|
|
497
|
+
for tbl in tables.values():
|
|
498
|
+
for e in tbl.foreign_keys:
|
|
499
|
+
if e.src_table not in tables or e.dst_table not in tables:
|
|
500
|
+
continue
|
|
501
|
+
adj[e.src_table].append(e)
|
|
502
|
+
adj[e.dst_table].append(
|
|
503
|
+
FKEdge(
|
|
504
|
+
src_table=e.dst_table,
|
|
505
|
+
src_cols=e.dst_cols,
|
|
506
|
+
dst_table=e.src_table,
|
|
507
|
+
dst_cols=e.src_cols,
|
|
508
|
+
)
|
|
509
|
+
)
|
|
510
|
+
for t in adj:
|
|
511
|
+
adj[t] = sorted(adj[t], key=lambda x: _edge_key(x))
|
|
512
|
+
|
|
513
|
+
debug("[schema.reflect_schema] computing shortest join paths")
|
|
514
|
+
join_paths_multi: dict[str, dict[str, list[list[dict[str, Any]]]]] = {}
|
|
515
|
+
|
|
516
|
+
tlist = sorted(tables.keys())
|
|
517
|
+
for s in tlist:
|
|
518
|
+
join_paths_multi[s] = {}
|
|
519
|
+
queue = [s]
|
|
520
|
+
prev: dict[str, tuple[str, FKEdge]] = {}
|
|
521
|
+
seen = {s}
|
|
522
|
+
while queue:
|
|
523
|
+
cur = queue.pop(0)
|
|
524
|
+
for e in adj[cur]:
|
|
525
|
+
nxt = e.dst_table
|
|
526
|
+
if nxt in seen:
|
|
527
|
+
continue
|
|
528
|
+
seen.add(nxt)
|
|
529
|
+
prev[nxt] = (cur, e)
|
|
530
|
+
queue.append(nxt)
|
|
531
|
+
|
|
532
|
+
for t in tlist:
|
|
533
|
+
if t == s:
|
|
534
|
+
join_paths_multi[s][t] = [[]]
|
|
535
|
+
continue
|
|
536
|
+
if t in prev:
|
|
537
|
+
path = []
|
|
538
|
+
cur = t
|
|
539
|
+
while cur != s:
|
|
540
|
+
p, e = prev[cur]
|
|
541
|
+
path.append(e)
|
|
542
|
+
cur = p
|
|
543
|
+
path.reverse()
|
|
544
|
+
sp = [asdict(e) for e in path]
|
|
545
|
+
join_paths_multi[s][t] = [_normalize_fk_path(sp)]
|
|
546
|
+
else:
|
|
547
|
+
join_paths_multi[s][t] = []
|
|
548
|
+
|
|
549
|
+
schema_hash = schema_hash_fp({k: _table_to_dict(v) for k, v in tables.items()})
|
|
550
|
+
|
|
551
|
+
sg = SchemaGraph(
|
|
552
|
+
tables=tables,
|
|
553
|
+
join_paths_multi=join_paths_multi,
|
|
554
|
+
schema_hash=schema_hash,
|
|
555
|
+
created_at=datetime.now().isoformat(),
|
|
556
|
+
enum_values=enum_values,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
return sg
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def load_or_create_schema_graph(engine: Any = None) -> SchemaGraph:
|
|
563
|
+
"""Load the schema graph from the JSON cache or build it fresh from the database.
|
|
564
|
+
|
|
565
|
+
If a valid cached schema exists (hash matches) it is returned without any database access. Otherwise the schema is reflected, profiled, and saved.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
|
|
569
|
+
engine: SQLAlchemy ``Engine`` required for PostgreSQL reflection. Not used for Databricks.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
|
|
573
|
+
Fully populated and profiled ``SchemaGraph``.
|
|
574
|
+
|
|
575
|
+
Raises:
|
|
576
|
+
|
|
577
|
+
ValueError: If ``engine`` is ``None`` for PostgreSQL or the engine type is unsupported.
|
|
578
|
+
"""
|
|
579
|
+
schema_json_path = EngineConfig.SCHEMA_JSON_PATH
|
|
580
|
+
|
|
581
|
+
debug(f"[schema.load_or_create_schema_graph] engine_type={EngineConfig.TYPE}")
|
|
582
|
+
|
|
583
|
+
if not PolicyConfig.REGENERATE_SCHEMA_GRAPH and os.path.exists(schema_json_path):
|
|
584
|
+
debug(f"[schema.load_or_create_schema_graph] loading from cache '{schema_json_path}'")
|
|
585
|
+
with open(schema_json_path, encoding="utf-8") as f:
|
|
586
|
+
d = json.load(f)
|
|
587
|
+
|
|
588
|
+
tables = {k: _table_from_dict(v) for k, v in d["tables"].items()}
|
|
589
|
+
|
|
590
|
+
recomputed_hash = schema_hash_fp(d["tables"])
|
|
591
|
+
cached_hash = d.get("schema_hash", "")
|
|
592
|
+
if recomputed_hash != cached_hash:
|
|
593
|
+
print("Schema cache outdated, rebuilding...")
|
|
594
|
+
debug(
|
|
595
|
+
f"[schema.load_or_create_schema_graph] hash mismatch: cached={cached_hash[:16]} recomputed={recomputed_hash[:16]}"
|
|
596
|
+
)
|
|
597
|
+
os.remove(schema_json_path)
|
|
598
|
+
else:
|
|
599
|
+
join_paths_multi = d.get("join_paths_multi")
|
|
600
|
+
if not join_paths_multi:
|
|
601
|
+
join_paths_multi = {}
|
|
602
|
+
jp = d.get("join_paths", {})
|
|
603
|
+
for a in jp:
|
|
604
|
+
join_paths_multi[a] = {}
|
|
605
|
+
for b in jp[a]:
|
|
606
|
+
join_paths_multi[a][b] = [jp[a][b]] if jp[a][b] is not None else []
|
|
607
|
+
|
|
608
|
+
debug(f"[schema.load_or_create_schema_graph] loaded {len(tables)} tables from cache")
|
|
609
|
+
|
|
610
|
+
sg = SchemaGraph(
|
|
611
|
+
tables=tables,
|
|
612
|
+
join_paths_multi=join_paths_multi,
|
|
613
|
+
schema_hash=d["schema_hash"],
|
|
614
|
+
enum_values=d.get("enum_values", {}),
|
|
615
|
+
schema_stats=d["schema_stats"],
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
return sg
|
|
619
|
+
|
|
620
|
+
debug("[schema.load_or_create_schema_graph] no cache found, building schema")
|
|
621
|
+
|
|
622
|
+
spark_session = None
|
|
623
|
+
databricks_connection = None
|
|
624
|
+
try:
|
|
625
|
+
if EngineConfig.TYPE == "databricks":
|
|
626
|
+
if getattr(EngineConfig.RUNTIME, "has_native_connection", lambda: False)():
|
|
627
|
+
try:
|
|
628
|
+
from databricks import sql as dbsql
|
|
629
|
+
|
|
630
|
+
databricks_connection = dbsql.connect(
|
|
631
|
+
server_hostname=EngineConfig.RUNTIME.SERVER_HOSTNAME,
|
|
632
|
+
http_path=EngineConfig.RUNTIME.HTTP_PATH,
|
|
633
|
+
access_token=EngineConfig.RUNTIME.ACCESS_TOKEN,
|
|
634
|
+
)
|
|
635
|
+
except Exception as e:
|
|
636
|
+
raise RuntimeError(f"Failed to connect via databricks-sql-connector: {e}") from e
|
|
637
|
+
else:
|
|
638
|
+
from pyspark.sql import SparkSession
|
|
639
|
+
|
|
640
|
+
spark_session = SparkSession.builder.getOrCreate()
|
|
641
|
+
sg = _load_or_create_schema_databricks(
|
|
642
|
+
spark_session=spark_session,
|
|
643
|
+
connection=databricks_connection,
|
|
644
|
+
)
|
|
645
|
+
elif EngineConfig.TYPE == "postgresql":
|
|
646
|
+
if engine is None:
|
|
647
|
+
raise ValueError("Engine is required for PostgreSQL schema reflection")
|
|
648
|
+
sg = _load_or_create_schema_postgresql(engine)
|
|
649
|
+
else:
|
|
650
|
+
raise ValueError(f"Unsupported engine type: {EngineConfig.TYPE}")
|
|
651
|
+
|
|
652
|
+
_add_profiling_data(engine, sg, spark_session, databricks_connection)
|
|
653
|
+
finally:
|
|
654
|
+
if databricks_connection is not None:
|
|
655
|
+
try:
|
|
656
|
+
databricks_connection.close()
|
|
657
|
+
except Exception:
|
|
658
|
+
pass
|
|
659
|
+
|
|
660
|
+
debug("[schema.load_or_create_schema_graph] computing schema stats after profiling")
|
|
661
|
+
sg.schema_stats = compute_schema_stats(sg)
|
|
662
|
+
|
|
663
|
+
_save_schema_to_cache(sg, schema_json_path)
|
|
664
|
+
debug("[schema.load_or_create_schema_graph] cache saved with profiling data")
|
|
665
|
+
|
|
666
|
+
return sg
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def _tables_meta_to_schema_graph(tables_meta: dict[str, dict]) -> SchemaGraph:
|
|
670
|
+
"""Convert a raw table metadata dictionary to a fully connected ``SchemaGraph``.
|
|
671
|
+
|
|
672
|
+
Builds ``ColumnMetadata`` and ``FKEdge`` objects, then runs BFS join-path computation between all table pairs.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
|
|
676
|
+
tables_meta: Dictionary mapping table name to metadata dicts with keys ``column_names_original``, ``column_types``, ``primary_keys``, and ``foreign_keys``.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
|
|
680
|
+
``SchemaGraph`` with ``tables``, ``join_paths_multi``, ``schema_hash``, and ``created_at``.
|
|
681
|
+
"""
|
|
682
|
+
tables: dict[str, TableMetadata] = {}
|
|
683
|
+
|
|
684
|
+
for table_name, meta in tables_meta.items():
|
|
685
|
+
columns: dict[str, ColumnMetadata] = {}
|
|
686
|
+
col_names = meta.get("column_names_original", [])
|
|
687
|
+
col_types = meta.get("column_types", [])
|
|
688
|
+
pk_cols = meta.get("primary_keys", [])
|
|
689
|
+
|
|
690
|
+
for i, col_name in enumerate(col_names):
|
|
691
|
+
col_type = col_types[i] if i < len(col_types) else "UNKNOWN"
|
|
692
|
+
columns[col_name] = ColumnMetadata(
|
|
693
|
+
name=col_name,
|
|
694
|
+
data_type=col_type,
|
|
695
|
+
is_primary_key=col_name in pk_cols,
|
|
696
|
+
is_foreign_key=False,
|
|
697
|
+
fk_target=None,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
fk_edges = []
|
|
701
|
+
for fk in meta.get("foreign_keys", []):
|
|
702
|
+
edge = FKEdge(
|
|
703
|
+
src_table=table_name,
|
|
704
|
+
src_cols=fk["src_cols"],
|
|
705
|
+
dst_table=fk["dst_table"],
|
|
706
|
+
dst_cols=fk["dst_cols"],
|
|
707
|
+
)
|
|
708
|
+
fk_edges.append(edge)
|
|
709
|
+
|
|
710
|
+
for i, src_col in enumerate(fk["src_cols"]):
|
|
711
|
+
if src_col in columns:
|
|
712
|
+
columns[src_col].is_foreign_key = True
|
|
713
|
+
columns[src_col].fk_target = (fk["dst_table"], fk["dst_cols"][i])
|
|
714
|
+
|
|
715
|
+
partition_cols = meta.get("partition_columns", [])
|
|
716
|
+
tables[table_name] = TableMetadata(
|
|
717
|
+
name=table_name,
|
|
718
|
+
columns=columns,
|
|
719
|
+
primary_key=pk_cols,
|
|
720
|
+
foreign_keys=fk_edges,
|
|
721
|
+
partition_columns=partition_cols,
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
fk_count = sum(len(tbl.foreign_keys) for tbl in tables.values())
|
|
725
|
+
debug(f"[schema.tables_meta_to_schema_graph] {len(tables)} tables, {fk_count} FK edges")
|
|
726
|
+
|
|
727
|
+
adj: dict[str, list[FKEdge]] = {t: [] for t in tables}
|
|
728
|
+
for tbl in tables.values():
|
|
729
|
+
for e in tbl.foreign_keys:
|
|
730
|
+
if e.src_table not in tables or e.dst_table not in tables:
|
|
731
|
+
continue
|
|
732
|
+
adj[e.src_table].append(e)
|
|
733
|
+
adj[e.dst_table].append(
|
|
734
|
+
FKEdge(
|
|
735
|
+
src_table=e.dst_table,
|
|
736
|
+
src_cols=e.dst_cols,
|
|
737
|
+
dst_table=e.src_table,
|
|
738
|
+
dst_cols=e.src_cols,
|
|
739
|
+
)
|
|
740
|
+
)
|
|
741
|
+
for t in adj:
|
|
742
|
+
adj[t] = sorted(adj[t], key=lambda x: _edge_key(x))
|
|
743
|
+
|
|
744
|
+
debug("[schema.tables_meta_to_schema_graph] computing shortest join paths")
|
|
745
|
+
|
|
746
|
+
join_paths_multi: dict[str, dict[str, list[list[dict[str, Any]]]]] = {}
|
|
747
|
+
tlist = sorted(tables.keys())
|
|
748
|
+
|
|
749
|
+
for s in tlist:
|
|
750
|
+
join_paths_multi[s] = {}
|
|
751
|
+
queue = [s]
|
|
752
|
+
prev: dict[str, tuple[str, FKEdge]] = {}
|
|
753
|
+
seen = {s}
|
|
754
|
+
while queue:
|
|
755
|
+
cur = queue.pop(0)
|
|
756
|
+
for e in adj[cur]:
|
|
757
|
+
nxt = e.dst_table
|
|
758
|
+
if nxt in seen:
|
|
759
|
+
continue
|
|
760
|
+
seen.add(nxt)
|
|
761
|
+
prev[nxt] = (cur, e)
|
|
762
|
+
queue.append(nxt)
|
|
763
|
+
|
|
764
|
+
for t in tlist:
|
|
765
|
+
if t == s:
|
|
766
|
+
join_paths_multi[s][t] = [[]]
|
|
767
|
+
continue
|
|
768
|
+
if t in prev:
|
|
769
|
+
path = []
|
|
770
|
+
cur = t
|
|
771
|
+
while cur != s:
|
|
772
|
+
p, e = prev[cur]
|
|
773
|
+
path.append(e)
|
|
774
|
+
cur = p
|
|
775
|
+
path.reverse()
|
|
776
|
+
sp = [asdict(e) for e in path]
|
|
777
|
+
join_paths_multi[s][t] = [_normalize_fk_path(sp)]
|
|
778
|
+
else:
|
|
779
|
+
join_paths_multi[s][t] = []
|
|
780
|
+
|
|
781
|
+
schema_hash = schema_hash_fp({k: _table_to_dict(v) for k, v in tables.items()})
|
|
782
|
+
|
|
783
|
+
sg = SchemaGraph(
|
|
784
|
+
tables=tables,
|
|
785
|
+
join_paths_multi=join_paths_multi,
|
|
786
|
+
schema_hash=schema_hash,
|
|
787
|
+
created_at=datetime.now().isoformat(),
|
|
788
|
+
enum_values={},
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
return sg
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def _load_or_create_schema_postgresql(engine: Any) -> SchemaGraph:
|
|
795
|
+
"""Build a ``SchemaGraph`` for PostgreSQL from a live database or SQL file fallback.
|
|
796
|
+
|
|
797
|
+
Attempts live reflection first; falls back to parsing a SQL DDL file if reflection fails and ``EngineConfig.RUNTIME.SQL_FILE_PATH`` is configured.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
|
|
801
|
+
engine: SQLAlchemy ``Engine`` connected to the PostgreSQL database.
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
|
|
805
|
+
``SchemaGraph`` built from the database or SQL file.
|
|
806
|
+
|
|
807
|
+
Raises:
|
|
808
|
+
|
|
809
|
+
ValueError: If reflection fails and no SQL file is available or parseable.
|
|
810
|
+
"""
|
|
811
|
+
try:
|
|
812
|
+
debug("[schema.load_or_create_schema_postgresql] reflecting_database")
|
|
813
|
+
sg = _reflect_schema(engine)
|
|
814
|
+
debug(f"[schema.load_or_create_schema_postgresql] reflected: {len(sg.tables)} tables")
|
|
815
|
+
return sg
|
|
816
|
+
except Exception as e:
|
|
817
|
+
debug(f"[schema.load_or_create_schema_postgresql] reflection_failed: {e}")
|
|
818
|
+
sql_file_path = getattr(EngineConfig.RUNTIME, "SQL_FILE_PATH", None)
|
|
819
|
+
|
|
820
|
+
if sql_file_path and os.path.exists(sql_file_path):
|
|
821
|
+
debug(f"[schema.load_or_create_schema_postgresql] parsing_sql_file: {sql_file_path}")
|
|
822
|
+
tables_meta = parse_sql_file(Path(sql_file_path))
|
|
823
|
+
|
|
824
|
+
if not tables_meta or len(tables_meta) == 0:
|
|
825
|
+
raise ValueError("Both database reflection and SQL file parsing failed") from e
|
|
826
|
+
|
|
827
|
+
return _tables_meta_to_schema_graph(tables_meta)
|
|
828
|
+
else:
|
|
829
|
+
raise ValueError(f"Database reflection failed and no SQL file available: {e}") from e
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
def _add_profiling_data(
|
|
833
|
+
engine: Any,
|
|
834
|
+
sg: SchemaGraph,
|
|
835
|
+
spark_session=None,
|
|
836
|
+
connection=None,
|
|
837
|
+
) -> None:
|
|
838
|
+
"""Add column profiling data to a SchemaGraph in-place.
|
|
839
|
+
|
|
840
|
+
Runs the three-step profiling pipeline: 1) ``profile_schema`` — collect statistics
|
|
841
|
+
(distinct count, null ratio, min/max, top-K). 2) ``apply_column_roles_llm`` — LLM
|
|
842
|
+
assigns table and column roles with overrides. 3) ``assign_column_ops`` —
|
|
843
|
+
deterministic operation flags based on final roles.
|
|
844
|
+
|
|
845
|
+
Falls back to heuristics after ``MAX_ROLE_CLASSIFICATION_RETRIES`` LLM failures.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
|
|
849
|
+
engine: SQLAlchemy ``Engine`` for PostgreSQL profiling; unused for Databricks.
|
|
850
|
+
|
|
851
|
+
sg: ``SchemaGraph`` to enrich in-place.
|
|
852
|
+
|
|
853
|
+
spark_session: Active ``SparkSession`` for Databricks Spark-based profiling.
|
|
854
|
+
|
|
855
|
+
connection: Active ``databricks.sql`` connection for connector-based profiling.
|
|
856
|
+
"""
|
|
857
|
+
debug("[schema.add_profiling_data] Step 1: profiling columns (statistics)")
|
|
858
|
+
|
|
859
|
+
if EngineConfig.TYPE == "databricks":
|
|
860
|
+
catalog = EngineConfig.RUNTIME.CATALOG
|
|
861
|
+
schema_name = EngineConfig.RUNTIME.SCHEMA
|
|
862
|
+
if connection is not None:
|
|
863
|
+
profile_schema_sql_connector(connection, catalog, schema_name, sg)
|
|
864
|
+
else:
|
|
865
|
+
if spark_session is None:
|
|
866
|
+
from pyspark.sql import SparkSession
|
|
867
|
+
|
|
868
|
+
spark_session = SparkSession.builder.getOrCreate()
|
|
869
|
+
profile_schema_spark(spark_session, catalog, schema_name, sg)
|
|
870
|
+
else:
|
|
871
|
+
profile_schema(engine, sg)
|
|
872
|
+
|
|
873
|
+
debug("[schema.add_profiling_data] Step 2: inferring column and table roles via LLM")
|
|
874
|
+
apply_column_roles_llm(sg)
|
|
875
|
+
|
|
876
|
+
debug("[schema.add_profiling_data] Step 3: assigning column operations (deterministic)")
|
|
877
|
+
assign_column_ops(sg)
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
def _save_schema_to_cache(sg: SchemaGraph, schema_json_path: str) -> None:
|
|
881
|
+
"""Save a SchemaGraph to a JSON cache file.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
|
|
885
|
+
sg: ``SchemaGraph`` to persist.
|
|
886
|
+
|
|
887
|
+
schema_json_path: Absolute path to the output JSON file.
|
|
888
|
+
"""
|
|
889
|
+
|
|
890
|
+
def _dump_mixed(fh, data: dict[str, Any]):
|
|
891
|
+
fh.write("{")
|
|
892
|
+
first = True
|
|
893
|
+
for k in sorted(data.keys()):
|
|
894
|
+
if not first:
|
|
895
|
+
fh.write(",")
|
|
896
|
+
first = False
|
|
897
|
+
fh.write(json.dumps(k, ensure_ascii=False))
|
|
898
|
+
fh.write(":")
|
|
899
|
+
v = data[k]
|
|
900
|
+
if isinstance(v, dict):
|
|
901
|
+
json.dump(v, fh, indent=2, sort_keys=True, ensure_ascii=False)
|
|
902
|
+
else:
|
|
903
|
+
json.dump(v, fh, separators=(",", ":"), ensure_ascii=False)
|
|
904
|
+
fh.write("}")
|
|
905
|
+
|
|
906
|
+
cache_data = {
|
|
907
|
+
"tables": {k: _table_to_dict(v) for k, v in sg.tables.items()},
|
|
908
|
+
"join_paths_multi": sg.join_paths_multi,
|
|
909
|
+
"schema_hash": sg.schema_hash,
|
|
910
|
+
"created_at": sg.created_at,
|
|
911
|
+
"enum_values": sg.enum_values or {},
|
|
912
|
+
"schema_stats": sg.schema_stats or {},
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
debug(f"[schema.save_schema_to_cache] saving to '{schema_json_path}'")
|
|
916
|
+
with open(schema_json_path, "w", encoding="utf-8") as f:
|
|
917
|
+
_dump_mixed(f, cache_data)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def _load_or_create_schema_databricks(
|
|
921
|
+
spark_session=None,
|
|
922
|
+
connection=None,
|
|
923
|
+
) -> SchemaGraph:
|
|
924
|
+
"""Build a ``SchemaGraph`` for Databricks from a SQL DDL file or catalog introspection.
|
|
925
|
+
|
|
926
|
+
Parses ``EngineConfig.RUNTIME.SQL_FILE_PATH`` if it exists and is non-empty; otherwise
|
|
927
|
+
extracts table metadata from the Databricks catalog via databricks-sql-connector
|
|
928
|
+
(when ``connection`` is provided) or Spark.
|
|
929
|
+
|
|
930
|
+
Args:
|
|
931
|
+
|
|
932
|
+
spark_session: Active ``SparkSession`` for Spark-based catalog extraction.
|
|
933
|
+
|
|
934
|
+
connection: Active ``databricks.sql`` connection for connector-based extraction.
|
|
935
|
+
|
|
936
|
+
Returns:
|
|
937
|
+
|
|
938
|
+
``SchemaGraph`` built from the SQL file or catalog.
|
|
939
|
+
|
|
940
|
+
Raises:
|
|
941
|
+
|
|
942
|
+
RuntimeError: If no SQL file is available and catalog extraction fails.
|
|
943
|
+
"""
|
|
944
|
+
sql_file_path = getattr(EngineConfig.RUNTIME, "SQL_FILE_PATH", None)
|
|
945
|
+
|
|
946
|
+
if sql_file_path and os.path.exists(sql_file_path):
|
|
947
|
+
debug(f"[schema.load_or_create_schema_databricks] parsing SQL file '{sql_file_path}'")
|
|
948
|
+
tables_meta = parse_sql_file(Path(sql_file_path))
|
|
949
|
+
|
|
950
|
+
if not tables_meta or len(tables_meta) == 0:
|
|
951
|
+
debug(
|
|
952
|
+
"[schema.load_or_create_schema_databricks] SQL file parsing returned 0 tables, falling back to catalog extraction"
|
|
953
|
+
)
|
|
954
|
+
sql_file_path = None
|
|
955
|
+
|
|
956
|
+
if not sql_file_path or not os.path.exists(sql_file_path):
|
|
957
|
+
debug("[schema.load_or_create_schema_databricks] SQL file not found or empty, attempting catalog extraction")
|
|
958
|
+
catalog = EngineConfig.RUNTIME.CATALOG
|
|
959
|
+
schema = EngineConfig.RUNTIME.SCHEMA
|
|
960
|
+
try:
|
|
961
|
+
if connection is not None:
|
|
962
|
+
tables_meta = extract_tables_from_catalog_sql_connector(
|
|
963
|
+
connection, catalog, schema
|
|
964
|
+
)
|
|
965
|
+
else:
|
|
966
|
+
from pyspark.sql import SparkSession
|
|
967
|
+
|
|
968
|
+
spark = spark_session if spark_session else SparkSession.builder.getOrCreate()
|
|
969
|
+
tables_meta = extract_tables_from_catalog(spark, catalog, schema)
|
|
970
|
+
except Exception as e:
|
|
971
|
+
raise RuntimeError(f"Cannot build schema: no SQL file and catalog extraction failed: {e}") from e
|
|
972
|
+
|
|
973
|
+
return _tables_meta_to_schema_graph(tables_meta)
|