aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2075 @@
1
+ """Column and table profiling, role inference, schema parsing, and partition filter injection.
2
+
3
+ Profiles columns for statistical properties (distinct count, null ratio, min/max,
4
+ top-K values), supporting both a local SQLAlchemy engine and a Databricks Spark
5
+ session. Infers column roles (CATEGORICAL, NUMERIC_MEASURE, TEMPORAL, IDENTIFIER,
6
+ BOOLEAN, FREE_TEXT, NUMERIC_CATEGORICAL, AUDIT) and table roles (FACT, DIMENSION,
7
+ BRIDGE) via a single schema-wide LLM call with a heuristic fallback. Also provides
8
+ SQL DDL parsing (sqlglot + LLM fallback), Unity Catalog constraint extraction,
9
+ and partition filter injection for Databricks Spark SQL.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ from .config import (
19
+ BOOLEAN_TRUE_FALSE_MAP,
20
+ BOOLEAN_VALUE_PATTERNS,
21
+ EngineConfig,
22
+ PolicyConfig,
23
+ QSimConfig,
24
+ )
25
+ from .contracts_base import (
26
+ ColumnMetadata,
27
+ ColumnRole,
28
+ SchemaGraph,
29
+ TableMetadata,
30
+ TableRole,
31
+ )
32
+ from .contracts_core import FilterParam, NormalizedExpr, RuntimeIntent
33
+ from .core_utils import debug, llm_chat, safe_json_loads, stable_json
34
+ from .utils import flatten_param_values
35
+
36
+
37
+ def _has_boolean_like_values(col: ColumnMetadata) -> bool:
38
+ """Check if a column's top-K values match a known boolean-like pattern.
39
+
40
+ Handles both string and numeric values (e.g. integer 0/1).
41
+
42
+ Args:
43
+
44
+ col: The ``ColumnMetadata`` to inspect.
45
+
46
+ Returns:
47
+
48
+ ``True`` if ``distinct_count`` is exactly 2 and the two top-K values form a recognised boolean pair.
49
+ """
50
+ if col.distinct_count != 2:
51
+ return False
52
+ if not col.top_k_values or len(col.top_k_values) != 2:
53
+ return False
54
+ values_lower = frozenset(str(v).lower().strip() for v in col.top_k_values)
55
+ return values_lower in BOOLEAN_VALUE_PATTERNS
56
+
57
+
58
+ def _populate_boolean_mapping(col: ColumnMetadata) -> None:
59
+ """Populate ``boolean_true_value`` and ``boolean_false_value`` on a boolean-like column.
60
+
61
+ For native boolean data types the canonical ``"true"``/``"false"`` pair is assigned directly. For columns whose profiled ``top_k_values`` match a known pattern in ``_BOOLEAN_TRUE_FALSE_MAP`` the actual database-cased values are stored so downstream normalisation preserves the original representation.
62
+
63
+ Args:
64
+
65
+ col: The ``ColumnMetadata`` to update in-place.
66
+ """
67
+ dtype_lower = (col.data_type or "").lower()
68
+ if "bool" in dtype_lower:
69
+ col.boolean_true_value = "true"
70
+ col.boolean_false_value = "false"
71
+ return
72
+ if not col.top_k_values or len(col.top_k_values) != 2:
73
+ return
74
+ values_lower = frozenset(str(v).lower().strip() for v in col.top_k_values)
75
+ true_false = BOOLEAN_TRUE_FALSE_MAP.get(values_lower)
76
+ if not true_false:
77
+ return
78
+ true_lower, false_lower = true_false
79
+ for v in col.top_k_values:
80
+ v_lower = str(v).lower().strip()
81
+ if v_lower == true_lower:
82
+ col.boolean_true_value = str(v)
83
+ elif v_lower == false_lower:
84
+ col.boolean_false_value = str(v)
85
+
86
+
87
+ def _profile_column(
88
+ engine: Any,
89
+ col: ColumnMetadata,
90
+ table_name: str,
91
+ row_count: int,
92
+ sample_threshold: int = None,
93
+ sample_size: int = None,
94
+ ) -> None:
95
+ """Profile a single column and update its metadata in-place.
96
+
97
+ Computes distinct count, distinct ratio, null ratio, min/max (for numeric and date columns), and top-K values. Uses sampling for large tables.
98
+
99
+ Args:
100
+
101
+ engine: SQLAlchemy engine connected to the target database.
102
+
103
+ col: The ``ColumnMetadata`` to update.
104
+
105
+ table_name: The name of the table containing the column.
106
+
107
+ row_count: Total row count for the table (used for sample size calculation).
108
+
109
+ sample_threshold: Row count above which sampling is used; defaults to ``QSimConfig.PROFILING_SAMPLE_THRESHOLD``.
110
+
111
+ sample_size: Number of rows to sample; defaults to ``QSimConfig.PROFILING_SAMPLE_SIZE``.
112
+ """
113
+ debug(f"[schema_profiling.profile_column] profiling {table_name}.{col.name}")
114
+ if sample_threshold is None:
115
+ sample_threshold = QSimConfig.PROFILING_SAMPLE_THRESHOLD
116
+ if sample_size is None:
117
+ sample_size = QSimConfig.PROFILING_SAMPLE_SIZE
118
+
119
+ col.row_count = row_count
120
+ use_sample = row_count > sample_threshold
121
+
122
+ from sqlalchemy import text
123
+
124
+ try:
125
+ with engine.connect() as conn:
126
+ if use_sample:
127
+ if EngineConfig.TYPE == "postgresql":
128
+ sample_clause = f"TABLESAMPLE BERNOULLI ({100 * sample_size / row_count:.2f}) REPEATABLE ({QSimConfig.RANDOM_SEED})"
129
+ else:
130
+ sample_clause = f"LIMIT {sample_size}"
131
+ else:
132
+ sample_clause = ""
133
+
134
+ if use_sample and EngineConfig.TYPE == "postgresql":
135
+ stats_sql = f"""
136
+ SELECT
137
+ COUNT(*) as cnt,
138
+ COUNT(DISTINCT "{col.name}") as dist,
139
+ COUNT(*) - COUNT("{col.name}") as nulls
140
+ FROM "{table_name}" {sample_clause}
141
+ """
142
+ elif use_sample:
143
+ stats_sql = f"""
144
+ SELECT
145
+ COUNT(*) as cnt,
146
+ COUNT(DISTINCT "{col.name}") as dist,
147
+ COUNT(*) - COUNT("{col.name}") as nulls
148
+ FROM (SELECT "{col.name}" FROM "{table_name}" {sample_clause}) t
149
+ """
150
+ else:
151
+ stats_sql = f"""
152
+ SELECT
153
+ COUNT(*) as cnt,
154
+ COUNT(DISTINCT "{col.name}") as dist,
155
+ COUNT(*) - COUNT("{col.name}") as nulls
156
+ FROM "{table_name}"
157
+ """
158
+
159
+ result = conn.execute(text(stats_sql)).fetchone()
160
+ cnt = result[0] or 1
161
+ dist = result[1] or 0
162
+ nulls = result[2] or 0
163
+
164
+ col.distinct_count = dist
165
+ col.distinct_ratio = dist / cnt if cnt > 0 else 0.0
166
+ col.null_ratio = nulls / cnt if cnt > 0 else 0.0
167
+
168
+ if col.value_type in ("integer", "number") or col.value_type == "date":
169
+ minmax_sql = f'SELECT MIN("{col.name}"), MAX("{col.name}") FROM "{table_name}"'
170
+ minmax_result = conn.execute(text(minmax_sql)).fetchone()
171
+ if minmax_result:
172
+ col.min_val = str(minmax_result[0]) if minmax_result[0] is not None else None
173
+ col.max_val = str(minmax_result[1]) if minmax_result[1] is not None else None
174
+
175
+ topk_sql = f"""
176
+ SELECT "{col.name}", COUNT(*) as freq
177
+ FROM "{table_name}"
178
+ WHERE "{col.name}" IS NOT NULL
179
+ GROUP BY "{col.name}"
180
+ ORDER BY freq DESC
181
+ LIMIT {PolicyConfig.CATEGORICAL_SAMPLE_SIZE}
182
+ """
183
+ topk_result = conn.execute(text(topk_sql)).fetchall()
184
+ col.top_k_values = [str(row[0]) for row in topk_result if row[0] is not None]
185
+ except Exception as e:
186
+ debug(f"[schema_profiling.profile_column] failed: {table_name}.{col.name}: {e}")
187
+
188
+
189
+ _NAME_COLUMN_PATTERN = re.compile(r"(first.?name|last.?name|given.?name|family.?name)", re.IGNORECASE)
190
+
191
+
192
+ def _profile_composite_descriptive(
193
+ engine: Any,
194
+ table: TableMetadata,
195
+ ) -> None:
196
+ """Compute composite distinct ratios for name-like column pairs.
197
+
198
+ Detects pairs of string columns whose names match common
199
+ name patterns (first_name / last_name) and measures the
200
+ distinct ratio of their concatenation. Results are stored in
201
+ ``table.composite_descriptive_ratios``.
202
+ """
203
+ name_cols = [
204
+ col_name
205
+ for col_name, col_meta in table.columns.items()
206
+ if (col_meta.value_type or "").lower() == "string"
207
+ and _NAME_COLUMN_PATTERN.search(col_name)
208
+ and not col_meta.is_primary_key
209
+ and not col_meta.is_foreign_key
210
+ ]
211
+ if len(name_cols) < 2:
212
+ return
213
+ row_count = table.row_count or 0
214
+ if row_count == 0:
215
+ return
216
+
217
+ from sqlalchemy import text
218
+
219
+ try:
220
+ with engine.connect() as conn:
221
+ for i in range(len(name_cols)):
222
+ for j in range(i + 1, len(name_cols)):
223
+ c1, c2 = name_cols[i], name_cols[j]
224
+ sql = (
225
+ f'SELECT COUNT(DISTINCT CONCAT("{c1}", \' \', "{c2}")) '
226
+ f'FROM "{table.name}"'
227
+ )
228
+ composite_distinct = conn.execute(text(sql)).scalar() or 0
229
+ ratio = composite_distinct / row_count
230
+ table.composite_descriptive_ratios[(c1, c2)] = ratio
231
+ debug(
232
+ f"[schema_profiling._profile_composite_descriptive] "
233
+ f"{table.name}.({c1}, {c2}) composite_ratio={ratio:.4f}"
234
+ )
235
+ except Exception as exc:
236
+ debug(
237
+ f"[schema_profiling._profile_composite_descriptive] "
238
+ f"failed for {table.name}: {exc}"
239
+ )
240
+
241
+
242
+ def _profile_table(engine: Any, table: TableMetadata) -> None:
243
+ """Profile all columns in a table and update metadata in-place.
244
+
245
+ Args:
246
+
247
+ engine: SQLAlchemy engine connected to the target database.
248
+
249
+ table: The ``TableMetadata`` to update (row_count and all column stats).
250
+ """
251
+ from sqlalchemy import text
252
+
253
+ debug(f"[schema_profiling.profile_table] profiling {table.name} ({len(table.columns)} columns)")
254
+ try:
255
+ with engine.connect() as conn:
256
+ count_sql = f'SELECT COUNT(*) FROM "{table.name}"'
257
+ row_count = conn.execute(text(count_sql)).scalar() or 0
258
+ table.row_count = row_count
259
+ except Exception as e:
260
+ debug(f"[schema_profiling.profile_table] row count failed: {table.name}: {e}")
261
+ row_count = 0
262
+ table.row_count = 0
263
+
264
+ for col in table.columns.values():
265
+ _profile_column(engine, col, table.name, row_count)
266
+
267
+ _profile_composite_descriptive(engine, table)
268
+
269
+ debug(f"[schema_profiling.profile_table] completed: {table.name}")
270
+
271
+
272
+ def profile_schema(engine: Any, schema: SchemaGraph) -> None:
273
+ """Profile all tables in a schema and update metadata in-place.
274
+
275
+ Args:
276
+
277
+ engine: SQLAlchemy engine connected to the target database.
278
+
279
+ schema: The ``SchemaGraph`` whose tables will be profiled.
280
+ """
281
+ debug(f"[schema_profiling.profile_schema] profiling {len(schema.tables)} tables")
282
+ for table in schema.tables.values():
283
+ _profile_table(engine, table)
284
+ debug("[schema_profiling.profile_schema] completed")
285
+
286
+
287
+ def _profile_column_spark(
288
+ spark,
289
+ catalog: str,
290
+ schema_name: str,
291
+ col: ColumnMetadata,
292
+ table_name: str,
293
+ row_count: int,
294
+ sample_threshold: int = None,
295
+ sample_size: int = None,
296
+ ) -> None:
297
+ """Profile a single column from a Databricks table via Spark SQL and update metadata in-place.
298
+
299
+ Args:
300
+
301
+ spark: Active ``SparkSession``.
302
+
303
+ catalog: The Unity Catalog name.
304
+
305
+ schema_name: The schema (database) name within the catalog.
306
+
307
+ col: The ``ColumnMetadata`` to update.
308
+
309
+ table_name: The table name within the schema.
310
+
311
+ row_count: Total row count used for sampling decisions.
312
+
313
+ sample_threshold: Row count above which sampling is used; defaults to ``QSimConfig.PROFILING_SAMPLE_THRESHOLD``.
314
+
315
+ sample_size: Number of rows to sample; defaults to ``QSimConfig.PROFILING_SAMPLE_SIZE``.
316
+ """
317
+ debug(f"[schema_profiling.profile_column_spark] profiling {table_name}.{col.name}")
318
+ if sample_threshold is None:
319
+ sample_threshold = QSimConfig.PROFILING_SAMPLE_THRESHOLD
320
+ if sample_size is None:
321
+ sample_size = QSimConfig.PROFILING_SAMPLE_SIZE
322
+
323
+ col.row_count = row_count
324
+ use_sample = row_count > sample_threshold
325
+
326
+ try:
327
+ full_table = f"`{catalog}`.`{schema_name}`.`{table_name}`"
328
+
329
+ if use_sample:
330
+ sample_clause = f"TABLESAMPLE ({sample_size} ROWS)"
331
+ else:
332
+ sample_clause = ""
333
+
334
+ if use_sample:
335
+ stats_sql = f"""
336
+ SELECT
337
+ COUNT(*) as cnt,
338
+ COUNT(DISTINCT `{col.name}`) as dist,
339
+ COUNT(*) - COUNT(`{col.name}`) as nulls
340
+ FROM {full_table} {sample_clause}
341
+ """
342
+ else:
343
+ stats_sql = f"""
344
+ SELECT
345
+ COUNT(*) as cnt,
346
+ COUNT(DISTINCT `{col.name}`) as dist,
347
+ COUNT(*) - COUNT(`{col.name}`) as nulls
348
+ FROM {full_table}
349
+ """
350
+
351
+ result = spark.sql(stats_sql).collect()[0]
352
+ cnt = result["cnt"] or 1
353
+ dist = result["dist"] or 0
354
+ nulls = result["nulls"] or 0
355
+
356
+ col.distinct_count = dist
357
+ col.distinct_ratio = dist / cnt if cnt > 0 else 0.0
358
+ col.null_ratio = nulls / cnt if cnt > 0 else 0.0
359
+
360
+ if col.value_type in ("integer", "number") or col.value_type == "date":
361
+ minmax_sql = f"SELECT MIN(`{col.name}`), MAX(`{col.name}`) FROM {full_table}"
362
+ minmax_result = spark.sql(minmax_sql).collect()[0]
363
+ if minmax_result:
364
+ col.min_val = str(minmax_result[0]) if minmax_result[0] is not None else None
365
+ col.max_val = str(minmax_result[1]) if minmax_result[1] is not None else None
366
+
367
+ topk_sql = f"""
368
+ SELECT `{col.name}`, COUNT(*) as freq
369
+ FROM {full_table}
370
+ WHERE `{col.name}` IS NOT NULL
371
+ GROUP BY `{col.name}`
372
+ ORDER BY freq DESC
373
+ LIMIT {PolicyConfig.CATEGORICAL_SAMPLE_SIZE}
374
+ """
375
+ topk_result = spark.sql(topk_sql).collect()
376
+ col.top_k_values = [str(row[col.name]) for row in topk_result if row[col.name] is not None]
377
+ except Exception as e:
378
+ debug(f"[schema_profiling.profile_column_spark] failed: {table_name}.{col.name}: {e}")
379
+
380
+
381
+ def _profile_composite_descriptive_spark(
382
+ spark,
383
+ catalog: str,
384
+ schema_name: str,
385
+ table: TableMetadata,
386
+ ) -> None:
387
+ """Compute composite distinct ratios for name-like column pairs via Spark.
388
+
389
+ Spark equivalent of ``_profile_composite_descriptive``.
390
+ """
391
+ name_cols = [
392
+ col_name
393
+ for col_name, col_meta in table.columns.items()
394
+ if (col_meta.value_type or "").lower() == "string"
395
+ and _NAME_COLUMN_PATTERN.search(col_name)
396
+ and not col_meta.is_primary_key
397
+ and not col_meta.is_foreign_key
398
+ ]
399
+ if len(name_cols) < 2:
400
+ return
401
+ row_count = table.row_count or 0
402
+ if row_count == 0:
403
+ return
404
+ full_table = f"`{catalog}`.`{schema_name}`.`{table.name}`"
405
+ try:
406
+ for i in range(len(name_cols)):
407
+ for j in range(i + 1, len(name_cols)):
408
+ c1, c2 = name_cols[i], name_cols[j]
409
+ sql = (
410
+ f"SELECT COUNT(DISTINCT CONCAT(`{c1}`, ' ', `{c2}`)) "
411
+ f"FROM {full_table}"
412
+ )
413
+ composite_distinct = spark.sql(sql).collect()[0][0] or 0
414
+ ratio = composite_distinct / row_count
415
+ table.composite_descriptive_ratios[(c1, c2)] = ratio
416
+ debug(
417
+ f"[schema_profiling._profile_composite_descriptive_spark] "
418
+ f"{table.name}.({c1}, {c2}) composite_ratio={ratio:.4f}"
419
+ )
420
+ except Exception as exc:
421
+ debug(
422
+ f"[schema_profiling._profile_composite_descriptive_spark] "
423
+ f"failed for {table.name}: {exc}"
424
+ )
425
+
426
+
427
+ def _profile_table_spark(spark, catalog: str, schema_name: str, table: TableMetadata) -> None:
428
+ """Profile all columns in a Databricks table via Spark queries.
429
+
430
+ Args:
431
+
432
+ spark: Active ``SparkSession``.
433
+
434
+ catalog: The Unity Catalog name.
435
+
436
+ schema_name: The schema (database) name.
437
+
438
+ table: The ``TableMetadata`` to update.
439
+ """
440
+ debug(f"[schema_profiling.profile_table_spark] profiling {table.name} ({len(table.columns)} columns)")
441
+ try:
442
+ full_table = f"`{catalog}`.`{schema_name}`.`{table.name}`"
443
+ count_sql = f"SELECT COUNT(*) FROM {full_table}"
444
+ row_count = spark.sql(count_sql).collect()[0][0] or 0
445
+ table.row_count = row_count
446
+ except Exception as e:
447
+ debug(f"[schema_profiling.profile_table_spark] row count failed: {table.name}: {e}")
448
+ row_count = 0
449
+ table.row_count = 0
450
+
451
+ for col in table.columns.values():
452
+ _profile_column_spark(spark, catalog, schema_name, col, table.name, row_count)
453
+
454
+ _profile_composite_descriptive_spark(spark, catalog, schema_name, table)
455
+
456
+ debug(f"[schema_profiling.profile_table_spark] completed: {table.name}")
457
+
458
+
459
+ def profile_schema_spark(spark, catalog: str, schema_name: str, schema: SchemaGraph) -> None:
460
+ """Profile all tables in a Databricks schema via Spark queries.
461
+
462
+ Args:
463
+
464
+ spark: Active ``SparkSession``.
465
+
466
+ catalog: The Unity Catalog name.
467
+
468
+ schema_name: The schema (database) name.
469
+
470
+ schema: The ``SchemaGraph`` whose tables will be profiled.
471
+ """
472
+ debug(f"[schema_profiling.profile_schema_spark] profiling {len(schema.tables)} tables")
473
+ for table in schema.tables.values():
474
+ _profile_table_spark(spark, catalog, schema_name, table)
475
+ debug("[schema_profiling.profile_schema_spark] completed")
476
+
477
+
478
+ def _cursor_rows_as_dicts(cursor) -> list[dict]:
479
+ """Convert cursor result rows to list of dicts keyed by column name."""
480
+ if not cursor.description:
481
+ return []
482
+ col_names = [d[0] for d in cursor.description]
483
+ return [dict(zip(col_names, row)) for row in cursor.fetchall()]
484
+
485
+
486
+ def _profile_column_sql_connector(
487
+ connection,
488
+ catalog: str,
489
+ schema_name: str,
490
+ col: ColumnMetadata,
491
+ table_name: str,
492
+ row_count: int,
493
+ sample_threshold: int = None,
494
+ sample_size: int = None,
495
+ ) -> None:
496
+ """Profile a single column via databricks-sql-connector and update metadata in-place."""
497
+ if sample_threshold is None:
498
+ sample_threshold = QSimConfig.PROFILING_SAMPLE_THRESHOLD
499
+ if sample_size is None:
500
+ sample_size = QSimConfig.PROFILING_SAMPLE_SIZE
501
+ col.row_count = row_count
502
+ use_sample = row_count > sample_threshold
503
+ full_table = f"`{catalog}`.`{schema_name}`.`{table_name}`"
504
+ sample_clause = f"TABLESAMPLE ({sample_size} ROWS)" if use_sample else ""
505
+ try:
506
+ with connection.cursor() as cursor:
507
+ if use_sample:
508
+ stats_sql = f"""
509
+ SELECT
510
+ COUNT(*) as cnt,
511
+ COUNT(DISTINCT `{col.name}`) as dist,
512
+ COUNT(*) - COUNT(`{col.name}`) as nulls
513
+ FROM {full_table} {sample_clause}
514
+ """
515
+ else:
516
+ stats_sql = f"""
517
+ SELECT
518
+ COUNT(*) as cnt,
519
+ COUNT(DISTINCT `{col.name}`) as dist,
520
+ COUNT(*) - COUNT(`{col.name}`) as nulls
521
+ FROM {full_table}
522
+ """
523
+ cursor.execute(stats_sql)
524
+ rows = _cursor_rows_as_dicts(cursor)
525
+ if rows:
526
+ r = rows[0]
527
+ cnt = r.get("cnt") or 1
528
+ dist = r.get("dist") or 0
529
+ nulls = r.get("nulls") or 0
530
+ col.distinct_count = dist
531
+ col.distinct_ratio = dist / cnt if cnt > 0 else 0.0
532
+ col.null_ratio = nulls / cnt if cnt > 0 else 0.0
533
+ if col.value_type in ("integer", "number") or col.value_type == "date":
534
+ minmax_sql = f"SELECT MIN(`{col.name}`) as mn, MAX(`{col.name}`) as mx FROM {full_table}"
535
+ cursor.execute(minmax_sql)
536
+ minmax_rows = _cursor_rows_as_dicts(cursor)
537
+ if minmax_rows:
538
+ r = minmax_rows[0]
539
+ col.min_val = str(r["mn"]) if r.get("mn") is not None else None
540
+ col.max_val = str(r["mx"]) if r.get("mx") is not None else None
541
+ topk_sql = f"""
542
+ SELECT `{col.name}` as topval, COUNT(*) as freq
543
+ FROM {full_table}
544
+ WHERE `{col.name}` IS NOT NULL
545
+ GROUP BY `{col.name}`
546
+ ORDER BY freq DESC
547
+ LIMIT {PolicyConfig.CATEGORICAL_SAMPLE_SIZE}
548
+ """
549
+ cursor.execute(topk_sql)
550
+ topk_rows = _cursor_rows_as_dicts(cursor)
551
+ col.top_k_values = [
552
+ str(r["topval"]) for r in topk_rows if r and r.get("topval") is not None
553
+ ]
554
+ except Exception as e:
555
+ debug(f"[schema_profiling._profile_column_sql_connector] failed: {table_name}.{col.name}: {e}")
556
+
557
+
558
+ def _profile_composite_descriptive_sql_connector(
559
+ connection,
560
+ catalog: str,
561
+ schema_name: str,
562
+ table: TableMetadata,
563
+ ) -> None:
564
+ """Compute composite distinct ratios for name-like column pairs via SQL connector."""
565
+ name_cols = [
566
+ col_name
567
+ for col_name, col_meta in table.columns.items()
568
+ if (col_meta.value_type or "").lower() == "string"
569
+ and _NAME_COLUMN_PATTERN.search(col_name)
570
+ and not col_meta.is_primary_key
571
+ and not col_meta.is_foreign_key
572
+ ]
573
+ if len(name_cols) < 2:
574
+ return
575
+ row_count = table.row_count or 0
576
+ if row_count == 0:
577
+ return
578
+ full_table = f"`{catalog}`.`{schema_name}`.`{table.name}`"
579
+ try:
580
+ with connection.cursor() as cursor:
581
+ for i in range(len(name_cols)):
582
+ for j in range(i + 1, len(name_cols)):
583
+ c1, c2 = name_cols[i], name_cols[j]
584
+ sql = f"SELECT COUNT(DISTINCT CONCAT(`{c1}`, ' ', `{c2}`)) FROM {full_table}"
585
+ cursor.execute(sql)
586
+ rows = _cursor_rows_as_dicts(cursor)
587
+ composite_distinct = 0
588
+ if rows and rows[0]:
589
+ composite_distinct = list(rows[0].values())[0] or 0
590
+ ratio = composite_distinct / row_count
591
+ table.composite_descriptive_ratios[(c1, c2)] = ratio
592
+ debug(
593
+ f"[schema_profiling._profile_composite_descriptive_sql_connector] "
594
+ f"{table.name}.({c1}, {c2}) composite_ratio={ratio:.4f}"
595
+ )
596
+ except Exception as exc:
597
+ debug(
598
+ f"[schema_profiling._profile_composite_descriptive_sql_connector] "
599
+ f"failed for {table.name}: {exc}"
600
+ )
601
+
602
+
603
+ def _profile_table_sql_connector(
604
+ connection,
605
+ catalog: str,
606
+ schema_name: str,
607
+ table: TableMetadata,
608
+ ) -> None:
609
+ """Profile all columns in a Databricks table via databricks-sql-connector."""
610
+ debug(f"[schema_profiling._profile_table_sql_connector] profiling {table.name}")
611
+ full_table = f"`{catalog}`.`{schema_name}`.`{table.name}`"
612
+ try:
613
+ with connection.cursor() as cursor:
614
+ cursor.execute(f"SELECT COUNT(*) FROM {full_table}")
615
+ rows = _cursor_rows_as_dicts(cursor)
616
+ row_count = rows[0].get(list(rows[0].keys())[0], 0) or 0 if rows else 0
617
+ except Exception as e:
618
+ debug(f"[schema_profiling._profile_table_sql_connector] row count failed: {table.name}: {e}")
619
+ row_count = 0
620
+ table.row_count = row_count
621
+ for col in table.columns.values():
622
+ _profile_column_sql_connector(
623
+ connection, catalog, schema_name, col, table.name, row_count
624
+ )
625
+ _profile_composite_descriptive_sql_connector(connection, catalog, schema_name, table)
626
+ debug(f"[schema_profiling._profile_table_sql_connector] completed: {table.name}")
627
+
628
+
629
+ def profile_schema_sql_connector(
630
+ connection,
631
+ catalog: str,
632
+ schema_name: str,
633
+ schema: SchemaGraph,
634
+ ) -> None:
635
+ """Profile all tables in a Databricks schema via databricks-sql-connector.
636
+
637
+ Args:
638
+
639
+ connection: Active ``databricks.sql`` connection.
640
+
641
+ catalog: The Unity Catalog name.
642
+
643
+ schema_name: The schema (database) name.
644
+
645
+ schema: The ``SchemaGraph`` whose tables will be profiled.
646
+ """
647
+ debug(f"[schema_profiling.profile_schema_sql_connector] profiling {len(schema.tables)} tables")
648
+ for table in schema.tables.values():
649
+ _profile_table_sql_connector(connection, catalog, schema_name, table)
650
+ debug("[schema_profiling.profile_schema_sql_connector] completed")
651
+
652
+
653
+ def extract_tables_from_catalog_sql_connector(
654
+ connection,
655
+ catalog: str,
656
+ schema: str,
657
+ ) -> dict[str, dict]:
658
+ """Extract full table metadata from a Databricks Unity Catalog schema via SQL connector.
659
+
660
+ Args:
661
+
662
+ connection: Active ``databricks.sql`` connection.
663
+
664
+ catalog: The catalog name.
665
+
666
+ schema: The schema (database) name.
667
+
668
+ Returns:
669
+
670
+ Dict mapping table name to a metadata dict with keys ``table_name_original``,
671
+ ``column_names_original``, ``column_types``, ``primary_keys``, ``foreign_keys``,
672
+ ``table_comment``, and ``properties``.
673
+ """
674
+ tables = {}
675
+ with connection.cursor() as cursor:
676
+ cursor.execute(f"SHOW TABLES IN {catalog}.{schema}")
677
+ table_rows = _cursor_rows_as_dicts(cursor)
678
+ table_col = "tableName"
679
+ if table_rows and table_rows[0]:
680
+ row0 = table_rows[0]
681
+ if "tableName" in row0:
682
+ table_col = "tableName"
683
+ elif "tablename" in row0:
684
+ table_col = "tablename"
685
+ else:
686
+ table_col = list(row0.keys())[0]
687
+ for row in table_rows or []:
688
+ table_name = row.get(table_col) if row else None
689
+ if not table_name:
690
+ continue
691
+ full_table = f"{catalog}.{schema}.{table_name}"
692
+ debug(f"[schema_profiling.extract_tables_from_catalog_sql_connector] extracting: {full_table}")
693
+ with connection.cursor() as cursor:
694
+ cursor.execute(f"DESCRIBE TABLE {full_table}")
695
+ cols = _cursor_rows_as_dicts(cursor)
696
+ column_names = []
697
+ column_types = []
698
+ for col in cols:
699
+ cname = col.get("col_name") or col.get("colname")
700
+ if not cname or str(cname).startswith("#"):
701
+ break
702
+ column_names.append(cname)
703
+ column_types.append(col.get("data_type") or "STRING")
704
+ primary_keys = []
705
+ foreign_keys = []
706
+ partition_columns: list[str] = []
707
+ try:
708
+ with connection.cursor() as cursor:
709
+ cursor.execute(f"SHOW CREATE TABLE {full_table}")
710
+ create_rows = _cursor_rows_as_dicts(cursor)
711
+ if create_rows:
712
+ row0 = create_rows[0] or {}
713
+ stmt = row0.get("createtab_stmt") or (
714
+ list(row0.values())[0] if row0 else None
715
+ )
716
+ if stmt:
717
+ primary_keys, foreign_keys = _parse_unity_catalog_constraints(stmt)
718
+ partition_columns = _parse_partition_columns_from_create_stmt(stmt)
719
+ debug(f"[schema_profiling.extract_tables_from_catalog_sql_connector] ddl_found: {full_table}")
720
+ except Exception as e:
721
+ debug(f"[schema_profiling.extract_tables_from_catalog_sql_connector] ddl_error: {full_table} {e}")
722
+
723
+ if not partition_columns:
724
+ partition_columns = _extract_partition_columns_from_describe_detail_sql_connector(
725
+ connection, full_table
726
+ )
727
+
728
+ properties = {}
729
+ try:
730
+ with connection.cursor() as cursor:
731
+ cursor.execute(f"SHOW TBLPROPERTIES {full_table}")
732
+ prop_rows = _cursor_rows_as_dicts(cursor)
733
+ for r in prop_rows or []:
734
+ k = r.get("key")
735
+ v = r.get("value")
736
+ if k is not None:
737
+ properties[k] = v
738
+ except Exception:
739
+ pass
740
+ table_comment = None
741
+ try:
742
+ with connection.cursor() as cursor:
743
+ cursor.execute(f"DESCRIBE TABLE EXTENDED {full_table}")
744
+ ext_rows = _cursor_rows_as_dicts(cursor)
745
+ for r in ext_rows or []:
746
+ cname = r.get("col_name") or r.get("colname")
747
+ if cname == "Comment":
748
+ table_comment = r.get("data_type")
749
+ break
750
+ except Exception:
751
+ pass
752
+ tables[table_name] = {
753
+ "table_name_original": table_name,
754
+ "column_names_original": column_names,
755
+ "column_types": column_types,
756
+ "primary_keys": primary_keys,
757
+ "foreign_keys": foreign_keys,
758
+ "partition_columns": partition_columns,
759
+ "table_comment": table_comment,
760
+ "properties": properties,
761
+ }
762
+ debug(f"[schema_profiling.extract_tables_from_catalog_sql_connector] complete: {len(tables)} tables")
763
+ return tables
764
+
765
+
766
+ def _enrich_fk_column_descriptions(schema: SchemaGraph) -> None:
767
+ """Append navigational hints to FK column descriptions.
768
+
769
+ For each foreign-key column whose description does not already mention the target table, appends a short suffix listing the target table and its notable descriptive (non-PK, non-FK) columns. This helps the intent LLM discover join relationships from the schema text alone.
770
+ """
771
+ for table in schema.tables.values():
772
+ for col in table.columns.values():
773
+ if not col.fk_target:
774
+ continue
775
+ dst_table_name, _dst_col = col.fk_target
776
+ dst_table = schema.tables.get(dst_table_name)
777
+ if not dst_table:
778
+ continue
779
+ if dst_table_name.lower() in (col.description or "").lower():
780
+ continue
781
+ descriptive_cols = [
782
+ c.name
783
+ for c in dst_table.columns.values()
784
+ if not c.is_primary_key and not c.is_foreign_key and c.role not in ("identifier", "")
785
+ ][:3]
786
+ if not descriptive_cols:
787
+ descriptive_cols = [
788
+ c.name for c in dst_table.columns.values() if not c.is_primary_key and not c.is_foreign_key
789
+ ][:3]
790
+ if descriptive_cols:
791
+ suffix = f"join {dst_table_name} for {', '.join(descriptive_cols)}"
792
+ else:
793
+ suffix = f"join {dst_table_name}"
794
+ existing = (col.description or "").rstrip(". ")
795
+ col.description = f"{existing} — {suffix}" if existing else suffix
796
+
797
+
798
+ def _infer_column_role(col: ColumnMetadata) -> ColumnRole:
799
+ """Infer a column's role from its metadata using heuristic rules (fallback).
800
+
801
+ Evaluation priority: BOOLEAN (type) → IDENTIFIER (PK/FK) → BOOLEAN (value pattern) → TEMPORAL → FREE_TEXT (high uniqueness) → CATEGORICAL / NUMERIC_CATEGORICAL → NUMERIC_MEASURE → FREE_TEXT.
802
+
803
+ Args:
804
+
805
+ col: The ``ColumnMetadata`` to classify.
806
+
807
+ Returns:
808
+
809
+ The inferred ``ColumnRole`` enum value.
810
+ """
811
+ if col.value_type == "boolean":
812
+ return ColumnRole.BOOLEAN
813
+
814
+ if col.is_primary_key:
815
+ return ColumnRole.IDENTIFIER
816
+
817
+ if col.is_foreign_key:
818
+ return ColumnRole.IDENTIFIER
819
+
820
+ if _has_boolean_like_values(col):
821
+ return ColumnRole.BOOLEAN
822
+
823
+ if col.value_type == "date":
824
+ return ColumnRole.TEMPORAL
825
+
826
+ if col.distinct_ratio >= PolicyConfig.IDENTIFIER_MIN_UNIQUENESS:
827
+ return ColumnRole.FREE_TEXT
828
+
829
+ is_categorical = (
830
+ col.distinct_count <= PolicyConfig.CATEGORICAL_MAX_CARDINALITY
831
+ or col.distinct_ratio <= PolicyConfig.CATEGORICAL_MAX_RATIO
832
+ )
833
+ if is_categorical:
834
+ if col.value_type in ("integer", "number"):
835
+ return ColumnRole.NUMERIC_CATEGORICAL
836
+ return ColumnRole.CATEGORICAL
837
+
838
+ if col.value_type in ("integer", "number"):
839
+ return ColumnRole.NUMERIC_MEASURE
840
+
841
+ return ColumnRole.FREE_TEXT
842
+
843
+
844
+ def _validate_column_classification(col: ColumnMetadata, role: str) -> tuple[list[str], list[str]]:
845
+ """Validate an LLM-assigned column role against profiling data.
846
+
847
+ Args:
848
+
849
+ col: The ``ColumnMetadata`` with profiling statistics.
850
+
851
+ role: The role string assigned by the LLM.
852
+
853
+ Returns:
854
+
855
+ Tuple of ``(hard_errors, soft_warnings)`` where hard errors block acceptance and trigger a retry, and soft warnings are logged only.
856
+ """
857
+ hard_errors = []
858
+ soft_warnings = []
859
+
860
+ is_numeric = col.value_type in ("integer", "number")
861
+ is_temporal = col.value_type == "date"
862
+ col_name_lower = col.name.lower()
863
+ dtype = (col.data_type or "").upper()
864
+
865
+ if role == ColumnRole.NUMERIC_MEASURE.value and not is_numeric:
866
+ hard_errors.append(f"{col.name}: NUMERIC_MEASURE requires numeric type, got '{col.data_type}'")
867
+
868
+ if role == ColumnRole.TEMPORAL.value and not is_temporal:
869
+ if "year" in col_name_lower or "DOMAIN" in dtype or "YEAR" in dtype:
870
+ soft_warnings.append(f"{col.name}: TEMPORAL on year/domain column, recommend NUMERIC_CATEGORICAL")
871
+ else:
872
+ hard_errors.append(f"{col.name}: TEMPORAL requires date/time type, got '{col.data_type}'")
873
+
874
+ if role == ColumnRole.BOOLEAN.value and col.distinct_count and col.distinct_count > 2:
875
+ hard_errors.append(f"{col.name}: BOOLEAN requires distinct_count <= 2, got {col.distinct_count}")
876
+
877
+ if role == ColumnRole.CATEGORICAL.value and col.distinct_count and col.distinct_count > 1000:
878
+ soft_warnings.append(f"{col.name}: CATEGORICAL with high cardinality ({col.distinct_count})")
879
+
880
+ if role == ColumnRole.NUMERIC_MEASURE.value and col.distinct_count and col.distinct_count <= 5:
881
+ soft_warnings.append(f"{col.name}: NUMERIC_MEASURE with low cardinality ({col.distinct_count})")
882
+
883
+ if role == ColumnRole.IDENTIFIER.value and not col.is_primary_key and not col.is_foreign_key:
884
+ soft_warnings.append(f"{col.name}: IDENTIFIER on non-PK/FK column")
885
+
886
+ return hard_errors, soft_warnings
887
+
888
+
889
+ def _build_column_profile_for_llm(col: ColumnMetadata) -> dict:
890
+ """Build a column profile dict for inclusion in the LLM classification prompt.
891
+
892
+ Args:
893
+
894
+ col: The ``ColumnMetadata`` to summarise.
895
+
896
+ Returns:
897
+
898
+ Dict with ``name``, ``data_type``, ``is_primary_key``, ``is_foreign_key``, and an optional ``profile_hints`` sub-dict.
899
+ """
900
+ profile = {
901
+ "name": col.name,
902
+ "data_type": col.data_type,
903
+ "is_primary_key": col.is_primary_key,
904
+ "is_foreign_key": col.is_foreign_key,
905
+ }
906
+ hints: dict = {}
907
+ if col.distinct_count is not None:
908
+ hints["distinct_count"] = col.distinct_count
909
+ if col.distinct_ratio is not None:
910
+ hints["distinct_ratio"] = round(col.distinct_ratio, 3)
911
+ if col.null_ratio is not None:
912
+ hints["null_ratio"] = round(col.null_ratio, 3)
913
+ if hints:
914
+ profile["profile_hints"] = hints
915
+ return profile
916
+
917
+
918
+ def _llm_classify_schema(
919
+ schema: SchemaGraph,
920
+ ) -> dict[str, tuple[str, str, dict[str, tuple[str, str]]]]:
921
+ """Use a single LLM call to classify all table roles, column roles, column semantic hints, and table descriptions.
922
+
923
+ Args:
924
+
925
+ schema: The ``SchemaGraph`` containing all tables and columns to classify.
926
+
927
+ Returns:
928
+
929
+ Dict mapping table name to a tuple of ``(table_role, description, {col_name: (role, hint), ...})``.
930
+
931
+ Raises:
932
+
933
+ ValueError: If the LLM returns invalid or non-dict JSON.
934
+ """
935
+ tables_data = []
936
+ for table in schema.tables.values():
937
+ fks = [",".join(fk.src_cols) + "->" + fk.dst_table + "." + ",".join(fk.dst_cols) for fk in table.foreign_keys]
938
+ column_profiles = [_build_column_profile_for_llm(col) for col in table.columns.values()]
939
+ tables_data.append(
940
+ {
941
+ "table": table.name,
942
+ "fks": fks,
943
+ "columns": column_profiles,
944
+ }
945
+ )
946
+ system = (
947
+ "Classify every table's role and every column's role in this schema.\n\n"
948
+ "TABLE ROLES:\n"
949
+ "- dimension: reference/lookup table referenced by others, descriptive attributes\n"
950
+ "- fact: transactional/event table with FKs to dimensions, contains measures\n"
951
+ "- bridge: junction table for many-to-many, mostly FKs, few own columns\n"
952
+ "- unknown: cannot confidently classify\n"
953
+ "Use FK topology: tables referenced by many others are dimension; tables with many outbound FKs are fact; tables with only 2+ FKs and minimal columns are bridge.\n\n"
954
+ "COLUMN ROLE DECISION PRIORITY (evaluate in order, first match wins):\n"
955
+ "1. is_primary_key or is_foreign_key → identifier\n"
956
+ "2. data_type is date/time/timestamp → temporal\n"
957
+ "3. name suggests binary state (is_*, has_*, active) and distinct_count = 2 → boolean\n"
958
+ "4. numeric and name suggests quantity/amount/duration/size/distance/price/count → numeric_measure (integer or decimal — data type does not restrict this)\n"
959
+ "5. numeric and name suggests code/rating/level/rank/status/tier/type → numeric_categorical\n"
960
+ "6. numeric with no clear name signal → numeric_measure (default for numeric)\n"
961
+ "7. text and very high distinct_ratio → free_text\n"
962
+ "8. text → categorical (default for text)\n\n"
963
+ "PROFILE HINTS (supporting evidence only — never override name/type signals):\n"
964
+ "Each column may include a profile_hints object with distinct_count, distinct_ratio, and null_ratio. Use these to confirm or disambiguate when name and type are ambiguous.\n"
965
+ "Do NOT use profile_hints as the primary reason to choose a role.\n\n"
966
+ "CROSS-TABLE CONSISTENCY:\n"
967
+ "- Columns with the same name and data type across tables MUST receive the same role.\n\n"
968
+ "COLUMN HINTS:\n"
969
+ "For each column, provide a short semantic hint (max 8 words) describing what the column represents in business terms. Include common synonyms a user might use when asking questions about this data. The hint should help map natural language to the correct column.\n"
970
+ "Role-based guidance for hints:\n"
971
+ "- identifier columns: describe what entity the ID refers to.\n"
972
+ "- numeric_measure columns: state the unit or what is measured.\n"
973
+ "- categorical columns: mention common category values or groupings.\n"
974
+ "- temporal columns: state what event the date/time marks.\n"
975
+ "- boolean columns: describe the yes/no condition.\n"
976
+ "- FK columns: MUST state what business data the target table provides when joined. Name the key descriptive columns on the target table (e.g. 'links to target_table for name, title, description').\n\n"
977
+ "TABLE DESCRIPTIONS:\n"
978
+ "For each table provide a one-line business purpose that includes: (a) what entity or event the table represents, (b) which related tables it connects to via foreign keys, and (c) the notable descriptive or measure columns it provides that users commonly ask about.\n\n"
979
+ "Reason internally, output only JSON:\n"
980
+ '{"table1": {"table_role": "...", "description": "one-line business purpose including related tables and key columns", "columns": {"col1": {"role": "...", "hint": "..."}, ...}}, ...}'
981
+ )
982
+ user = stable_json({"tables": tables_data})
983
+ raw = llm_chat(system, user, timeout=120.0, task="schema")
984
+ result = safe_json_loads(raw)
985
+ if not result or not isinstance(result, dict):
986
+ raise ValueError(f"LLM returned invalid JSON for schema classification: {raw[:200]}")
987
+ valid_table_roles = {"dimension", "fact", "bridge", "unknown"}
988
+ valid_col_roles = {r.value for r in ColumnRole}
989
+ classifications = {}
990
+ for table_name, table_data in result.items():
991
+ if not isinstance(table_data, dict):
992
+ continue
993
+ table_role = table_data.get("table_role", "unknown").lower()
994
+ if table_role not in valid_table_roles:
995
+ table_role = "unknown"
996
+ description = str(table_data.get("description", "")).strip()
997
+ columns_data = table_data.get("columns", {})
998
+ column_classifications: dict[str, tuple[str, str]] = {}
999
+ for col_name, classification in columns_data.items():
1000
+ if isinstance(classification, dict):
1001
+ role = classification.get("role", "").lower()
1002
+ hint = str(classification.get("hint", "")).strip()
1003
+ else:
1004
+ role = str(classification).lower()
1005
+ hint = ""
1006
+ if role not in valid_col_roles:
1007
+ role = ColumnRole.FREE_TEXT.value
1008
+ column_classifications[col_name] = (role, hint)
1009
+ classifications[table_name] = (table_role, description, column_classifications)
1010
+ return classifications
1011
+
1012
+
1013
+ def apply_column_roles_llm(schema: SchemaGraph) -> None:
1014
+ """Apply LLM-inferred roles and table descriptions to the schema in-place.
1015
+
1016
+ Retries up to ``QSimConfig.MAX_ROLE_CLASSIFICATION_RETRIES`` times on hard validation errors. Falls back to ``infer_column_role`` for all tables if all LLM attempts fail. Also applies hard overrides for PK/FK columns and boolean-value-pattern columns after LLM assignment.
1017
+
1018
+ Args:
1019
+
1020
+ schema: The ``SchemaGraph`` to update in-place.
1021
+ """
1022
+ debug(f"[schema_profiling.apply_column_roles_llm] classifying {len(schema.tables)} tables via LLM (single-call)")
1023
+ total_columns = sum(len(table.columns) for table in schema.tables.values())
1024
+ debug(f"[schema_profiling.apply_column_roles_llm] total columns: {total_columns}")
1025
+ role_counts: dict[str, int] = {}
1026
+ table_role_counts: dict[str, int] = {}
1027
+ llm_success = 0
1028
+ llm_fallback = 0
1029
+ success = False
1030
+ for attempt in range(QSimConfig.MAX_ROLE_CLASSIFICATION_RETRIES + 1):
1031
+ try:
1032
+ classifications = _llm_classify_schema(schema)
1033
+ all_hard_errors = []
1034
+ all_soft_warnings = []
1035
+ for table in schema.tables.values():
1036
+ if table.name not in classifications:
1037
+ all_hard_errors.append(f"{table.name}: missing from LLM response")
1038
+ continue
1039
+ table_role, _desc, column_classifications = classifications[table.name]
1040
+ for col in table.columns.values():
1041
+ if col.name not in column_classifications:
1042
+ all_hard_errors.append(f"{table.name}.{col.name}: missing from LLM response")
1043
+ continue
1044
+ role, _hint = column_classifications[col.name]
1045
+ hard_errors, soft_warnings = _validate_column_classification(col, role)
1046
+ all_hard_errors.extend([f"{table.name}.{e}" for e in hard_errors])
1047
+ all_soft_warnings.extend([f"{table.name}.{w}" for w in soft_warnings])
1048
+ for warning in all_soft_warnings:
1049
+ debug(f"[apply_column_roles_llm] WARNING: {warning}")
1050
+ if all_hard_errors:
1051
+ debug(
1052
+ f"[apply_column_roles_llm] {len(all_hard_errors)} hard errors (attempt {attempt + 1}): {all_hard_errors[:5]}"
1053
+ )
1054
+ continue
1055
+ for table in schema.tables.values():
1056
+ if table.name in classifications:
1057
+ table_role, description, column_classifications = classifications[table.name]
1058
+ table.role = table_role
1059
+ table.description = description
1060
+ table_role_counts[table_role] = table_role_counts.get(table_role, 0) + 1
1061
+ for col in table.columns.values():
1062
+ if col.name in column_classifications:
1063
+ role, hint = column_classifications[col.name]
1064
+ col.role = role
1065
+ col.description = hint
1066
+ if col.is_primary_key or col.is_foreign_key:
1067
+ if col.role != ColumnRole.IDENTIFIER.value:
1068
+ debug(
1069
+ f"[apply_column_roles_llm] override {table.name}.{col.name}: {col.role} → identifier (pk/fk)"
1070
+ )
1071
+ col.role = ColumnRole.IDENTIFIER.value
1072
+ elif _has_boolean_like_values(col):
1073
+ if col.role != ColumnRole.BOOLEAN.value:
1074
+ debug(
1075
+ f"[apply_column_roles_llm] override {table.name}.{col.name}: {col.role} → boolean (value pattern)"
1076
+ )
1077
+ col.role = ColumnRole.BOOLEAN.value
1078
+ elif (
1079
+ col.role == ColumnRole.FREE_TEXT.value
1080
+ and col.distinct_count is not None
1081
+ and col.distinct_count <= PolicyConfig.FREE_TEXT_CATEGORICAL_MAX_CARDINALITY
1082
+ ):
1083
+ debug(
1084
+ f"[apply_column_roles_llm] override {table.name}.{col.name}: free_text → categorical (distinct={col.distinct_count})"
1085
+ )
1086
+ col.role = ColumnRole.CATEGORICAL.value
1087
+ role_counts[col.role] = role_counts.get(col.role, 0) + 1
1088
+ success = True
1089
+ llm_success = len(schema.tables)
1090
+ debug("[apply_column_roles_llm] single-call classification successful")
1091
+ break
1092
+ except Exception as e:
1093
+ debug(f"[apply_column_roles_llm] attempt {attempt + 1} failed: {e}")
1094
+ continue
1095
+ if not success:
1096
+ debug("[apply_column_roles_llm] LLM failed, using heuristic fallback for all tables")
1097
+ for table in schema.tables.values():
1098
+ llm_fallback += 1
1099
+ fk_out = len(table.foreign_keys)
1100
+ fk_in = sum(1 for t in schema.tables.values() for fk in t.foreign_keys if fk.dst_table == table.name)
1101
+ if fk_out >= 2:
1102
+ table_role = TableRole.FACT.value
1103
+ elif fk_out == 0 and fk_in >= 1:
1104
+ table_role = TableRole.DIMENSION.value
1105
+ elif fk_out == 2 and fk_in == 0 and len(table.columns) <= 4:
1106
+ table_role = TableRole.BRIDGE.value
1107
+ else:
1108
+ table_role = TableRole.UNKNOWN.value
1109
+ table.role = table_role
1110
+ table_role_counts[table_role] = table_role_counts.get(table_role, 0) + 1
1111
+ for col in table.columns.values():
1112
+ role = _infer_column_role(col)
1113
+ col.role = role.value
1114
+ role_counts[role.value] = role_counts.get(role.value, 0) + 1
1115
+ debug(f"[apply_column_roles_llm] completed: {llm_success} LLM, {llm_fallback} fallback")
1116
+ debug(f"[apply_column_roles_llm] table distribution: {table_role_counts}")
1117
+ debug(f"[apply_column_roles_llm] column distribution: {role_counts}")
1118
+ _enrich_fk_column_descriptions(schema)
1119
+ boolean_mapped = 0
1120
+ for table in schema.tables.values():
1121
+ for col in table.columns.values():
1122
+ if col.is_boolean_like or "bool" in (col.data_type or "").lower():
1123
+ _populate_boolean_mapping(col)
1124
+ if col.boolean_true_value is not None:
1125
+ boolean_mapped += 1
1126
+ if boolean_mapped:
1127
+ debug(f"[apply_column_roles_llm] boolean mappings populated: {boolean_mapped}")
1128
+
1129
+
1130
+ def assign_column_ops(schema: SchemaGraph) -> None:
1131
+ """Assign valid filter, aggregation, and HAVING ops to each column based on its final role.
1132
+
1133
+ Deterministic; run after role assignment. Removes string-only operators from non-string columns and numeric-only aggregations from non-numeric columns. Columns marked ``is_filterable=False`` receive an empty filter-ops list.
1134
+
1135
+ Args:
1136
+
1137
+ schema: The ``SchemaGraph`` to update in-place.
1138
+ """
1139
+ debug("[schema_profiling.assign_column_ops] assigning ops to columns")
1140
+
1141
+ null_ops = ["is null", "is not null"]
1142
+ string_only_ops = {"like", "ilike", "not like", "not ilike"}
1143
+ numeric_only_aggs = {"sum", "avg"}
1144
+
1145
+ for table in schema.tables.values():
1146
+ for col in table.columns.values():
1147
+ role = col.role
1148
+ vt = col.value_type
1149
+ string = vt == "string"
1150
+ numeric = vt in ("integer", "number")
1151
+
1152
+ if role == ColumnRole.AUDIT.value:
1153
+ col.valid_filter_ops = []
1154
+ col.valid_aggregations = []
1155
+ col.valid_having_ops = []
1156
+ elif col.is_primary_key or col.is_foreign_key or role == ColumnRole.IDENTIFIER.value:
1157
+ col.valid_filter_ops = [
1158
+ "=",
1159
+ "!=",
1160
+ "<",
1161
+ "<=",
1162
+ ">",
1163
+ ">=",
1164
+ "between",
1165
+ "in",
1166
+ "not in",
1167
+ ] + null_ops
1168
+ col.valid_aggregations = ["count"]
1169
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1170
+ elif role == ColumnRole.CATEGORICAL.value:
1171
+ col.valid_filter_ops = [
1172
+ "=",
1173
+ "!=",
1174
+ "in",
1175
+ "not in",
1176
+ "like",
1177
+ "ilike",
1178
+ "not like",
1179
+ "not ilike",
1180
+ ] + null_ops
1181
+ col.valid_aggregations = ["count", "min", "max"]
1182
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1183
+ elif role == ColumnRole.NUMERIC_CATEGORICAL.value:
1184
+ col.valid_filter_ops = [
1185
+ "=",
1186
+ "!=",
1187
+ "in",
1188
+ "not in",
1189
+ "<",
1190
+ "<=",
1191
+ ">",
1192
+ ">=",
1193
+ "between",
1194
+ ] + null_ops
1195
+ col.valid_aggregations = ["count", "min", "max"]
1196
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1197
+ elif role == ColumnRole.NUMERIC_MEASURE.value:
1198
+ col.valid_filter_ops = [
1199
+ "=",
1200
+ "!=",
1201
+ "<",
1202
+ "<=",
1203
+ ">",
1204
+ ">=",
1205
+ "between",
1206
+ "in",
1207
+ "not in",
1208
+ ] + null_ops
1209
+ col.valid_aggregations = ["sum", "avg", "min", "max", "count"]
1210
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1211
+ elif role == ColumnRole.TEMPORAL.value:
1212
+ col.valid_filter_ops = [
1213
+ "=",
1214
+ "!=",
1215
+ "<",
1216
+ "<=",
1217
+ ">",
1218
+ ">=",
1219
+ "between",
1220
+ "in",
1221
+ "not in",
1222
+ ] + null_ops
1223
+ col.valid_aggregations = ["min", "max", "count"]
1224
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1225
+ elif role == ColumnRole.BOOLEAN.value:
1226
+ col.valid_filter_ops = ["=", "!=", "in", "not in"] + null_ops
1227
+ col.valid_aggregations = ["count"]
1228
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1229
+ elif role == ColumnRole.FREE_TEXT.value:
1230
+ col.valid_filter_ops = [
1231
+ "like",
1232
+ "ilike",
1233
+ "not like",
1234
+ "not ilike",
1235
+ ] + null_ops
1236
+ col.valid_aggregations = ["count"]
1237
+ col.valid_having_ops = []
1238
+ else:
1239
+ col.valid_filter_ops = ["=", "!="] + null_ops
1240
+ col.valid_aggregations = ["count"]
1241
+ col.valid_having_ops = ["=", "!=", "<", "<=", ">", ">="]
1242
+
1243
+ if not string:
1244
+ col.valid_filter_ops = [op for op in col.valid_filter_ops if op not in string_only_ops]
1245
+ if not numeric:
1246
+ col.valid_aggregations = [agg for agg in col.valid_aggregations if agg not in numeric_only_aggs]
1247
+
1248
+ if not col.is_filterable:
1249
+ if role == ColumnRole.FREE_TEXT.value:
1250
+ pattern_ops = {"like", "ilike", "not like", "not ilike",
1251
+ "is null", "is not null"}
1252
+ col.valid_filter_ops = [
1253
+ op for op in col.valid_filter_ops if op in pattern_ops
1254
+ ]
1255
+ else:
1256
+ col.valid_filter_ops = []
1257
+
1258
+ debug("[schema_profiling.assign_column_ops] completed")
1259
+
1260
+
1261
+ def extract_tables_from_catalog(spark, catalog: str, schema: str) -> dict[str, dict]:
1262
+ """Extract full table metadata from a Databricks Unity Catalog schema.
1263
+
1264
+ Args:
1265
+
1266
+ spark: Active ``SparkSession``.
1267
+
1268
+ catalog: The catalog name.
1269
+
1270
+ schema: The schema (database) name.
1271
+
1272
+ Returns:
1273
+
1274
+ Dict mapping table name to a metadata dict with keys ``table_name_original``, ``column_names_original``, ``column_types``, ``primary_keys``, ``foreign_keys``, ``table_comment``, and ``properties``.
1275
+ """
1276
+ tables = {}
1277
+
1278
+ table_list = spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
1279
+
1280
+ for row in table_list:
1281
+ table_name = row["tableName"]
1282
+ full_table = f"{catalog}.{schema}.{table_name}"
1283
+
1284
+ debug(f"[schema_profiling.extract_tables_from_catalog] extracting: {full_table}")
1285
+
1286
+ cols = spark.sql(f"DESCRIBE TABLE {full_table}").collect()
1287
+
1288
+ column_names = []
1289
+ column_types = []
1290
+ primary_keys = []
1291
+
1292
+ for col in cols:
1293
+ col_name = col["col_name"]
1294
+
1295
+ if col_name.startswith("#"):
1296
+ break
1297
+
1298
+ column_names.append(col_name)
1299
+ column_types.append(col["data_type"])
1300
+
1301
+ partition_columns: list[str] = []
1302
+ try:
1303
+ create_result = spark.sql(f"SHOW CREATE TABLE {full_table}").collect()
1304
+ if create_result:
1305
+ create_stmt = create_result[0]["createtab_stmt"]
1306
+ primary_keys, foreign_keys = _parse_unity_catalog_constraints(create_stmt)
1307
+ partition_columns = _parse_partition_columns_from_create_stmt(create_stmt)
1308
+ debug(f"[schema_profiling.extract_tables_from_catalog] ddl_found: {full_table}")
1309
+ except Exception as e:
1310
+ debug(f"[schema_profiling.extract_tables_from_catalog] ddl_error: {full_table} {e}")
1311
+ primary_keys = []
1312
+ foreign_keys = []
1313
+
1314
+ if not partition_columns:
1315
+ partition_columns = _extract_partition_columns_from_describe_detail_spark(
1316
+ spark, full_table
1317
+ )
1318
+
1319
+ try:
1320
+ props = spark.sql(f"SHOW TBLPROPERTIES {full_table}").collect()
1321
+ properties = {p["key"]: p["value"] for p in props}
1322
+ except Exception:
1323
+ properties = {}
1324
+
1325
+ try:
1326
+ extended = spark.sql(f"DESCRIBE TABLE EXTENDED {full_table}").collect()
1327
+ table_comment = None
1328
+ for row in extended:
1329
+ if row["col_name"] == "Comment":
1330
+ table_comment = row["data_type"]
1331
+ break
1332
+ except Exception:
1333
+ table_comment = None
1334
+
1335
+ tables[table_name] = {
1336
+ "table_name_original": table_name,
1337
+ "column_names_original": column_names,
1338
+ "column_types": column_types,
1339
+ "primary_keys": primary_keys,
1340
+ "foreign_keys": foreign_keys,
1341
+ "partition_columns": partition_columns,
1342
+ "table_comment": table_comment,
1343
+ "properties": properties,
1344
+ }
1345
+
1346
+ debug(f"[schema_profiling.extract_tables_from_catalog] complete: {len(tables)} tables")
1347
+ return tables
1348
+
1349
+
1350
+ def _extract_partition_columns_from_describe_detail_spark(
1351
+ spark, full_table: str
1352
+ ) -> list[str]:
1353
+ """Extract partition column names via DESCRIBE DETAIL, fallback INFORMATION_SCHEMA.
1354
+
1355
+ Args:
1356
+
1357
+ spark: Active SparkSession.
1358
+
1359
+ full_table: Fully qualified table name (catalog.schema.table).
1360
+
1361
+ Returns:
1362
+
1363
+ List of partition column name strings.
1364
+ """
1365
+ try:
1366
+ detail_df = spark.sql(f"DESCRIBE DETAIL {full_table}")
1367
+ row = detail_df.collect()
1368
+ if row:
1369
+ r = row[0]
1370
+ cols = r.get("partitionColumns") or r.get("partition_columns")
1371
+ if isinstance(cols, list) and cols:
1372
+ return [str(c) for c in cols]
1373
+ except Exception as e:
1374
+ debug(f"[schema_profiling._extract_partition_from_detail] DESCRIBE DETAIL failed: {e}")
1375
+
1376
+ try:
1377
+ parts = full_table.split(".")
1378
+ if len(parts) >= 3:
1379
+ catalog_name, schema_name, table_name = parts[0], parts[1], parts[2]
1380
+ info_schema = f"{catalog_name}.information_schema.columns"
1381
+ q = f"""
1382
+ SELECT column_name FROM {info_schema}
1383
+ WHERE table_catalog = '{catalog_name}'
1384
+ AND table_schema = '{schema_name}'
1385
+ AND table_name = '{table_name}'
1386
+ AND partition_ordinal_position IS NOT NULL
1387
+ ORDER BY partition_ordinal_position
1388
+ """
1389
+ info_rows = spark.sql(q).collect()
1390
+ return [str(r["column_name"]) for r in info_rows if r.get("column_name")]
1391
+ except Exception as e:
1392
+ debug(f"[schema_profiling._extract_partition_from_detail] INFORMATION_SCHEMA failed: {e}")
1393
+
1394
+ return []
1395
+
1396
+
1397
+ def _extract_partition_columns_from_describe_detail_sql_connector(
1398
+ connection, full_table: str
1399
+ ) -> list[str]:
1400
+ """Extract partition column names via DESCRIBE DETAIL, fallback INFORMATION_SCHEMA.
1401
+
1402
+ Args:
1403
+
1404
+ connection: Active databricks.sql connection.
1405
+
1406
+ full_table: Fully qualified table name (catalog.schema.table).
1407
+
1408
+ Returns:
1409
+
1410
+ List of partition column name strings.
1411
+ """
1412
+ try:
1413
+ with connection.cursor() as cursor:
1414
+ cursor.execute(f"DESCRIBE DETAIL {full_table}")
1415
+ rows = _cursor_rows_as_dicts(cursor)
1416
+ if rows:
1417
+ r = rows[0]
1418
+ cols = r.get("partitionColumns") or r.get("partition_columns")
1419
+ if isinstance(cols, list) and cols:
1420
+ return [str(c) for c in cols]
1421
+ except Exception as e:
1422
+ debug(
1423
+ f"[schema_profiling._extract_partition_sql_connector] DESCRIBE DETAIL failed: {e}"
1424
+ )
1425
+
1426
+ try:
1427
+ parts = full_table.split(".")
1428
+ if len(parts) >= 3:
1429
+ catalog_name, schema_name, table_name = parts[0], parts[1], parts[2]
1430
+ info_schema = f"{catalog_name}.information_schema.columns"
1431
+ q = f"""
1432
+ SELECT column_name FROM {info_schema}
1433
+ WHERE table_catalog = '{catalog_name}'
1434
+ AND table_schema = '{schema_name}'
1435
+ AND table_name = '{table_name}'
1436
+ AND partition_ordinal_position IS NOT NULL
1437
+ ORDER BY partition_ordinal_position
1438
+ """
1439
+ with connection.cursor() as cursor:
1440
+ cursor.execute(q)
1441
+ info_rows = _cursor_rows_as_dicts(cursor)
1442
+ return [str(r["column_name"]) for r in info_rows if r.get("column_name")]
1443
+ except Exception as e:
1444
+ debug(
1445
+ f"[schema_profiling._extract_partition_sql_connector] INFORMATION_SCHEMA failed: {e}"
1446
+ )
1447
+
1448
+ return []
1449
+
1450
+
1451
+ def _parse_partition_columns_from_create_stmt(create_stmt: str) -> list[str]:
1452
+ """Extract partition column names from a CREATE TABLE DDL string.
1453
+
1454
+ Matches ``PARTITIONED BY (col1, col2, ...)`` (case-insensitive).
1455
+
1456
+ Args:
1457
+
1458
+ create_stmt: The raw CREATE TABLE DDL string.
1459
+
1460
+ Returns:
1461
+
1462
+ List of partition column name strings, or empty list if not found.
1463
+ """
1464
+ match = re.search(r"PARTITIONED\s+BY\s*\(([^)]+)\)", create_stmt, re.IGNORECASE)
1465
+ if not match:
1466
+ return []
1467
+ return [c.strip().strip("`").strip('"') for c in match.group(1).split(",")]
1468
+
1469
+
1470
+ def _parse_unity_catalog_constraints(create_stmt: str) -> tuple[list[str], list[dict]]:
1471
+ """Parse PRIMARY KEY and FOREIGN KEY constraints from a CREATE TABLE DDL string.
1472
+
1473
+ Args:
1474
+
1475
+ create_stmt: The raw CREATE TABLE DDL string from ``SHOW CREATE TABLE``.
1476
+
1477
+ Returns:
1478
+
1479
+ Tuple of ``(primary_keys, foreign_keys)`` where ``primary_keys`` is a list of column name strings and ``foreign_keys`` is a list of dicts with keys ``src_cols``, ``dst_table``, and ``dst_cols``.
1480
+ """
1481
+ primary_keys = []
1482
+ foreign_keys = []
1483
+
1484
+ pk_pattern = r"CONSTRAINT\s+\w+\s+PRIMARY\s+KEY\s*\(([^)]+)\)"
1485
+ pk_matches = re.findall(pk_pattern, create_stmt, re.IGNORECASE)
1486
+ for match in pk_matches:
1487
+ cols = [c.strip().strip("`").strip('"') for c in match.split(",")]
1488
+ primary_keys.extend(cols)
1489
+
1490
+ fk_pattern = r"CONSTRAINT\s+\w+\s+FOREIGN\s+KEY\s*\(([^)]+)\)\s+REFERENCES\s+(\w+)\s*\(([^)]+)\)"
1491
+ fk_matches = re.findall(fk_pattern, create_stmt, re.IGNORECASE)
1492
+ for match in fk_matches:
1493
+ src_cols = [c.strip().strip("`").strip('"') for c in match[0].split(",")]
1494
+ ref_table = match[1].strip("`").strip('"')
1495
+ ref_cols = [c.strip().strip("`").strip('"') for c in match[2].split(",")]
1496
+
1497
+ foreign_keys.append({"src_cols": src_cols, "dst_table": ref_table, "dst_cols": ref_cols})
1498
+
1499
+ return primary_keys, foreign_keys
1500
+
1501
+
1502
+ def parse_sql_file(sql_path: Path) -> dict[str, dict]:
1503
+ """Parse CREATE TABLE statements from a SQL file using sqlglot with LLM fallback.
1504
+
1505
+ Attempts sqlglot first; if it returns 0 tables, falls back to an LLM call to extract metadata.
1506
+
1507
+ Args:
1508
+
1509
+ sql_path: Path to the SQL file to parse.
1510
+
1511
+ Returns:
1512
+
1513
+ Dict mapping table name to a metadata dict (same structure as ``extract_tables_from_catalog``).
1514
+ """
1515
+ with open(sql_path, encoding="utf-8-sig") as f:
1516
+ sql_content = f.read()
1517
+
1518
+ debug(f"[schema_profiling.parse_sql_file] reading: {len(sql_content)} chars")
1519
+
1520
+ tables = _parse_sql_file_fallback(sql_content)
1521
+ if tables:
1522
+ debug(f"[schema_profiling.parse_sql_file] sqlglot parsed: {len(tables)} tables")
1523
+ return tables
1524
+
1525
+ debug("[schema_profiling.parse_sql_file] sqlglot returned 0 tables, falling back to LLM")
1526
+
1527
+ system = """You are a deterministic SQL parser. Extract CREATE TABLE statements and output ONLY valid JSON. Be precise and consistent. Follow the output format exactly."""
1528
+
1529
+ user = stable_json(
1530
+ {
1531
+ "task": "Parse all CREATE TABLE statements from the SQL and extract complete metadata",
1532
+ "sql_content": sql_content,
1533
+ "output_format": {
1534
+ "tables": {
1535
+ "<table_name>": {
1536
+ "table_name_original": "exact table name from SQL (without schema prefix)",
1537
+ "column_names_original": ["column1", "column2"],
1538
+ "column_types": ["TYPE1", "TYPE2"],
1539
+ "primary_keys": ["column1"],
1540
+ "foreign_keys": [
1541
+ {
1542
+ "src_cols": ["column1"],
1543
+ "dst_table": "other_table",
1544
+ "dst_cols": ["column1"],
1545
+ }
1546
+ ],
1547
+ }
1548
+ }
1549
+ },
1550
+ "rules": [
1551
+ "Extract ALL CREATE TABLE statements, ignore other SQL commands",
1552
+ "Preserve exact table and column names (case-sensitive)",
1553
+ "Strip schema prefixes from table names (e.g., 'public.users' → 'users')",
1554
+ "Remove quotes from identifiers (e.g., '\"user_id\"' → 'user_id')",
1555
+ "Capture PRIMARY KEY constraints (inline and separate CONSTRAINT clauses)",
1556
+ "Capture FOREIGN KEY constraints with all source/destination columns",
1557
+ "Normalize data types to uppercase (e.g., 'varchar(50)' → 'VARCHAR(50)')",
1558
+ "Convert SERIAL → INTEGER, BIGSERIAL → BIGINT",
1559
+ "Use empty arrays [] for tables with no PKs or FKs",
1560
+ "Ignore CHECK constraints, DEFAULT values, and UNIQUE constraints",
1561
+ "Handle multi-line statements and SQL comments (-- and /* */)",
1562
+ "Output ONLY the JSON object, no markdown code blocks, no explanations",
1563
+ ],
1564
+ "examples": [
1565
+ {
1566
+ "input": "CREATE TABLE public.table1 (column1 SERIAL PRIMARY KEY, column2 VARCHAR(100));",
1567
+ "output": {
1568
+ "tables": {
1569
+ "table1": {
1570
+ "table_name_original": "table1",
1571
+ "column_names_original": ["column1", "column2"],
1572
+ "column_types": ["INTEGER", "VARCHAR(100)"],
1573
+ "primary_keys": ["column1"],
1574
+ "foreign_keys": [],
1575
+ }
1576
+ }
1577
+ },
1578
+ },
1579
+ {
1580
+ "input": "CREATE TABLE table2 (column1 INT, column2 INT, FOREIGN KEY (column2) REFERENCES table1(column1));",
1581
+ "output": {
1582
+ "tables": {
1583
+ "table2": {
1584
+ "table_name_original": "table2",
1585
+ "column_names_original": ["column1", "column2"],
1586
+ "column_types": ["INT", "INT"],
1587
+ "primary_keys": [],
1588
+ "foreign_keys": [
1589
+ {
1590
+ "src_cols": ["column2"],
1591
+ "dst_table": "table1",
1592
+ "dst_cols": ["column1"],
1593
+ }
1594
+ ],
1595
+ }
1596
+ }
1597
+ },
1598
+ },
1599
+ ],
1600
+ }
1601
+ )
1602
+
1603
+ response = llm_chat(system, user, task="schema")
1604
+ parsed = safe_json_loads(response)
1605
+
1606
+ if not isinstance(parsed, dict) or "tables" not in parsed:
1607
+ debug("[schema_profiling.parse_sql_file] llm also failed, returning empty")
1608
+ return {}
1609
+
1610
+ tables = parsed["tables"]
1611
+ debug(f"[schema_profiling.parse_sql_file] llm parsed: {len(tables)} tables")
1612
+
1613
+ if PolicyConfig.DEBUG:
1614
+ for tname, tinfo in tables.items():
1615
+ debug(
1616
+ f"[schema_profiling.parse_sql_file] table: {tname} cols={len(tinfo.get('column_names_original', []))}"
1617
+ )
1618
+
1619
+ return tables
1620
+
1621
+
1622
+ def _parse_sql_file_fallback(sql_content: str) -> dict[str, dict]:
1623
+ """Parse CREATE TABLE statements from SQL content using sqlglot.
1624
+
1625
+ Uses dialect from EngineConfig.TYPE (spark for Databricks, postgres for PostgreSQL).
1626
+
1627
+ Args:
1628
+
1629
+ sql_content: The full SQL file content as a string.
1630
+
1631
+ Returns:
1632
+
1633
+ Dict mapping table name to a metadata dict, or an empty dict if no tables are found.
1634
+ """
1635
+ import sqlglot
1636
+ from sqlglot import exp
1637
+
1638
+ dialect_name = "spark" if EngineConfig.TYPE == "databricks" else "postgres"
1639
+ try:
1640
+ statements = sqlglot.parse(sql_content, dialect=dialect_name)
1641
+ except Exception as e:
1642
+ debug(f"[schema_profiling._parse_sql_file_fallback] parse failed: {e}")
1643
+ return {}
1644
+
1645
+ debug(f"[schema_profiling._parse_sql_file_fallback] statements: {len(statements)}")
1646
+ tables: dict[str, dict] = {}
1647
+
1648
+ for stmt in statements:
1649
+ if not isinstance(stmt, exp.Create) or not stmt.this:
1650
+ continue
1651
+
1652
+ table_ref = stmt.this
1653
+ if hasattr(table_ref, "this") and table_ref.this is not None:
1654
+ table_name = getattr(table_ref.this, "name", None) or str(table_ref.this)
1655
+ else:
1656
+ table_name = getattr(table_ref, "name", None) or str(table_ref)
1657
+ if not table_name or "." in str(table_name):
1658
+ table_name = str(table_ref).split(".")[-1].strip("`\"")
1659
+ table_name = str(table_name).strip("`\"")
1660
+
1661
+ if not table_name:
1662
+ continue
1663
+
1664
+ debug(f"[schema_profiling._parse_sql_file_fallback] parsing: {table_name}")
1665
+
1666
+ col_block = _extract_column_block_from_create(stmt)
1667
+ columns, types, pks, fks = _parse_columns_and_constraints(col_block)
1668
+ full_stmt = stmt.sql(dialect=dialect_name)
1669
+ partition_cols = _parse_partition_columns_from_create_stmt(full_stmt)
1670
+
1671
+ debug(f"[schema_profiling._parse_sql_file_fallback] ddl_generated: {table_name}")
1672
+
1673
+ tables[table_name] = {
1674
+ "table_name_original": table_name,
1675
+ "column_names_original": columns,
1676
+ "column_types": types,
1677
+ "primary_keys": pks,
1678
+ "foreign_keys": fks,
1679
+ "partition_columns": partition_cols,
1680
+ }
1681
+
1682
+ debug(f"[schema_profiling._parse_sql_file_fallback] complete: {len(tables)} tables")
1683
+ return tables
1684
+
1685
+
1686
+ def _extract_column_block_from_create(create_expr: Any) -> str:
1687
+ """Extract the inner content of the column definition block from a sqlglot Create.
1688
+
1689
+ In sqlglot, Create.this is a Schema (table + columns); Create.expression is often None.
1690
+ Schema.expressions holds ColumnDef nodes. We join their SQL to form the column block.
1691
+
1692
+ Args:
1693
+
1694
+ create_expr: A parsed sqlglot ``Create`` expression.
1695
+
1696
+ Returns:
1697
+
1698
+ The column definition string with surrounding parentheses stripped.
1699
+ """
1700
+ schema = create_expr.this if create_expr.this else create_expr.expression
1701
+ if schema is None:
1702
+ return ""
1703
+ expressions = getattr(schema, "expressions", None)
1704
+ if expressions:
1705
+ return ", ".join(e.sql() for e in expressions if hasattr(e, "sql"))
1706
+ schema_sql = schema.sql()
1707
+ if schema_sql.startswith("(") and schema_sql.endswith(")"):
1708
+ return schema_sql[1:-1].strip()
1709
+ full_sql = create_expr.sql()
1710
+ match = re.search(r"\(([\s\S]*)\)\s*(?:PARTITIONED|STORED|LOCATION|AS|$)", full_sql)
1711
+ if match:
1712
+ return match.group(1).strip()
1713
+ return schema_sql.strip()
1714
+
1715
+
1716
+ def _split_by_top_level_comma(s: str) -> list[str]:
1717
+ """Split string by commas that are outside parentheses.
1718
+
1719
+ Preserves commas inside type definitions such as NUMERIC(4,2) or VARCHAR(10,2).
1720
+
1721
+ Args:
1722
+
1723
+ s: The string to split (e.g. a CREATE TABLE column block).
1724
+
1725
+ Returns:
1726
+
1727
+ List of non-empty segments between top-level commas.
1728
+ """
1729
+ result: list[str] = []
1730
+ current: list[str] = []
1731
+ depth = 0
1732
+ for char in s:
1733
+ if char == "(":
1734
+ depth += 1
1735
+ current.append(char)
1736
+ elif char == ")":
1737
+ depth -= 1
1738
+ current.append(char)
1739
+ elif char == "," and depth == 0:
1740
+ segment = "".join(current).strip()
1741
+ if segment:
1742
+ result.append(segment)
1743
+ current = []
1744
+ else:
1745
+ current.append(char)
1746
+ segment = "".join(current).strip()
1747
+ if segment:
1748
+ result.append(segment)
1749
+ return result
1750
+
1751
+
1752
+ def _parse_columns_and_constraints(col_block: str) -> tuple[list, list, list, list]:
1753
+ """Parse column definitions and constraints from a column block string.
1754
+
1755
+ Args:
1756
+
1757
+ col_block: The inner column definition block (contents between ``(`` and ``)``) from a CREATE TABLE statement.
1758
+
1759
+ Returns:
1760
+
1761
+ Tuple of ``(columns, types, pks, fks)``.
1762
+ """
1763
+ lines = [line.strip() for line in _split_by_top_level_comma(col_block)]
1764
+
1765
+ columns = []
1766
+ types = []
1767
+ pks = []
1768
+ fks = []
1769
+
1770
+ for line in lines:
1771
+ line_upper = line.upper()
1772
+
1773
+ if line_upper.startswith("PRIMARY KEY"):
1774
+ pk_cols = _extract_pk_columns(line)
1775
+ pks.extend(pk_cols)
1776
+ continue
1777
+
1778
+ if line_upper.startswith("FOREIGN KEY"):
1779
+ fk_def = _extract_fk_definition(line)
1780
+ if fk_def:
1781
+ fks.append(fk_def)
1782
+ continue
1783
+
1784
+ parts = line.split()
1785
+ if len(parts) < 2:
1786
+ continue
1787
+
1788
+ col_name = parts[0].strip("`").strip('"')
1789
+ col_type = parts[1]
1790
+
1791
+ columns.append(col_name)
1792
+ types.append(col_type)
1793
+
1794
+ if "PRIMARY KEY" in line_upper:
1795
+ pks.append(col_name)
1796
+
1797
+ return columns, types, pks, fks
1798
+
1799
+
1800
+ def _extract_pk_columns(line: str) -> list[str]:
1801
+ """Extract column names from a PRIMARY KEY (col1, col2) definition line.
1802
+
1803
+ Args:
1804
+
1805
+ line: A single DDL line starting with ``PRIMARY KEY``.
1806
+
1807
+ Returns:
1808
+
1809
+ List of unquoted column name strings.
1810
+ """
1811
+ match = re.search(r"PRIMARY\s+KEY\s*\(([^)]+)\)", line, re.IGNORECASE)
1812
+ if match:
1813
+ return [c.strip().strip("`").strip('"') for c in match.group(1).split(",")]
1814
+ return []
1815
+
1816
+
1817
+ def _extract_fk_definition(line: str) -> dict:
1818
+ """Extract a FOREIGN KEY definition from a DDL constraint line.
1819
+
1820
+ Args:
1821
+
1822
+ line: A single DDL line starting with ``FOREIGN KEY``.
1823
+
1824
+ Returns:
1825
+
1826
+ Dict with keys ``src_cols``, ``dst_table``, and ``dst_cols``, or ``None`` if the pattern does not match.
1827
+ """
1828
+ match = re.search(
1829
+ r"FOREIGN\s+KEY\s*\(([^)]+)\)\s+REFERENCES\s+(\w+)\s*\(([^)]+)\)",
1830
+ line,
1831
+ re.IGNORECASE,
1832
+ )
1833
+ if match:
1834
+ return {
1835
+ "src_cols": [c.strip().strip("`").strip('"') for c in match.group(1).split(",")],
1836
+ "dst_table": match.group(2).strip("`").strip('"'),
1837
+ "dst_cols": [c.strip().strip("`").strip('"') for c in match.group(3).split(",")],
1838
+ }
1839
+ return None
1840
+
1841
+
1842
+ def inject_partition_filters(
1843
+ spark_sql: str,
1844
+ schema: SchemaGraph,
1845
+ intent: RuntimeIntent,
1846
+ ) -> str:
1847
+ """Ensure partition column predicates from intent filters are in the WHERE clause.
1848
+
1849
+ Builds predicates from filters_param for partition columns, formats them for
1850
+ Spark (backticks), and appends any missing predicates to the WHERE clause.
1851
+
1852
+ Args:
1853
+
1854
+ spark_sql: Spark SQL string (with params already substituted).
1855
+
1856
+ schema: SchemaGraph with partition_columns per table.
1857
+
1858
+ intent: RuntimeIntent with filters_param and param_values.
1859
+
1860
+ Returns:
1861
+
1862
+ Spark SQL with partition predicates ensured in WHERE.
1863
+ """
1864
+ params = dict(flatten_param_values(intent))
1865
+ predicates = _build_partition_predicates(schema, intent, params)
1866
+ if not predicates:
1867
+ return spark_sql
1868
+ combined = " AND ".join(predicates)
1869
+ if _predicate_already_in_sql(spark_sql, combined, predicates):
1870
+ return spark_sql
1871
+ return _append_to_where(spark_sql, combined)
1872
+
1873
+
1874
+ def _build_partition_predicates(
1875
+ schema: SchemaGraph,
1876
+ intent: RuntimeIntent,
1877
+ params: dict[str, Any],
1878
+ ) -> list[str]:
1879
+ """Build Spark-formatted partition predicates from filters_param."""
1880
+ tables = intent.tables or []
1881
+ filters = intent.filters_param or []
1882
+ if not tables or not filters:
1883
+ return []
1884
+
1885
+ grouped: dict[tuple[str, str], list[FilterParam]] = {}
1886
+
1887
+ for table_name in tables:
1888
+ table_meta = schema.tables.get(table_name)
1889
+ if not table_meta or not table_meta.partition_columns:
1890
+ continue
1891
+ part_cols_lower = {c.lower(): c for c in table_meta.partition_columns}
1892
+
1893
+ for fp in filters:
1894
+ col_ref = _get_column_ref(fp.left_expr)
1895
+ if not col_ref:
1896
+ continue
1897
+ table_part, col_part = col_ref
1898
+ col_lower = col_part.lower() if col_part else ""
1899
+ if col_lower not in part_cols_lower:
1900
+ continue
1901
+ actual_col = part_cols_lower[col_lower]
1902
+ table_for_pred = table_part or (tables[0] if tables else "")
1903
+ if table_part and table_part.lower() not in {t.lower() for t in tables}:
1904
+ continue
1905
+ key = (table_for_pred.lower(), actual_col.lower())
1906
+ grouped.setdefault(key, []).append(fp)
1907
+
1908
+ result: list[str] = []
1909
+ for (table_key, col_key), fps in grouped.items():
1910
+ table_name = next(
1911
+ (t for t in tables if t.lower() == table_key), tables[0] if tables else ""
1912
+ )
1913
+ table_meta = schema.tables.get(table_name)
1914
+ col_name = (
1915
+ next(
1916
+ (c for c in table_meta.partition_columns if c.lower() == col_key),
1917
+ col_key,
1918
+ )
1919
+ if table_meta
1920
+ else col_key
1921
+ )
1922
+ pred = _format_grouped_predicate(table_name, col_name, fps, params)
1923
+ if pred:
1924
+ result.append(pred)
1925
+
1926
+ return result
1927
+
1928
+
1929
+ def _format_grouped_predicate(
1930
+ table: str,
1931
+ col: str,
1932
+ fps: list[FilterParam],
1933
+ params: dict[str, Any],
1934
+ ) -> str | None:
1935
+ """Format a grouped predicate for IN (= with OR) or BETWEEN (>= and <=)."""
1936
+ qual = f"`{table}`.`{col}`"
1937
+ ops = {fp.op for fp in fps}
1938
+ if ops <= {"="} and len(fps) > 1:
1939
+ parts = []
1940
+ for fp in fps:
1941
+ val = fp.param_key and params.get(fp.param_key) or fp.raw_value
1942
+ if val is not None:
1943
+ parts.append(_format_partition_literal(val))
1944
+ if parts:
1945
+ return f"{qual} IN ({', '.join(parts)})"
1946
+ return None
1947
+ if ops <= {">=", "<="} and len(fps) == 2:
1948
+ ge = next((f for f in fps if f.op == ">="), None)
1949
+ le = next((f for f in fps if f.op == "<="), None)
1950
+ if ge and le:
1951
+ v1 = ge.param_key and params.get(ge.param_key) or ge.raw_value
1952
+ v2 = le.param_key and params.get(le.param_key) or le.raw_value
1953
+ if v1 is not None and v2 is not None:
1954
+ return f"({qual} >= {_format_partition_literal(v1)} AND {qual} <= {_format_partition_literal(v2)})"
1955
+ return None
1956
+ if len(fps) == 1:
1957
+ return _format_partition_predicate(table, col, fps[0], params)
1958
+ return None
1959
+
1960
+
1961
+ def _get_column_ref(expr: NormalizedExpr) -> tuple[str | None, str | None]:
1962
+ """Extract (table, column) from primary_term like 'table.col' or 'col'."""
1963
+ term = (expr.primary_term or "").strip()
1964
+ if not term:
1965
+ return None, None
1966
+ if "." in term:
1967
+ parts = term.rsplit(".", 1)
1968
+ return parts[0].strip() or None, parts[1].strip() or None
1969
+ return None, term
1970
+
1971
+
1972
+ def _format_partition_predicate(
1973
+ table: str,
1974
+ col: str,
1975
+ fp: FilterParam,
1976
+ params: dict[str, Any],
1977
+ ) -> str | None:
1978
+ """Format a single partition predicate for Spark SQL."""
1979
+ qual = f"`{table}`.`{col}`"
1980
+ val = fp.param_key and params.get(fp.param_key)
1981
+ if val is None and fp.raw_value is not None:
1982
+ val = fp.raw_value
1983
+
1984
+ if fp.op == "=":
1985
+ if val is None:
1986
+ return None
1987
+ lit = _format_partition_literal(val)
1988
+ return f"{qual} = {lit}"
1989
+ if fp.op in (">=", "<=", ">", "<"):
1990
+ if val is None:
1991
+ return None
1992
+ lit = _format_partition_literal(val)
1993
+ return f"{qual} {fp.op} {lit}"
1994
+ if fp.op == "in":
1995
+ if val is None:
1996
+ return None
1997
+ if isinstance(val, list):
1998
+ parts = [_format_partition_literal(v) for v in val]
1999
+ return f"{qual} IN ({', '.join(parts)})"
2000
+ lit = _format_partition_literal(val)
2001
+ return f"{qual} IN ({lit})"
2002
+ return None
2003
+
2004
+
2005
+ def _format_partition_literal(val: Any) -> str:
2006
+ """Format a value as a Spark SQL literal."""
2007
+ if isinstance(val, str):
2008
+ escaped = val.replace("\\", "\\\\").replace("'", "\\'")
2009
+ return f"'{escaped}'"
2010
+ if isinstance(val, bool):
2011
+ return "TRUE" if val else "FALSE"
2012
+ if isinstance(val, (int, float)):
2013
+ return str(val)
2014
+ return f"'{str(val)}'"
2015
+
2016
+
2017
+ def _predicate_already_in_sql(
2018
+ sql: str,
2019
+ combined: str,
2020
+ predicates: list[str],
2021
+ ) -> bool:
2022
+ """Check if partition predicates are already present in the SQL."""
2023
+ sql_norm = sql.replace(" ", "").replace("\n", " ").lower()
2024
+ for pred in predicates:
2025
+ pred_norm = pred.replace(" ", "").lower()
2026
+ if pred_norm not in sql_norm:
2027
+ return False
2028
+ return True
2029
+
2030
+
2031
+ def _append_to_where(sql: str, predicate: str) -> str:
2032
+ """Append predicate to the WHERE clause, creating WHERE if absent."""
2033
+ where_match = re.search(r"\bWHERE\b", sql, re.IGNORECASE)
2034
+ if where_match:
2035
+ insert_pos = where_match.end()
2036
+ next_clause = re.search(
2037
+ r"\b(GROUP\s+BY|ORDER\s+BY|LIMIT|HAVING)\b",
2038
+ sql[insert_pos:],
2039
+ re.IGNORECASE,
2040
+ )
2041
+ if next_clause:
2042
+ end_pos = insert_pos + next_clause.start()
2043
+ clause = sql[insert_pos:end_pos].rstrip()
2044
+ new_clause = clause + " AND " + predicate
2045
+ return sql[:insert_pos] + new_clause + sql[end_pos:]
2046
+ return sql.rstrip() + " AND " + predicate
2047
+
2048
+ from_match = re.search(r"\bFROM\b", sql, re.IGNORECASE)
2049
+ if not from_match:
2050
+ return sql
2051
+ group_match = re.search(
2052
+ r"\bGROUP\s+BY\b",
2053
+ sql[from_match.end() :],
2054
+ re.IGNORECASE,
2055
+ )
2056
+ order_match = re.search(
2057
+ r"\bORDER\s+BY\b",
2058
+ sql[from_match.end() :],
2059
+ re.IGNORECASE,
2060
+ )
2061
+ limit_match = re.search(r"\bLIMIT\b", sql[from_match.end() :], re.IGNORECASE)
2062
+ having_match = re.search(r"\bHAVING\b", sql[from_match.end() :], re.IGNORECASE)
2063
+
2064
+ insert_pos = len(sql)
2065
+ for m in [group_match, order_match, limit_match, having_match]:
2066
+ if m:
2067
+ pos = from_match.end() + m.start()
2068
+ if pos < insert_pos:
2069
+ insert_pos = pos
2070
+
2071
+ before = sql[:insert_pos].rstrip()
2072
+ after = sql[insert_pos:].lstrip()
2073
+ if after:
2074
+ return before + " WHERE " + predicate + " " + after
2075
+ return before + " WHERE " + predicate