aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
text2sql/schema.py ADDED
@@ -0,0 +1,973 @@
1
+ """Schema reflection, graph building, and profiling orchestration.
2
+
3
+ Reflects table and column metadata from a live database or SQL file, infers missing foreign keys by naming convention, and builds a shortest-path join graph between all table pairs. A local JSON cache avoids repeated reflection; the cache is invalidated by a content hash of the schema structure. Profiling data (statistics, value domains, LLM-assigned roles, and allowed operations) is layered on top of the reflected schema before caching. Adaptive limits derived from schema statistics are also stored.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import os
10
+ from dataclasses import asdict
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ from .config import EngineConfig, PolicyConfig
16
+ from .contracts_base import (
17
+ ColumnMetadata,
18
+ FKEdge,
19
+ SchemaGraph,
20
+ SchemaLimits,
21
+ TableMetadata,
22
+ )
23
+ from .core_utils import debug, schema_hash_fp
24
+ from .dialect import get_dialect
25
+ from .schema_profiling import (
26
+ apply_column_roles_llm,
27
+ assign_column_ops,
28
+ extract_tables_from_catalog,
29
+ extract_tables_from_catalog_sql_connector,
30
+ parse_sql_file,
31
+ profile_schema,
32
+ profile_schema_spark,
33
+ profile_schema_sql_connector,
34
+ )
35
+
36
+
37
+ def compute_schema_stats(schema: SchemaGraph) -> dict[str, Any]:
38
+ """Compute schema-wide column availability statistics for adaptive limit calculation.
39
+
40
+ Args:
41
+
42
+ schema: Populated ``SchemaGraph`` with profiled column metadata.
43
+
44
+ Returns:
45
+
46
+ Dictionary with keys: ``total_filterable``, ``total_groupable``, ``total_aggregatable``, ``min/max_filterable_per_table``, ``min/max_groupable_per_table``, ``table_count``, and ``filterable_per_table`` (list of per-table counts).
47
+ """
48
+ debug("[schema.compute_schema_stats] computing schema statistics for adaptive limits")
49
+
50
+ stats = {
51
+ "total_filterable": 0,
52
+ "total_groupable": 0,
53
+ "total_aggregatable": 0,
54
+ "min_filterable_per_table": float("inf"),
55
+ "max_filterable_per_table": 0,
56
+ "min_groupable_per_table": float("inf"),
57
+ "max_groupable_per_table": 0,
58
+ "table_count": len(schema.tables),
59
+ "filterable_per_table": [],
60
+ }
61
+
62
+ table_details = []
63
+ for table_name, table in schema.tables.items():
64
+ filterable_count = sum(1 for col in table.columns.values() if col.is_filterable)
65
+ groupable_count = sum(1 for col in table.columns.values() if col.is_groupable)
66
+ aggregatable_count = sum(1 for col in table.columns.values() if col.is_aggregatable)
67
+
68
+ table_details.append(
69
+ {
70
+ "table": table_name,
71
+ "filterable": filterable_count,
72
+ "groupable": groupable_count,
73
+ "aggregatable": aggregatable_count,
74
+ }
75
+ )
76
+
77
+ stats["total_filterable"] += filterable_count
78
+ stats["total_groupable"] += groupable_count
79
+ stats["total_aggregatable"] += aggregatable_count
80
+
81
+ if filterable_count > 0:
82
+ stats["min_filterable_per_table"] = min(stats["min_filterable_per_table"], filterable_count)
83
+ stats["max_filterable_per_table"] = max(stats["max_filterable_per_table"], filterable_count)
84
+ stats["filterable_per_table"].append(filterable_count)
85
+
86
+ if groupable_count > 0:
87
+ stats["min_groupable_per_table"] = min(stats["min_groupable_per_table"], groupable_count)
88
+ stats["max_groupable_per_table"] = max(stats["max_groupable_per_table"], groupable_count)
89
+
90
+ if stats["min_filterable_per_table"] == float("inf"):
91
+ stats["min_filterable_per_table"] = 0
92
+ if stats["min_groupable_per_table"] == float("inf"):
93
+ stats["min_groupable_per_table"] = 0
94
+
95
+ debug("[schema.compute_schema_stats] per-table column counts:")
96
+ for td in table_details:
97
+ debug(
98
+ f" {td['table']}: filterable={td['filterable']}, groupable={td['groupable']}, aggregatable={td['aggregatable']}"
99
+ )
100
+
101
+ debug("[schema.compute_schema_stats] schema-wide statistics:")
102
+ debug(f" table_count: {stats['table_count']}")
103
+ debug(f" total_filterable: {stats['total_filterable']}")
104
+ debug(f" total_groupable: {stats['total_groupable']}")
105
+ debug(f" total_aggregatable: {stats['total_aggregatable']}")
106
+ debug(f" min_filterable_per_table: {stats['min_filterable_per_table']}")
107
+ debug(f" max_filterable_per_table: {stats['max_filterable_per_table']}")
108
+ debug(f" min_groupable_per_table: {stats['min_groupable_per_table']}")
109
+ debug(f" max_groupable_per_table: {stats['max_groupable_per_table']}")
110
+ debug(f" filterable_per_table distribution: {stats['filterable_per_table']}")
111
+
112
+ return stats
113
+
114
+
115
+ def compute_schema_limits(schema_stats: dict[str, Any]) -> SchemaLimits:
116
+ """Compute adaptive pipeline limits from schema statistics.
117
+
118
+ Args:
119
+
120
+ schema_stats: Dictionary as returned by ``compute_schema_stats``.
121
+
122
+ Returns:
123
+
124
+ ``SchemaLimits`` with ``max_filters``, ``max_groupby``, and ``max_tables``.
125
+ """
126
+ table_count = schema_stats.get("table_count", 1)
127
+ total_filterable = schema_stats.get("total_filterable", 0)
128
+ total_groupable = schema_stats.get("total_groupable", 0)
129
+
130
+ max_filters = max(1, total_filterable // table_count) if table_count > 0 else 1
131
+ max_groupby = max(1, total_groupable // table_count) if table_count > 0 else 1
132
+
133
+ if table_count <= 3:
134
+ max_tables = table_count
135
+ elif table_count <= 10:
136
+ max_tables = 3
137
+ else:
138
+ max_tables = 4
139
+
140
+ return SchemaLimits(
141
+ max_filters=max_filters,
142
+ max_groupby=max_groupby,
143
+ max_tables=max_tables,
144
+ )
145
+
146
+
147
+ def _edge_key(e: FKEdge) -> tuple[str, tuple[str, ...], str, tuple[str, ...]]:
148
+ """Generate a stable, sortable tuple key for an FK edge.
149
+
150
+ Args:
151
+
152
+ e: ``FKEdge`` to key.
153
+
154
+ Returns:
155
+
156
+ Tuple of ``(src_table, src_cols_tuple, dst_table, dst_cols_tuple)``.
157
+ """
158
+ return (e.src_table, tuple(e.src_cols), e.dst_table, tuple(e.dst_cols))
159
+
160
+
161
+ def _reflect_enum_values(engine: Any, schema_name: str) -> dict[str, list[str]]:
162
+ """Reflect all enum types and their ordered values from the database.
163
+
164
+ Args:
165
+
166
+ engine: SQLAlchemy ``Engine`` connected to the target database.
167
+
168
+ schema_name: Database schema name to introspect.
169
+
170
+ Returns:
171
+
172
+ Dictionary mapping enum type name to a list of its allowed values, or an empty dict if reflection fails.
173
+ """
174
+ dialect = get_dialect()
175
+ try:
176
+ enum_values = dialect.reflect_enums(engine, schema_name)
177
+ debug(f"[schema.reflect_enum_values] found {len(enum_values)} enum types")
178
+ return enum_values
179
+ except Exception as e:
180
+ debug(f"[schema.reflect_enum_values] failed to reflect enums: {e}")
181
+ return {}
182
+
183
+
184
+ def _table_to_dict(table: TableMetadata) -> dict[str, Any]:
185
+ """Serialize a TableMetadata instance to a plain dictionary.
186
+
187
+ Args:
188
+
189
+ table: ``TableMetadata`` to serialize.
190
+
191
+ Returns:
192
+
193
+ Dictionary with ``name``, ``columns``, ``primary_key``, ``foreign_keys``, ``role``, ``row_count``, and ``description`` fields.
194
+ """
195
+ return {
196
+ "name": table.name,
197
+ "columns": {k: asdict(v) for k, v in table.columns.items()},
198
+ "primary_key": table.primary_key,
199
+ "foreign_keys": [asdict(fk) for fk in table.foreign_keys],
200
+ "partition_columns": table.partition_columns,
201
+ "role": table.role,
202
+ "row_count": table.row_count,
203
+ "description": table.description,
204
+ }
205
+
206
+
207
+ def _table_from_dict(d: dict[str, Any]) -> TableMetadata:
208
+ """Deserialize a TableMetadata instance from a plain dictionary.
209
+
210
+ Args:
211
+
212
+ d: Dictionary with keys matching ``TableMetadata`` fields.
213
+
214
+ Returns:
215
+
216
+ Populated ``TableMetadata`` with ``ColumnMetadata`` and ``FKEdge`` sub-objects.
217
+ """
218
+ columns = {k: ColumnMetadata.from_dict(v) for k, v in d["columns"].items()}
219
+ return TableMetadata(
220
+ name=d["name"],
221
+ columns=columns,
222
+ primary_key=d["primary_key"],
223
+ foreign_keys=[FKEdge(**fk) for fk in d["foreign_keys"]],
224
+ partition_columns=d.get("partition_columns", []),
225
+ role=d.get("role"),
226
+ row_count=d.get("row_count", 0),
227
+ description=d.get("description", ""),
228
+ )
229
+
230
+
231
+ def _reverse_fk_path(path: list[dict[str, Any]]) -> list[dict[str, Any]]:
232
+ """Reverse a FK path by flipping each edge's direction and reversing list order.
233
+
234
+ Args:
235
+
236
+ path: List of FK edge dicts with ``src_table``, ``src_cols``, ``dst_table``, and ``dst_cols``.
237
+
238
+ Returns:
239
+
240
+ New list of edge dicts with src/dst swapped and order reversed.
241
+ """
242
+ reversed_path = []
243
+ for e in reversed(path):
244
+ reversed_path.append(
245
+ {
246
+ "src_table": e["dst_table"],
247
+ "src_cols": e["dst_cols"],
248
+ "dst_table": e["src_table"],
249
+ "dst_cols": e["src_cols"],
250
+ }
251
+ )
252
+ return reversed_path
253
+
254
+
255
+ def _analyze_fk_path_topology(path: list[dict[str, Any]]) -> tuple[str, str, list[str]]:
256
+ """Analyze an FK path to determine its topology type, anchor table, and leaf tables.
257
+
258
+ Args:
259
+
260
+ path: List of FK edge dicts representing a join path.
261
+
262
+ Returns:
263
+
264
+ Tuple of ``(topology_type, anchor_table, leaf_tables)`` where ``topology_type`` is one of ``'none'``, ``'linear'``, ``'star'``, or ``'tree'``.
265
+ """
266
+ if not path:
267
+ return ("none", "", [])
268
+ table_counts: dict[str, int] = {}
269
+ for e in path:
270
+ src = e["src_table"]
271
+ dst = e["dst_table"]
272
+ table_counts[src] = table_counts.get(src, 0) + 1
273
+ table_counts[dst] = table_counts.get(dst, 0) + 1
274
+ if not table_counts:
275
+ return ("none", "", [])
276
+ leaves = sorted([t for t, c in table_counts.items() if c == 1])
277
+ hubs = sorted(
278
+ [t for t, c in table_counts.items() if c > 1],
279
+ key=lambda t: (-table_counts[t], t),
280
+ )
281
+ if len(leaves) == 2 and len(hubs) == len(table_counts) - 2:
282
+ return ("linear", min(leaves), leaves)
283
+ if len(hubs) == 1:
284
+ return ("star", hubs[0], leaves)
285
+ if hubs:
286
+ return ("tree", hubs[0], leaves)
287
+ return ("linear", min(table_counts.keys()), list(table_counts.keys()))
288
+
289
+
290
+ def _normalize_fk_path(path: list[dict[str, Any]]) -> list[dict[str, Any]]:
291
+ """Normalize an FK join path to a canonical form based on its topology.
292
+
293
+ For linear paths, edges are ordered so traversal starts from the lexicographically smallest leaf. For star or tree paths, edges are reoriented to flow outward from the hub (anchor) table.
294
+
295
+ Args:
296
+
297
+ path: List of FK edge dicts to normalize.
298
+
299
+ Returns:
300
+
301
+ Reordered and/or flipped list of edge dicts in canonical form.
302
+ """
303
+ if not path:
304
+ return path
305
+ topology_type, anchor, leaves = _analyze_fk_path_topology(path)
306
+ if topology_type == "none":
307
+ return path
308
+ if topology_type == "linear":
309
+ start_table = path[0]["src_table"]
310
+ if start_table == anchor:
311
+ return path
312
+ return _reverse_fk_path(path)
313
+ edge_map: dict[str, list[dict[str, Any]]] = {}
314
+ for e in path:
315
+ src = e["src_table"]
316
+ dst = e["dst_table"]
317
+ if src == anchor:
318
+ edge_map.setdefault(dst, []).append(e)
319
+ elif dst == anchor:
320
+ flipped = {
321
+ "src_table": dst,
322
+ "src_cols": e["dst_cols"],
323
+ "dst_table": src,
324
+ "dst_cols": e["src_cols"],
325
+ }
326
+ edge_map.setdefault(src, []).append(flipped)
327
+ else:
328
+ for branch_key in sorted(edge_map.keys()):
329
+ branch_tables = set()
330
+ for be in edge_map[branch_key]:
331
+ branch_tables.add(be["src_table"])
332
+ branch_tables.add(be["dst_table"])
333
+ if src in branch_tables or dst in branch_tables:
334
+ edge_map[branch_key].append(e)
335
+ break
336
+ normalized = []
337
+ for branch_key in sorted(edge_map.keys()):
338
+ normalized.extend(edge_map[branch_key])
339
+ return normalized if normalized else path
340
+
341
+
342
+ def _infer_missing_fks(tables: dict[str, TableMetadata]) -> list[FKEdge]:
343
+ """Infer missing foreign keys from column naming conventions.
344
+
345
+ Matches columns ending in ``'_id'``, ``'_key'``, ``'id'``, or ``'key'`` against table names, then verifies the candidate target table has a matching primary key column.
346
+
347
+ Args:
348
+
349
+ tables: Mapping of table name to ``TableMetadata`` as reflected from the database.
350
+
351
+ Returns:
352
+
353
+ List of ``FKEdge`` instances for inferred relationships not already declared in the schema.
354
+ """
355
+ inferred = []
356
+ suffixes = ["_id", "_key", "id", "key"]
357
+ for table_name, table in tables.items():
358
+ for col_name, col in table.columns.items():
359
+ if col.is_foreign_key:
360
+ continue
361
+ if col.is_primary_key:
362
+ continue
363
+ matched_suffix = None
364
+ for suffix in suffixes:
365
+ if col_name.endswith(suffix):
366
+ matched_suffix = suffix
367
+ break
368
+ if not matched_suffix:
369
+ continue
370
+ prefix_full = col_name[: -len(matched_suffix)]
371
+ if not prefix_full:
372
+ continue
373
+ candidate_prefixes = [prefix_full]
374
+ if "_" in prefix_full:
375
+ parts = prefix_full.split("_")
376
+ candidate_prefixes.append(parts[-1])
377
+ for i in range(len(parts) - 1, 0, -1):
378
+ candidate_prefixes.append("_".join(parts[i:]))
379
+ for prefix in candidate_prefixes:
380
+ if prefix not in tables:
381
+ continue
382
+ if table_name == prefix:
383
+ continue
384
+ target_table = tables[prefix]
385
+ if not target_table.primary_key:
386
+ continue
387
+ target_pk = target_table.primary_key[0]
388
+ target_matched = False
389
+ for target_suffix in suffixes:
390
+ if target_pk.endswith(target_suffix):
391
+ target_prefix = target_pk[: -len(target_suffix)]
392
+ if target_prefix == prefix:
393
+ target_matched = True
394
+ break
395
+ if target_matched:
396
+ debug(f"[schema.infer_missing_fks] candidate: {table_name}.{col_name} -> {prefix}.{target_pk}")
397
+ inferred.append(
398
+ FKEdge(
399
+ src_table=table_name,
400
+ src_cols=[col_name],
401
+ dst_table=prefix,
402
+ dst_cols=[target_pk],
403
+ )
404
+ )
405
+ break
406
+ return inferred
407
+
408
+
409
+ def _reflect_schema(engine: Any, schema_name: str = None) -> SchemaGraph:
410
+ """Reflect a database schema using SQLAlchemy and build a join-path graph.
411
+
412
+ Reflects all tables and columns, resolves explicit FK constraints, infers missing FKs by naming convention, reflects enum types, and runs BFS to compute shortest join paths between every table pair.
413
+
414
+ Args:
415
+
416
+ engine: SQLAlchemy ``Engine`` connected to the target database.
417
+
418
+ schema_name: Database schema to reflect. Defaults to ``EngineConfig.RUNTIME.SCHEMA`` or ``'public'`` if not set.
419
+
420
+ Returns:
421
+
422
+ Fully populated ``SchemaGraph`` with ``tables``, ``join_paths_multi``, ``schema_hash``, ``created_at`` timestamp, and ``enum_values``.
423
+ """
424
+ if schema_name is None:
425
+ schema_name = EngineConfig.RUNTIME.SCHEMA if hasattr(EngineConfig.RUNTIME, "SCHEMA") else "public"
426
+
427
+ from sqlalchemy import MetaData
428
+
429
+ debug(f"[schema.reflect_schema] reflecting from database schema '{schema_name}'")
430
+ md = MetaData(schema=schema_name)
431
+ md.reflect(bind=engine)
432
+
433
+ tables: dict[str, TableMetadata] = {}
434
+
435
+ for t in md.tables.values():
436
+ columns: dict[str, ColumnMetadata] = {}
437
+ for c in t.columns:
438
+ columns[c.name] = ColumnMetadata(
439
+ name=c.name,
440
+ data_type=str(c.type),
441
+ is_primary_key=c.name in [pk.name for pk in t.primary_key.columns],
442
+ is_foreign_key=False,
443
+ fk_target=None,
444
+ )
445
+
446
+ tables[t.name] = TableMetadata(
447
+ name=t.name,
448
+ columns=columns,
449
+ primary_key=[c.name for c in t.primary_key.columns],
450
+ foreign_keys=[],
451
+ )
452
+
453
+ debug(f"[schema.reflect_schema] found {len(tables)} tables")
454
+
455
+ for t in md.tables.values():
456
+ for fk in t.foreign_key_constraints:
457
+ e = FKEdge(
458
+ src_table=t.name,
459
+ src_cols=[el.parent.name for el in fk.elements],
460
+ dst_table=fk.elements[0].column.table.name,
461
+ dst_cols=[el.column.name for el in fk.elements],
462
+ )
463
+ tables[t.name].foreign_keys.append(e)
464
+
465
+ for i, src_col in enumerate(e.src_cols):
466
+ if src_col in tables[t.name].columns:
467
+ tables[t.name].columns[src_col].is_foreign_key = True
468
+ tables[t.name].columns[src_col].fk_target = (
469
+ e.dst_table,
470
+ e.dst_cols[i],
471
+ )
472
+
473
+ debug(
474
+ f"[schema.reflect_schema] explicit FK: {e.src_table}.{e.src_cols[0]} -> {e.dst_table}.{e.dst_cols[0]}"
475
+ )
476
+
477
+ fk_count = sum(len(tbl.foreign_keys) for tbl in tables.values())
478
+ debug(f"[schema.reflect_schema] found {fk_count} foreign key edges")
479
+
480
+ inferred_fks = _infer_missing_fks(tables)
481
+ if inferred_fks:
482
+ debug(f"[schema.reflect_schema] inferred {len(inferred_fks)} missing FKs from naming conventions")
483
+ for e in inferred_fks:
484
+ tables[e.src_table].foreign_keys.append(e)
485
+ for i, src_col in enumerate(e.src_cols):
486
+ if src_col in tables[e.src_table].columns:
487
+ tables[e.src_table].columns[src_col].is_foreign_key = True
488
+ tables[e.src_table].columns[src_col].fk_target = (
489
+ e.dst_table,
490
+ e.dst_cols[i],
491
+ )
492
+ debug(f"[schema.reflect_schema] {e.src_table}.{e.src_cols[0]} -> {e.dst_table}.{e.dst_cols[0]}")
493
+
494
+ enum_values = _reflect_enum_values(engine, schema_name)
495
+
496
+ adj: dict[str, list[FKEdge]] = {t: [] for t in tables}
497
+ for tbl in tables.values():
498
+ for e in tbl.foreign_keys:
499
+ if e.src_table not in tables or e.dst_table not in tables:
500
+ continue
501
+ adj[e.src_table].append(e)
502
+ adj[e.dst_table].append(
503
+ FKEdge(
504
+ src_table=e.dst_table,
505
+ src_cols=e.dst_cols,
506
+ dst_table=e.src_table,
507
+ dst_cols=e.src_cols,
508
+ )
509
+ )
510
+ for t in adj:
511
+ adj[t] = sorted(adj[t], key=lambda x: _edge_key(x))
512
+
513
+ debug("[schema.reflect_schema] computing shortest join paths")
514
+ join_paths_multi: dict[str, dict[str, list[list[dict[str, Any]]]]] = {}
515
+
516
+ tlist = sorted(tables.keys())
517
+ for s in tlist:
518
+ join_paths_multi[s] = {}
519
+ queue = [s]
520
+ prev: dict[str, tuple[str, FKEdge]] = {}
521
+ seen = {s}
522
+ while queue:
523
+ cur = queue.pop(0)
524
+ for e in adj[cur]:
525
+ nxt = e.dst_table
526
+ if nxt in seen:
527
+ continue
528
+ seen.add(nxt)
529
+ prev[nxt] = (cur, e)
530
+ queue.append(nxt)
531
+
532
+ for t in tlist:
533
+ if t == s:
534
+ join_paths_multi[s][t] = [[]]
535
+ continue
536
+ if t in prev:
537
+ path = []
538
+ cur = t
539
+ while cur != s:
540
+ p, e = prev[cur]
541
+ path.append(e)
542
+ cur = p
543
+ path.reverse()
544
+ sp = [asdict(e) for e in path]
545
+ join_paths_multi[s][t] = [_normalize_fk_path(sp)]
546
+ else:
547
+ join_paths_multi[s][t] = []
548
+
549
+ schema_hash = schema_hash_fp({k: _table_to_dict(v) for k, v in tables.items()})
550
+
551
+ sg = SchemaGraph(
552
+ tables=tables,
553
+ join_paths_multi=join_paths_multi,
554
+ schema_hash=schema_hash,
555
+ created_at=datetime.now().isoformat(),
556
+ enum_values=enum_values,
557
+ )
558
+
559
+ return sg
560
+
561
+
562
+ def load_or_create_schema_graph(engine: Any = None) -> SchemaGraph:
563
+ """Load the schema graph from the JSON cache or build it fresh from the database.
564
+
565
+ If a valid cached schema exists (hash matches) it is returned without any database access. Otherwise the schema is reflected, profiled, and saved.
566
+
567
+ Args:
568
+
569
+ engine: SQLAlchemy ``Engine`` required for PostgreSQL reflection. Not used for Databricks.
570
+
571
+ Returns:
572
+
573
+ Fully populated and profiled ``SchemaGraph``.
574
+
575
+ Raises:
576
+
577
+ ValueError: If ``engine`` is ``None`` for PostgreSQL or the engine type is unsupported.
578
+ """
579
+ schema_json_path = EngineConfig.SCHEMA_JSON_PATH
580
+
581
+ debug(f"[schema.load_or_create_schema_graph] engine_type={EngineConfig.TYPE}")
582
+
583
+ if not PolicyConfig.REGENERATE_SCHEMA_GRAPH and os.path.exists(schema_json_path):
584
+ debug(f"[schema.load_or_create_schema_graph] loading from cache '{schema_json_path}'")
585
+ with open(schema_json_path, encoding="utf-8") as f:
586
+ d = json.load(f)
587
+
588
+ tables = {k: _table_from_dict(v) for k, v in d["tables"].items()}
589
+
590
+ recomputed_hash = schema_hash_fp(d["tables"])
591
+ cached_hash = d.get("schema_hash", "")
592
+ if recomputed_hash != cached_hash:
593
+ print("Schema cache outdated, rebuilding...")
594
+ debug(
595
+ f"[schema.load_or_create_schema_graph] hash mismatch: cached={cached_hash[:16]} recomputed={recomputed_hash[:16]}"
596
+ )
597
+ os.remove(schema_json_path)
598
+ else:
599
+ join_paths_multi = d.get("join_paths_multi")
600
+ if not join_paths_multi:
601
+ join_paths_multi = {}
602
+ jp = d.get("join_paths", {})
603
+ for a in jp:
604
+ join_paths_multi[a] = {}
605
+ for b in jp[a]:
606
+ join_paths_multi[a][b] = [jp[a][b]] if jp[a][b] is not None else []
607
+
608
+ debug(f"[schema.load_or_create_schema_graph] loaded {len(tables)} tables from cache")
609
+
610
+ sg = SchemaGraph(
611
+ tables=tables,
612
+ join_paths_multi=join_paths_multi,
613
+ schema_hash=d["schema_hash"],
614
+ enum_values=d.get("enum_values", {}),
615
+ schema_stats=d["schema_stats"],
616
+ )
617
+
618
+ return sg
619
+
620
+ debug("[schema.load_or_create_schema_graph] no cache found, building schema")
621
+
622
+ spark_session = None
623
+ databricks_connection = None
624
+ try:
625
+ if EngineConfig.TYPE == "databricks":
626
+ if getattr(EngineConfig.RUNTIME, "has_native_connection", lambda: False)():
627
+ try:
628
+ from databricks import sql as dbsql
629
+
630
+ databricks_connection = dbsql.connect(
631
+ server_hostname=EngineConfig.RUNTIME.SERVER_HOSTNAME,
632
+ http_path=EngineConfig.RUNTIME.HTTP_PATH,
633
+ access_token=EngineConfig.RUNTIME.ACCESS_TOKEN,
634
+ )
635
+ except Exception as e:
636
+ raise RuntimeError(f"Failed to connect via databricks-sql-connector: {e}") from e
637
+ else:
638
+ from pyspark.sql import SparkSession
639
+
640
+ spark_session = SparkSession.builder.getOrCreate()
641
+ sg = _load_or_create_schema_databricks(
642
+ spark_session=spark_session,
643
+ connection=databricks_connection,
644
+ )
645
+ elif EngineConfig.TYPE == "postgresql":
646
+ if engine is None:
647
+ raise ValueError("Engine is required for PostgreSQL schema reflection")
648
+ sg = _load_or_create_schema_postgresql(engine)
649
+ else:
650
+ raise ValueError(f"Unsupported engine type: {EngineConfig.TYPE}")
651
+
652
+ _add_profiling_data(engine, sg, spark_session, databricks_connection)
653
+ finally:
654
+ if databricks_connection is not None:
655
+ try:
656
+ databricks_connection.close()
657
+ except Exception:
658
+ pass
659
+
660
+ debug("[schema.load_or_create_schema_graph] computing schema stats after profiling")
661
+ sg.schema_stats = compute_schema_stats(sg)
662
+
663
+ _save_schema_to_cache(sg, schema_json_path)
664
+ debug("[schema.load_or_create_schema_graph] cache saved with profiling data")
665
+
666
+ return sg
667
+
668
+
669
+ def _tables_meta_to_schema_graph(tables_meta: dict[str, dict]) -> SchemaGraph:
670
+ """Convert a raw table metadata dictionary to a fully connected ``SchemaGraph``.
671
+
672
+ Builds ``ColumnMetadata`` and ``FKEdge`` objects, then runs BFS join-path computation between all table pairs.
673
+
674
+ Args:
675
+
676
+ tables_meta: Dictionary mapping table name to metadata dicts with keys ``column_names_original``, ``column_types``, ``primary_keys``, and ``foreign_keys``.
677
+
678
+ Returns:
679
+
680
+ ``SchemaGraph`` with ``tables``, ``join_paths_multi``, ``schema_hash``, and ``created_at``.
681
+ """
682
+ tables: dict[str, TableMetadata] = {}
683
+
684
+ for table_name, meta in tables_meta.items():
685
+ columns: dict[str, ColumnMetadata] = {}
686
+ col_names = meta.get("column_names_original", [])
687
+ col_types = meta.get("column_types", [])
688
+ pk_cols = meta.get("primary_keys", [])
689
+
690
+ for i, col_name in enumerate(col_names):
691
+ col_type = col_types[i] if i < len(col_types) else "UNKNOWN"
692
+ columns[col_name] = ColumnMetadata(
693
+ name=col_name,
694
+ data_type=col_type,
695
+ is_primary_key=col_name in pk_cols,
696
+ is_foreign_key=False,
697
+ fk_target=None,
698
+ )
699
+
700
+ fk_edges = []
701
+ for fk in meta.get("foreign_keys", []):
702
+ edge = FKEdge(
703
+ src_table=table_name,
704
+ src_cols=fk["src_cols"],
705
+ dst_table=fk["dst_table"],
706
+ dst_cols=fk["dst_cols"],
707
+ )
708
+ fk_edges.append(edge)
709
+
710
+ for i, src_col in enumerate(fk["src_cols"]):
711
+ if src_col in columns:
712
+ columns[src_col].is_foreign_key = True
713
+ columns[src_col].fk_target = (fk["dst_table"], fk["dst_cols"][i])
714
+
715
+ partition_cols = meta.get("partition_columns", [])
716
+ tables[table_name] = TableMetadata(
717
+ name=table_name,
718
+ columns=columns,
719
+ primary_key=pk_cols,
720
+ foreign_keys=fk_edges,
721
+ partition_columns=partition_cols,
722
+ )
723
+
724
+ fk_count = sum(len(tbl.foreign_keys) for tbl in tables.values())
725
+ debug(f"[schema.tables_meta_to_schema_graph] {len(tables)} tables, {fk_count} FK edges")
726
+
727
+ adj: dict[str, list[FKEdge]] = {t: [] for t in tables}
728
+ for tbl in tables.values():
729
+ for e in tbl.foreign_keys:
730
+ if e.src_table not in tables or e.dst_table not in tables:
731
+ continue
732
+ adj[e.src_table].append(e)
733
+ adj[e.dst_table].append(
734
+ FKEdge(
735
+ src_table=e.dst_table,
736
+ src_cols=e.dst_cols,
737
+ dst_table=e.src_table,
738
+ dst_cols=e.src_cols,
739
+ )
740
+ )
741
+ for t in adj:
742
+ adj[t] = sorted(adj[t], key=lambda x: _edge_key(x))
743
+
744
+ debug("[schema.tables_meta_to_schema_graph] computing shortest join paths")
745
+
746
+ join_paths_multi: dict[str, dict[str, list[list[dict[str, Any]]]]] = {}
747
+ tlist = sorted(tables.keys())
748
+
749
+ for s in tlist:
750
+ join_paths_multi[s] = {}
751
+ queue = [s]
752
+ prev: dict[str, tuple[str, FKEdge]] = {}
753
+ seen = {s}
754
+ while queue:
755
+ cur = queue.pop(0)
756
+ for e in adj[cur]:
757
+ nxt = e.dst_table
758
+ if nxt in seen:
759
+ continue
760
+ seen.add(nxt)
761
+ prev[nxt] = (cur, e)
762
+ queue.append(nxt)
763
+
764
+ for t in tlist:
765
+ if t == s:
766
+ join_paths_multi[s][t] = [[]]
767
+ continue
768
+ if t in prev:
769
+ path = []
770
+ cur = t
771
+ while cur != s:
772
+ p, e = prev[cur]
773
+ path.append(e)
774
+ cur = p
775
+ path.reverse()
776
+ sp = [asdict(e) for e in path]
777
+ join_paths_multi[s][t] = [_normalize_fk_path(sp)]
778
+ else:
779
+ join_paths_multi[s][t] = []
780
+
781
+ schema_hash = schema_hash_fp({k: _table_to_dict(v) for k, v in tables.items()})
782
+
783
+ sg = SchemaGraph(
784
+ tables=tables,
785
+ join_paths_multi=join_paths_multi,
786
+ schema_hash=schema_hash,
787
+ created_at=datetime.now().isoformat(),
788
+ enum_values={},
789
+ )
790
+
791
+ return sg
792
+
793
+
794
+ def _load_or_create_schema_postgresql(engine: Any) -> SchemaGraph:
795
+ """Build a ``SchemaGraph`` for PostgreSQL from a live database or SQL file fallback.
796
+
797
+ Attempts live reflection first; falls back to parsing a SQL DDL file if reflection fails and ``EngineConfig.RUNTIME.SQL_FILE_PATH`` is configured.
798
+
799
+ Args:
800
+
801
+ engine: SQLAlchemy ``Engine`` connected to the PostgreSQL database.
802
+
803
+ Returns:
804
+
805
+ ``SchemaGraph`` built from the database or SQL file.
806
+
807
+ Raises:
808
+
809
+ ValueError: If reflection fails and no SQL file is available or parseable.
810
+ """
811
+ try:
812
+ debug("[schema.load_or_create_schema_postgresql] reflecting_database")
813
+ sg = _reflect_schema(engine)
814
+ debug(f"[schema.load_or_create_schema_postgresql] reflected: {len(sg.tables)} tables")
815
+ return sg
816
+ except Exception as e:
817
+ debug(f"[schema.load_or_create_schema_postgresql] reflection_failed: {e}")
818
+ sql_file_path = getattr(EngineConfig.RUNTIME, "SQL_FILE_PATH", None)
819
+
820
+ if sql_file_path and os.path.exists(sql_file_path):
821
+ debug(f"[schema.load_or_create_schema_postgresql] parsing_sql_file: {sql_file_path}")
822
+ tables_meta = parse_sql_file(Path(sql_file_path))
823
+
824
+ if not tables_meta or len(tables_meta) == 0:
825
+ raise ValueError("Both database reflection and SQL file parsing failed") from e
826
+
827
+ return _tables_meta_to_schema_graph(tables_meta)
828
+ else:
829
+ raise ValueError(f"Database reflection failed and no SQL file available: {e}") from e
830
+
831
+
832
+ def _add_profiling_data(
833
+ engine: Any,
834
+ sg: SchemaGraph,
835
+ spark_session=None,
836
+ connection=None,
837
+ ) -> None:
838
+ """Add column profiling data to a SchemaGraph in-place.
839
+
840
+ Runs the three-step profiling pipeline: 1) ``profile_schema`` — collect statistics
841
+ (distinct count, null ratio, min/max, top-K). 2) ``apply_column_roles_llm`` — LLM
842
+ assigns table and column roles with overrides. 3) ``assign_column_ops`` —
843
+ deterministic operation flags based on final roles.
844
+
845
+ Falls back to heuristics after ``MAX_ROLE_CLASSIFICATION_RETRIES`` LLM failures.
846
+
847
+ Args:
848
+
849
+ engine: SQLAlchemy ``Engine`` for PostgreSQL profiling; unused for Databricks.
850
+
851
+ sg: ``SchemaGraph`` to enrich in-place.
852
+
853
+ spark_session: Active ``SparkSession`` for Databricks Spark-based profiling.
854
+
855
+ connection: Active ``databricks.sql`` connection for connector-based profiling.
856
+ """
857
+ debug("[schema.add_profiling_data] Step 1: profiling columns (statistics)")
858
+
859
+ if EngineConfig.TYPE == "databricks":
860
+ catalog = EngineConfig.RUNTIME.CATALOG
861
+ schema_name = EngineConfig.RUNTIME.SCHEMA
862
+ if connection is not None:
863
+ profile_schema_sql_connector(connection, catalog, schema_name, sg)
864
+ else:
865
+ if spark_session is None:
866
+ from pyspark.sql import SparkSession
867
+
868
+ spark_session = SparkSession.builder.getOrCreate()
869
+ profile_schema_spark(spark_session, catalog, schema_name, sg)
870
+ else:
871
+ profile_schema(engine, sg)
872
+
873
+ debug("[schema.add_profiling_data] Step 2: inferring column and table roles via LLM")
874
+ apply_column_roles_llm(sg)
875
+
876
+ debug("[schema.add_profiling_data] Step 3: assigning column operations (deterministic)")
877
+ assign_column_ops(sg)
878
+
879
+
880
+ def _save_schema_to_cache(sg: SchemaGraph, schema_json_path: str) -> None:
881
+ """Save a SchemaGraph to a JSON cache file.
882
+
883
+ Args:
884
+
885
+ sg: ``SchemaGraph`` to persist.
886
+
887
+ schema_json_path: Absolute path to the output JSON file.
888
+ """
889
+
890
+ def _dump_mixed(fh, data: dict[str, Any]):
891
+ fh.write("{")
892
+ first = True
893
+ for k in sorted(data.keys()):
894
+ if not first:
895
+ fh.write(",")
896
+ first = False
897
+ fh.write(json.dumps(k, ensure_ascii=False))
898
+ fh.write(":")
899
+ v = data[k]
900
+ if isinstance(v, dict):
901
+ json.dump(v, fh, indent=2, sort_keys=True, ensure_ascii=False)
902
+ else:
903
+ json.dump(v, fh, separators=(",", ":"), ensure_ascii=False)
904
+ fh.write("}")
905
+
906
+ cache_data = {
907
+ "tables": {k: _table_to_dict(v) for k, v in sg.tables.items()},
908
+ "join_paths_multi": sg.join_paths_multi,
909
+ "schema_hash": sg.schema_hash,
910
+ "created_at": sg.created_at,
911
+ "enum_values": sg.enum_values or {},
912
+ "schema_stats": sg.schema_stats or {},
913
+ }
914
+
915
+ debug(f"[schema.save_schema_to_cache] saving to '{schema_json_path}'")
916
+ with open(schema_json_path, "w", encoding="utf-8") as f:
917
+ _dump_mixed(f, cache_data)
918
+
919
+
920
+ def _load_or_create_schema_databricks(
921
+ spark_session=None,
922
+ connection=None,
923
+ ) -> SchemaGraph:
924
+ """Build a ``SchemaGraph`` for Databricks from a SQL DDL file or catalog introspection.
925
+
926
+ Parses ``EngineConfig.RUNTIME.SQL_FILE_PATH`` if it exists and is non-empty; otherwise
927
+ extracts table metadata from the Databricks catalog via databricks-sql-connector
928
+ (when ``connection`` is provided) or Spark.
929
+
930
+ Args:
931
+
932
+ spark_session: Active ``SparkSession`` for Spark-based catalog extraction.
933
+
934
+ connection: Active ``databricks.sql`` connection for connector-based extraction.
935
+
936
+ Returns:
937
+
938
+ ``SchemaGraph`` built from the SQL file or catalog.
939
+
940
+ Raises:
941
+
942
+ RuntimeError: If no SQL file is available and catalog extraction fails.
943
+ """
944
+ sql_file_path = getattr(EngineConfig.RUNTIME, "SQL_FILE_PATH", None)
945
+
946
+ if sql_file_path and os.path.exists(sql_file_path):
947
+ debug(f"[schema.load_or_create_schema_databricks] parsing SQL file '{sql_file_path}'")
948
+ tables_meta = parse_sql_file(Path(sql_file_path))
949
+
950
+ if not tables_meta or len(tables_meta) == 0:
951
+ debug(
952
+ "[schema.load_or_create_schema_databricks] SQL file parsing returned 0 tables, falling back to catalog extraction"
953
+ )
954
+ sql_file_path = None
955
+
956
+ if not sql_file_path or not os.path.exists(sql_file_path):
957
+ debug("[schema.load_or_create_schema_databricks] SQL file not found or empty, attempting catalog extraction")
958
+ catalog = EngineConfig.RUNTIME.CATALOG
959
+ schema = EngineConfig.RUNTIME.SCHEMA
960
+ try:
961
+ if connection is not None:
962
+ tables_meta = extract_tables_from_catalog_sql_connector(
963
+ connection, catalog, schema
964
+ )
965
+ else:
966
+ from pyspark.sql import SparkSession
967
+
968
+ spark = spark_session if spark_session else SparkSession.builder.getOrCreate()
969
+ tables_meta = extract_tables_from_catalog(spark, catalog, schema)
970
+ except Exception as e:
971
+ raise RuntimeError(f"Cannot build schema: no SQL file and catalog extraction failed: {e}") from e
972
+
973
+ return _tables_meta_to_schema_graph(tables_meta)