aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
text2sql/sql_gen.py ADDED
@@ -0,0 +1,1537 @@
1
+ """SQL generation, join path resolution, and canonical normalization.
2
+
3
+ Responsible for building the SQL prompt and repair prompt sent to the LLM, enumerating and ranking all valid FK-based join paths between a set of tables, normalising generated SQL to a canonical JOIN order and predicate form, and validating that the LLM-chosen join candidate matches the actual SQL structure. Also provides utilities for rendering intent expressions as SQL fragments used in the generation prompt and for post-hoc alias injection.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from typing import Any
10
+
11
+ from .config import SCALAR_FUNCTIONS_LEADING_ARG, EngineConfig, PolicyConfig
12
+ from .contracts_base import SchemaGraph
13
+ from .contracts_core import FilterParam, MulGroup, NormalizedExpr, RuntimeCteStep, RuntimeIntent, SelectCol
14
+ from .core_utils import debug, llm_json, llm_sql_with_join, stable_json
15
+ from .dialect import get_dialect, render_date_diff_expr, render_date_window_expr
16
+
17
+
18
+ def cte_to_intent_for_ranking(cte: RuntimeCteStep) -> RuntimeIntent:
19
+ """Build a synthetic ``RuntimeIntent`` from ``RuntimeCteStep`` for CTE join ranking.
20
+
21
+ Args:
22
+
23
+ cte: The ``RuntimeCteStep`` whose tables and intent fields should be promoted.
24
+
25
+ Returns:
26
+
27
+ A ``RuntimeIntent`` with the CTE's tables, grain, select/group/order/filter/having columns, param values, column map, and limit, suitable for join scoring.
28
+ """
29
+ return RuntimeIntent(
30
+ tables=cte.tables,
31
+ grain=cte.grain,
32
+ select_cols=cte.select_cols,
33
+ group_by_cols=cte.group_by_cols,
34
+ order_by_cols=cte.order_by_cols,
35
+ filters_param=cte.filters_param,
36
+ having_param=cte.having_param,
37
+ param_values=cte.param_values,
38
+ column_map=cte.column_map,
39
+ limit=cte.limit,
40
+ cte_steps=[],
41
+ )
42
+
43
+
44
+ SQL_REPAIR_SYSTEM_PROMPT = """You are a deterministic SQL repair assistant.
45
+
46
+ Output requirements:
47
+ - Output ONLY valid JSON that matches the specified output_schema.
48
+ - Do NOT include markdown, explanations, or commentary.
49
+ - Identical inputs must produce identical outputs.
50
+
51
+ Repair guidelines:
52
+ - Make minimal changes to fix the specific error.
53
+ - Do not change query logic unless required to fix the error.
54
+ - Preserve JOIN structure unless the error requires changing it.
55
+ - Maintain existing filters, aggregations, and ordering when possible.
56
+ """
57
+
58
+
59
+ def join_candidate_map(join_hints: dict[str, Any]) -> dict[str, list[str]]:
60
+ """Build map from candidate ID to join path signature.
61
+
62
+ Args:
63
+
64
+ join_hints: The join hints dict produced by ``join_hints_multi``.
65
+
66
+ Returns:
67
+
68
+ Dict mapping ``candidate_id`` to list of join path signature strings.
69
+ """
70
+ out: dict[str, list[str]] = {}
71
+ for c in join_hints.get("candidates", []):
72
+ cid = c.get("candidate_id")
73
+ sig = c.get("join_path_signature")
74
+ if isinstance(cid, str) and isinstance(sig, list):
75
+ out[cid] = [str(x) for x in sig]
76
+ return out
77
+
78
+
79
+ def _analyze_join_topology(sig: list[str]) -> tuple[str, str, list[str]]:
80
+ """Analyze join signature to determine topology type, hub table, and leaf tables.
81
+
82
+ Args:
83
+
84
+ sig: List of join path signature strings.
85
+
86
+ Returns:
87
+
88
+ Tuple of ``(topology_type, anchor_table, leaf_tables)`` where ``topology_type`` is one of ``"none"``, ``"linear"``, ``"star"``, or ``"tree"``; ``anchor_table`` is the canonical root; and ``leaf_tables`` is the list of endpoint tables.
89
+ """
90
+ if not sig:
91
+ return ("none", "", [])
92
+ table_counts: dict[str, int] = {}
93
+ for item in sig:
94
+ if "->" not in item:
95
+ continue
96
+ left, right = item.split("->", 1)
97
+ left_table = left.split(".")[0].strip()
98
+ right_table = right.split(".")[0].strip()
99
+ table_counts[left_table] = table_counts.get(left_table, 0) + 1
100
+ table_counts[right_table] = table_counts.get(right_table, 0) + 1
101
+ if not table_counts:
102
+ return ("none", "", [])
103
+ leaves = sorted([t for t, c in table_counts.items() if c == 1])
104
+ hubs = sorted(
105
+ [t for t, c in table_counts.items() if c > 1],
106
+ key=lambda t: (-table_counts[t], t),
107
+ )
108
+ if len(leaves) == 2 and len(hubs) == len(table_counts) - 2:
109
+ return ("linear", min(leaves), leaves)
110
+ if len(hubs) == 1:
111
+ return ("star", hubs[0], leaves)
112
+ if hubs:
113
+ return ("tree", hubs[0], leaves)
114
+ return ("linear", min(table_counts.keys()), list(table_counts.keys()))
115
+
116
+
117
+ def reorder_sql_joins_canonical(sql: str, join_sig: list[str]) -> str:
118
+ """Reorder SQL FROM clause to canonical form based on join topology.
119
+
120
+ Args:
121
+
122
+ sql: The SQL string whose FROM/JOIN clause should be reordered.
123
+
124
+ join_sig: The join path signature that defines the canonical order.
125
+
126
+ Returns:
127
+
128
+ SQL string with the FROM clause reordered; returns the original SQL unchanged if topology is ``"none"`` or the SQL cannot be parsed.
129
+ """
130
+ if not join_sig or len(join_sig) == 0:
131
+ return sql
132
+ topology_type, anchor, leaves = _analyze_join_topology(join_sig)
133
+ if topology_type == "none":
134
+ return sql
135
+ from_match = re.search(r"\bFROM\s+(\w+)", sql, re.IGNORECASE)
136
+ if not from_match:
137
+ return sql
138
+ current_first_table = from_match.group(1).lower()
139
+ if topology_type == "linear":
140
+ if current_first_table == anchor.lower():
141
+ return sql
142
+ other_endpoint = [leaf for leaf in leaves if leaf.lower() != anchor.lower()]
143
+ if not other_endpoint or current_first_table != other_endpoint[0].lower():
144
+ return sql
145
+ join_pattern = re.compile(
146
+ r"\bFROM\s+(\w+)\s+((?:(?:INNER\s+)?JOIN\s+\w+\s+ON\s+[^)]+?(?=\s+(?:INNER\s+)?JOIN|\s+WHERE|\s+GROUP|\s+ORDER|\s+LIMIT|\s+HAVING|$))+)",
147
+ re.IGNORECASE | re.DOTALL,
148
+ )
149
+ match = join_pattern.search(sql)
150
+ if not match:
151
+ return sql
152
+ first_table = match.group(1)
153
+ joins_block = match.group(2)
154
+ join_clauses = re.findall(
155
+ r"((?:INNER\s+)?JOIN\s+(\w+)\s+ON\s+([^)]+?)(?=\s+(?:INNER\s+)?JOIN|\s*$))",
156
+ joins_block,
157
+ re.IGNORECASE | re.DOTALL,
158
+ )
159
+ if not join_clauses:
160
+ return sql
161
+ tables_in_order = [first_table]
162
+ for jc in join_clauses:
163
+ tables_in_order.append(jc[1])
164
+ reversed_tables = list(reversed(tables_in_order))
165
+ reversed_on_clauses = []
166
+ for jc in reversed(join_clauses):
167
+ on_clause = jc[2].strip()
168
+ reversed_on_clauses.append(on_clause)
169
+ new_from = f"FROM {reversed_tables[0]}"
170
+ for i, tbl in enumerate(reversed_tables[1:]):
171
+ new_from += f" JOIN {tbl} ON {reversed_on_clauses[i]}"
172
+ original_from_end = match.end()
173
+ original_from_start = match.start()
174
+ new_sql = sql[:original_from_start] + new_from + sql[original_from_end:]
175
+ debug(f"[sql_gen.reorder_sql_joins_canonical] linear reordered: {current_first_table} -> {reversed_tables[0]}")
176
+ return new_sql
177
+ if current_first_table == anchor.lower():
178
+ join_pattern = re.compile(
179
+ r"\bFROM\s+(\w+)\s+((?:(?:INNER\s+)?JOIN\s+\w+\s+ON\s+[^)]+?(?=\s+(?:INNER\s+)?JOIN|\s+WHERE|\s+GROUP|\s+ORDER|\s+LIMIT|\s+HAVING|$))+)",
180
+ re.IGNORECASE | re.DOTALL,
181
+ )
182
+ match = join_pattern.search(sql)
183
+ if not match:
184
+ return sql
185
+ joins_block = match.group(2)
186
+ join_clauses = re.findall(
187
+ r"((?:INNER\s+)?JOIN\s+(\w+)\s+ON\s+([^)]+?)(?=\s+(?:INNER\s+)?JOIN|\s*$))",
188
+ joins_block,
189
+ re.IGNORECASE | re.DOTALL,
190
+ )
191
+ if not join_clauses:
192
+ return sql
193
+ sorted_joins = sorted(join_clauses, key=lambda jc: jc[1].lower())
194
+ new_from = f"FROM {anchor}"
195
+ for jc in sorted_joins:
196
+ new_from += f" JOIN {jc[1]} ON {jc[2].strip()}"
197
+ original_from_end = match.end()
198
+ original_from_start = match.start()
199
+ new_sql = sql[:original_from_start] + new_from + sql[original_from_end:]
200
+ debug("[sql_gen.reorder_sql_joins_canonical] star/tree branches sorted alphabetically")
201
+ return new_sql
202
+ debug(
203
+ f"[sql_gen.reorder_sql_joins_canonical] star/tree topology but FROM table {current_first_table} != anchor {anchor}"
204
+ )
205
+ return sql
206
+
207
+
208
+ JOIN_PLACEHOLDER = "-- <JOIN: integrate from join candidates>"
209
+
210
+
211
+ def _wrap_for_case_insensitive(expr: str, dialect_type: str) -> str:
212
+ """Wrap expression for case-insensitive string comparison.
213
+
214
+ On Databricks, uses LOWER(TRIM(expr)) to handle whitespace and
215
+ collation. On other dialects, uses LOWER(expr).
216
+ """
217
+ if dialect_type == "databricks":
218
+ return f"LOWER(TRIM({expr}))"
219
+ return f"LOWER({expr})"
220
+
221
+
222
+ def _join_clause_from_signature(signature: list[str], from_table: str = "") -> str:
223
+ """Build JOIN clause text from a join path signature.
224
+
225
+ Each segment is "src_tbl.col->dst_tbl.col". Tracks tables already in
226
+ the chain to avoid duplicate JOINs (e.g. when two edges target the
227
+ same table). When the target is already present, adds the source
228
+ table instead.
229
+ """
230
+ if not signature:
231
+ return ""
232
+ chain: set[str] = {from_table.lower()} if from_table else set()
233
+ parts: list[str] = []
234
+ for seg in signature:
235
+ seg = seg.strip()
236
+ if "->" not in seg:
237
+ continue
238
+ left_part, right_part = seg.split("->", 1)
239
+ left_part = left_part.strip()
240
+ right_part = right_part.strip()
241
+ if "." not in left_part or "." not in right_part:
242
+ continue
243
+ left_tbl, left_cols = left_part.split(".", 1)
244
+ right_tbl, right_cols = right_part.split(".", 1)
245
+ left_col_list = [c.strip() for c in left_cols.split(",")]
246
+ right_col_list = [c.strip() for c in right_cols.split(",")]
247
+ on_terms = [
248
+ f"{left_tbl}.{lc} = {right_tbl}.{rc}"
249
+ for lc, rc in zip(left_col_list, right_col_list, strict=False)
250
+ ]
251
+ if not on_terms:
252
+ continue
253
+ right_tbl_lower = right_tbl.lower()
254
+ left_tbl_lower = left_tbl.lower()
255
+ if right_tbl_lower in chain:
256
+ join_tbl = left_tbl
257
+ chain.add(left_tbl_lower)
258
+ else:
259
+ join_tbl = right_tbl
260
+ chain.add(right_tbl_lower)
261
+ parts.append(f" JOIN {join_tbl} ON " + " AND ".join(on_terms))
262
+ return "".join(parts)
263
+
264
+
265
+ def _orient_join_sig_for_from(
266
+ sig: list[str],
267
+ from_table: str,
268
+ ) -> list[str]:
269
+ """Reorient join segments so that no target duplicates the FROM table.
270
+
271
+ When the right-hand (target) table of a segment equals the current
272
+ FROM table, the segment is flipped so the other table becomes the
273
+ JOIN target instead. This prevents ``FROM t JOIN t`` self-join
274
+ artefacts that occur when ``tables[0]`` in the intent happens to
275
+ sit on the target side of the join signature.
276
+ """
277
+ if not from_table:
278
+ return sig
279
+ oriented: list[str] = []
280
+ for seg in sig:
281
+ if "->" not in seg:
282
+ oriented.append(seg)
283
+ continue
284
+ left, right = seg.split("->", 1)
285
+ right_tbl = right.split(".")[0].strip().lower()
286
+ if right_tbl == from_table:
287
+ oriented.append(f"{right.strip()}->{left.strip()}")
288
+ else:
289
+ oriented.append(seg)
290
+ return oriented
291
+
292
+
293
+ def inject_join_into_deterministic_sql(
294
+ det_sql: str,
295
+ join_sigs_ordered: list[list[str]],
296
+ ) -> str:
297
+ """Replace each JOIN placeholder in deterministic SQL with JOIN clause from signatures.
298
+
299
+ Placeholders are replaced in order: first occurrence with
300
+ ``join_sigs_ordered[0]``, etc. Before building each JOIN clause
301
+ the signature is oriented so that the target table does not
302
+ duplicate the current FROM table.
303
+ """
304
+ if not join_sigs_ordered:
305
+ return det_sql
306
+ result = det_sql
307
+ for sig in join_sigs_ordered:
308
+ if JOIN_PLACEHOLDER not in result:
309
+ break
310
+ from_match = re.search(r"\bFROM\s+(\w+)", result, re.IGNORECASE)
311
+ from_tbl = from_match.group(1) if from_match else ""
312
+ oriented = _orient_join_sig_for_from(sig, from_tbl.lower())
313
+ join_clause = _join_clause_from_signature(oriented, from_tbl)
314
+ result = result.replace(JOIN_PLACEHOLDER, join_clause.strip(), 1)
315
+ result = re.sub(r"\n\s*-- <JOIN[^>]*>\s*", "\n", result)
316
+ return result
317
+
318
+
319
+ def normalize_where_having_predicates(sql: str) -> str:
320
+ """Normalize WHERE and HAVING predicates to put column references on the left.
321
+
322
+ Swaps predicates of the form ``:param op table.column`` to ``table.column op :param`` so that column references always appear on the left-hand side of comparison operators.
323
+
324
+ Args:
325
+
326
+ sql: The SQL string to normalise.
327
+
328
+ Returns:
329
+
330
+ SQL string with swapped predicates in WHERE and HAVING clauses.
331
+ """
332
+
333
+ def swap_predicate(match):
334
+ full = match.group(0)
335
+ left = match.group(1).strip()
336
+ op = match.group(2).strip()
337
+ right = match.group(3).strip()
338
+
339
+ left_is_param = left.startswith(":") or left.startswith("'") or left.startswith('"') or left[0].isdigit()
340
+ right_is_col = "." in right and not (right.startswith(":") or right.startswith("'") or right.startswith('"'))
341
+
342
+ if left_is_param and right_is_col:
343
+ return f"{right} {op} {left}"
344
+ return full
345
+
346
+ pattern = r"([:\w.]+|'[^']*'|\"[^\"]*\")\s*(=|!=|<>|<=|>=|<|>)\s+([:\w.]+|'[^']*'|\"[^\"]*\")"
347
+
348
+ where_match = re.search(r"\bWHERE\b", sql, re.IGNORECASE)
349
+ if where_match:
350
+ where_start = where_match.end()
351
+ next_clause = re.search(
352
+ r"\b(GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT)\b",
353
+ sql[where_start:],
354
+ re.IGNORECASE,
355
+ )
356
+ where_end = where_start + next_clause.start() if next_clause else len(sql)
357
+ where_clause = sql[where_start:where_end]
358
+ normalized_where = re.sub(pattern, swap_predicate, where_clause)
359
+ sql = sql[:where_start] + normalized_where + sql[where_end:]
360
+
361
+ having_match = re.search(r"\bHAVING\b", sql, re.IGNORECASE)
362
+ if having_match:
363
+ having_start = having_match.end()
364
+ next_clause = re.search(r"\b(ORDER\s+BY|LIMIT)\b", sql[having_start:], re.IGNORECASE)
365
+ having_end = having_start + next_clause.start() if next_clause else len(sql)
366
+ having_clause = sql[having_start:having_end]
367
+ normalized_having = re.sub(pattern, swap_predicate, having_clause)
368
+ sql = sql[:having_start] + normalized_having + sql[having_end:]
369
+
370
+ return sql
371
+
372
+
373
+ def normalize_cte_sql(sql: str, cte_join_sigs: dict[str, list[str]]) -> str:
374
+ """Normalize CTE bodies with join reordering and predicate normalization.
375
+
376
+ Args:
377
+
378
+ sql: The full SQL string that may contain WITH/CTE clauses.
379
+
380
+ cte_join_sigs: Dict mapping CTE name to join path signature for reordering.
381
+
382
+ Returns:
383
+
384
+ SQL string with each CTE body's FROM clause reordered and WHERE/HAVING predicates normalised.
385
+ """
386
+ dialect = get_dialect()
387
+ cte_bodies = dialect.extract_cte_bodies(sql)
388
+ if not cte_bodies:
389
+ return sql
390
+
391
+ for cte_name, cte_body in cte_bodies.items():
392
+ join_sig = cte_join_sigs.get(cte_name, [])
393
+ normalized_body = cte_body
394
+ if join_sig:
395
+ normalized_body = reorder_sql_joins_canonical(normalized_body, join_sig)
396
+ normalized_body = normalize_where_having_predicates(normalized_body)
397
+ if normalized_body != cte_body:
398
+ old_cte = f"{cte_name} AS ({cte_body})"
399
+ new_cte = f"{cte_name} AS ({normalized_body})"
400
+ sql = sql.replace(old_cte, new_cte)
401
+
402
+ debug(f"[sql_gen.normalize_cte_sql] normalized {len(cte_bodies)} CTEs")
403
+ return sql
404
+
405
+
406
+ def _join_path_signature_for_path(path: list[dict[str, Any]]) -> list[str]:
407
+ """Generate signature strings for a join path.
408
+
409
+ Args:
410
+
411
+ path: List of edge dicts, each with ``src_table``, ``src_cols``, ``dst_table``, and ``dst_cols`` keys.
412
+
413
+ Returns:
414
+
415
+ List of strings in the form ``"src_table.col1,col2->dst_table.col3,col4"``.
416
+ """
417
+ sig = []
418
+ for e in path:
419
+ sig.append(f"{e['src_table']}.{','.join(e['src_cols'])}->{e['dst_table']}.{','.join(e['dst_cols'])}")
420
+ return sig
421
+
422
+
423
+ def _candidate_join_paths_for_tables(schema: SchemaGraph, tables: list[str]) -> list[list[dict[str, Any]]]:
424
+ """Compute all candidate join paths for a set of tables by trying every table as root.
425
+
426
+ First attempts direct paths (no bridge tables). Falls back to bridge-table paths if no direct paths are found.
427
+
428
+ Args:
429
+
430
+ schema: The schema graph containing pre-computed join paths.
431
+
432
+ tables: List of table names that must all be reachable in each candidate.
433
+
434
+ Returns:
435
+
436
+ List of join paths, each a list of edge dicts with source and destination table and column keys. Returns ``[[]]`` for single-table queries.
437
+ """
438
+ tables = sorted(set(tables))
439
+ if len(tables) < 2:
440
+ return [[]]
441
+
442
+ def uniq_edges(edges: list[dict[str, Any]]) -> list[dict[str, Any]]:
443
+ seen: set = set()
444
+ out: list[dict[str, Any]] = []
445
+ for e in edges:
446
+ pair = (
447
+ (e["src_table"], tuple(e["src_cols"])),
448
+ (e["dst_table"], tuple(e["dst_cols"])),
449
+ )
450
+ canonical = tuple(sorted(pair))
451
+ if canonical in seen:
452
+ continue
453
+ seen.add(canonical)
454
+ out.append(e)
455
+ return out
456
+
457
+ table_set = set(tables)
458
+
459
+ def _edges_cover_tables(edges: list[dict[str, Any]], root: str) -> set[str]:
460
+ covered = {root}
461
+ for e in edges:
462
+ covered.add(e["src_table"])
463
+ covered.add(e["dst_table"])
464
+ return covered
465
+
466
+ def _merge_paths_minimal(
467
+ root: str, others: list[str], allow_bridges: bool
468
+ ) -> list[list[dict[str, Any]]]:
469
+ covered: set[str] = {root}
470
+ merged: list[dict[str, Any]] = []
471
+ for target in others:
472
+ if target in covered:
473
+ continue
474
+ paths = schema.join_paths_multi.get(root, {}).get(target, [])
475
+ if not paths:
476
+ continue
477
+ best: list[dict[str, Any]] | None = None
478
+ for p in paths:
479
+ if not p:
480
+ continue
481
+ path_tables = _edges_cover_tables(p, root)
482
+ if target not in path_tables:
483
+ continue
484
+ if not allow_bridges and not path_tables <= table_set:
485
+ continue
486
+ if best is None or len(p) < len(best):
487
+ best = p
488
+ if best:
489
+ for e in best:
490
+ if e not in merged:
491
+ merged.append(e)
492
+ covered = _edges_cover_tables(merged, root)
493
+ return [merged] if merged else []
494
+
495
+ def _collect(allow_bridges: bool) -> dict[tuple, list[dict[str, Any]]]:
496
+ candidates: dict[tuple, list[dict[str, Any]]] = {}
497
+ for root in tables:
498
+ others = [t for t in tables if t != root]
499
+ for merged in _merge_paths_minimal(root, others, allow_bridges):
500
+ deduped = uniq_edges(merged)
501
+ edge_tables = (
502
+ {root}
503
+ | {e["src_table"] for e in deduped}
504
+ | {e["dst_table"] for e in deduped}
505
+ )
506
+ if not table_set <= edge_tables:
507
+ continue
508
+ if not allow_bridges and not edge_tables <= table_set:
509
+ continue
510
+ sig = tuple(_join_path_signature_for_path(deduped))
511
+ if sig not in candidates:
512
+ candidates[sig] = deduped
513
+ return candidates
514
+
515
+ all_candidates = _collect(allow_bridges=False)
516
+ if not all_candidates:
517
+ all_candidates = _collect(allow_bridges=True)
518
+ if all_candidates:
519
+ debug(
520
+ f"[sql_gen.candidate_join_paths_for_tables] no direct paths, found {len(all_candidates)} bridge paths"
521
+ )
522
+
523
+ res = list(all_candidates.values())
524
+ res.sort(key=lambda m: (len(m), tuple(_join_path_signature_for_path(m))))
525
+ return res
526
+
527
+
528
+ def _score_join_path(edges: list[dict[str, Any]], intent: RuntimeIntent, schema: SchemaGraph) -> float:
529
+ """Score a join path based on FK direction, intent alignment, and path characteristics.
530
+
531
+ Higher scores indicate more semantically appropriate join paths. Points are awarded for forward FK direction, shorter paths, joins that connect to filtered or grouped columns, and FACT→DIMENSION relationships.
532
+
533
+ Args:
534
+
535
+ edges: List of join edge dicts for the candidate path.
536
+
537
+ intent: The ``RuntimeIntent`` providing filter, group-by, and aggregation context.
538
+
539
+ schema: The schema graph for table role and FK metadata.
540
+
541
+ Returns:
542
+
543
+ Float score; higher is better.
544
+ """
545
+ score = 0.0
546
+
547
+ filter_columns = set()
548
+ for fp in intent.filters_param or []:
549
+ pcol = fp.left_expr.primary_column
550
+ col_parts = pcol.split(".")
551
+ if len(col_parts) == 2:
552
+ filter_columns.add((col_parts[0], col_parts[1]))
553
+
554
+ groupby_columns = set()
555
+ for gb in intent.group_by_cols or []:
556
+ col_parts = gb.primary_column.split(".")
557
+ if len(col_parts) == 2:
558
+ groupby_columns.add((col_parts[0], col_parts[1]))
559
+
560
+ agg_tables = set()
561
+ for sc in intent.select_cols or []:
562
+ if sc.is_aggregated:
563
+ pcol = sc.expr.primary_column
564
+ col_parts = pcol.split(".")
565
+ if len(col_parts) == 2:
566
+ agg_tables.add(col_parts[0])
567
+
568
+ path_tables = set()
569
+ for edge in edges:
570
+ path_tables.add(edge["src_table"])
571
+ path_tables.add(edge["dst_table"])
572
+
573
+ score += max(20 - (len(edges) * 3), 0)
574
+
575
+ for edge in edges:
576
+ src_table = edge["src_table"]
577
+ dst_table = edge["dst_table"]
578
+ src_cols = edge["src_cols"]
579
+ dst_cols = edge["dst_cols"]
580
+
581
+ src_meta = schema.tables.get(src_table)
582
+ dst_meta = schema.tables.get(dst_table)
583
+
584
+ if not src_meta or not dst_meta:
585
+ continue
586
+
587
+ is_forward = False
588
+ for fk in src_meta.foreign_keys:
589
+ if fk.dst_table == dst_table and set(fk.src_cols) == set(src_cols) and set(fk.dst_cols) == set(dst_cols):
590
+ is_forward = True
591
+ break
592
+
593
+ if is_forward:
594
+ score += 10
595
+ else:
596
+ score += 5
597
+
598
+ for dst_col in dst_cols:
599
+ if (dst_table, dst_col) in filter_columns:
600
+ score += 15
601
+ if (dst_table, dst_col) in groupby_columns:
602
+ score += 10
603
+
604
+ for src_col in src_cols:
605
+ if (src_table, src_col) in filter_columns:
606
+ if is_forward and any(
607
+ dst_col == src_col.replace("_id", "") or src_col.endswith(f"_{dst_table}_id")
608
+ for dst_col in dst_cols
609
+ ):
610
+ score += 12
611
+
612
+ if src_table in agg_tables:
613
+ score += 8
614
+
615
+ src_role = src_meta.role or ""
616
+ dst_role = dst_meta.role or ""
617
+
618
+ if dst_role == "DIMENSION":
619
+ score += 3
620
+ if src_role == "FACT" and dst_role == "DIMENSION":
621
+ score += 5
622
+ if src_role == "FACT" and dst_role == "FACT":
623
+ score -= 10
624
+ if dst_role == "BRIDGE":
625
+ score -= 5
626
+
627
+ return score
628
+
629
+
630
+ def _format_join_candidate_semantic(
631
+ candidate_id: str, edges: list[dict[str, Any]], schema: SchemaGraph, score: float
632
+ ) -> str:
633
+ """Format join candidate with semantic labels and FK direction indicators.
634
+
635
+ Args:
636
+
637
+ candidate_id: The candidate identifier string (for example, ``"J01"``).
638
+
639
+ edges: List of join edge dicts for this candidate.
640
+
641
+ schema: The schema graph for FK direction lookup.
642
+
643
+ score: The numeric score assigned to this candidate.
644
+
645
+ Returns:
646
+
647
+ Multi-line string describing each join edge with FK direction label and overall score, suitable for inclusion in an LLM prompt.
648
+ """
649
+ if not edges:
650
+ return f"{candidate_id}: Single table (no joins)"
651
+
652
+ lines = [f"{candidate_id} [Score: {score:.1f}]:"]
653
+
654
+ for edge in edges:
655
+ src_table = edge["src_table"]
656
+ dst_table = edge["dst_table"]
657
+ src_cols = edge["src_cols"]
658
+ dst_cols = edge["dst_cols"]
659
+
660
+ src_meta = schema.tables.get(src_table)
661
+
662
+ fk_direction = "Reverse FK"
663
+ if src_meta:
664
+ for fk in src_meta.foreign_keys:
665
+ if (
666
+ fk.dst_table == dst_table
667
+ and set(fk.src_cols) == set(src_cols)
668
+ and set(fk.dst_cols) == set(dst_cols)
669
+ ):
670
+ fk_direction = "Forward FK"
671
+ break
672
+
673
+ src_col_str = ",".join(src_cols)
674
+ dst_col_str = ",".join(dst_cols)
675
+ lines.append(f" {src_table}.{src_col_str} -> {dst_table}.{dst_col_str} ({fk_direction})")
676
+
677
+ return "\n".join(lines)
678
+
679
+
680
+ def _llm_rank_join_candidates(
681
+ candidates: list[dict[str, Any]], intent: RuntimeIntent, schema: SchemaGraph
682
+ ) -> list[int]:
683
+ """Use LLM to rank ambiguous join candidates when scores are tied.
684
+
685
+ Args:
686
+
687
+ candidates: List of scored candidate dicts (each with ``edges``, ``score``, and ``candidate_id`` keys).
688
+
689
+ intent: The ``RuntimeIntent`` for query context.
690
+
691
+ schema: The schema graph.
692
+
693
+ Returns:
694
+
695
+ List of integer indices into ``candidates`` in LLM-ranked order (best first). Falls back to the original order if the LLM call fails.
696
+ """
697
+ if len(candidates) <= 1:
698
+ return list(range(len(candidates)))
699
+
700
+ filter_desc = []
701
+ for fp in intent.filters_param or []:
702
+ filter_desc.append(f"{fp.left_expr.primary_column} {fp.op}")
703
+
704
+ agg_desc = []
705
+ for sc in intent.select_cols or []:
706
+ if sc.is_aggregated:
707
+ agg_desc.append(sc.expr.primary_term)
708
+
709
+ groupby_desc = ", ".join(g.primary_column for g in (intent.group_by_cols or []))
710
+
711
+ intent_summary = (
712
+ f"Tables: {intent.tables}\n"
713
+ f"Filters: {', '.join(filter_desc) if filter_desc else 'none'}\n"
714
+ f"Aggregations: {', '.join(agg_desc) if agg_desc else 'none'}\n"
715
+ f"Group By: {groupby_desc if groupby_desc else 'none'}\n"
716
+ f"Grain: {intent.grain}"
717
+ )
718
+
719
+ candidate_descriptions = []
720
+ for idx, cand in enumerate(candidates[:3]):
721
+ edges = cand.get("edges", [])
722
+ cand_id = cand.get("candidate_id", f"J{idx + 1:02d}")
723
+ score = cand.get("score", 0.0)
724
+ desc = _format_join_candidate_semantic(cand_id, edges, schema, score)
725
+ candidate_descriptions.append(desc)
726
+
727
+ system_prompt = (
728
+ "You are a SQL join path validator. Rank join paths by semantic correctness for the given query intent.\n\n"
729
+ "Output Requirements:\n"
730
+ "- Output ONLY valid JSON matching the specified output_schema.\n"
731
+ "- Do NOT include markdown code blocks, explanations, or commentary.\n"
732
+ "- Identical inputs must produce identical outputs.\n\n"
733
+ "Ranking Criteria:\n"
734
+ "1. FK direction: Forward FKs (natural flow) preferred over reverse FKs.\n"
735
+ "2. Filter alignment: Joins connecting directly to filtered columns score higher.\n"
736
+ "3. Semantic correctness: Does the join path match business intent?\n"
737
+ "4. Simplicity: Shorter, more intuitive paths preferred."
738
+ )
739
+
740
+ user_prompt = stable_json(
741
+ {
742
+ "task": "Rank SQL join path candidates by semantic correctness for the given query intent.",
743
+ "intent_summary": intent_summary,
744
+ "join_path_candidates": candidate_descriptions,
745
+ "output_schema": {
746
+ "ranked_ids": ["J01", "J02", "J03"],
747
+ "reasoning": "Brief explanation of ranking",
748
+ },
749
+ "instructions": "Rank the candidates from best to worst based on semantic fit for this query.",
750
+ }
751
+ )
752
+
753
+ result = llm_json(system_prompt, user_prompt, task="sql")
754
+ if not result or "ranked_ids" not in result:
755
+ debug("[sql_gen.llm_rank_join_candidates] LLM ranking failed, using original order")
756
+ return list(range(len(candidates)))
757
+
758
+ ranked_ids = result["ranked_ids"]
759
+ id_to_idx = {cand.get("candidate_id", f"J{i + 1:02d}"): i for i, cand in enumerate(candidates)}
760
+
761
+ ranked_indices = []
762
+ for cand_id in ranked_ids:
763
+ if cand_id in id_to_idx:
764
+ ranked_indices.append(id_to_idx[cand_id])
765
+
766
+ for i in range(len(candidates)):
767
+ if i not in ranked_indices:
768
+ ranked_indices.append(i)
769
+
770
+ debug(f"[sql_gen.llm_rank_join_candidates] LLM ranked: {ranked_ids}, reasoning: {result.get('reasoning', 'none')}")
771
+ return ranked_indices
772
+
773
+
774
+ def _rank_join_candidates(
775
+ candidates: list[list[dict[str, Any]]], intent: RuntimeIntent, schema: SchemaGraph
776
+ ) -> list[list[dict[str, Any]]]:
777
+ """Rank join candidates deterministically by score with optional LLM tie-breaking.
778
+
779
+ Sorts candidates by ``score_join_path`` score descending. When the top two candidates are within 5 points of each other, invokes ``llm_rank_join_candidates`` on the top three for LLM tie-breaking.
780
+
781
+ Args:
782
+
783
+ candidates: List of join paths (each a list of edge dicts).
784
+
785
+ intent: The ``RuntimeIntent`` for scoring context.
786
+
787
+ schema: The schema graph.
788
+
789
+ Returns:
790
+
791
+ List of join paths sorted from most to least preferred.
792
+ """
793
+ if len(candidates) <= 1:
794
+ return candidates
795
+
796
+ scored = []
797
+ for edges in candidates:
798
+ score = _score_join_path(edges, intent, schema)
799
+ scored.append({"edges": edges, "score": score})
800
+
801
+ scored.sort(key=lambda x: x["score"], reverse=True)
802
+
803
+ if len(scored) >= 2:
804
+ top_score = scored[0]["score"]
805
+ second_score = scored[1]["score"]
806
+
807
+ if abs(top_score - second_score) <= 5.0:
808
+ debug(
809
+ f"[sql_gen.rank_join_candidates] top scores within threshold: {top_score:.1f} vs {second_score:.1f}, invoking LLM"
810
+ )
811
+
812
+ for idx, item in enumerate(scored[:3]):
813
+ item["candidate_id"] = f"J{idx + 1:02d}"
814
+
815
+ llm_ranking = _llm_rank_join_candidates(scored[:3], intent, schema)
816
+
817
+ reordered = [scored[i] for i in llm_ranking if i < len(scored)]
818
+ remaining = [scored[i] for i in range(len(scored)) if i not in llm_ranking[: len(reordered)]]
819
+ scored = reordered + remaining
820
+
821
+ ranked = [item["edges"] for item in scored]
822
+
823
+ top_score = scored[0]["score"] if scored else 0.0
824
+ debug(f"[sql_gen.rank_join_candidates] ranked {len(ranked)} candidates, top_score={top_score:.1f}")
825
+
826
+ return ranked
827
+
828
+
829
+ def physical_tables_for_join_hints(
830
+ tables: list[str] | None,
831
+ schema: SchemaGraph,
832
+ ) -> list[str]:
833
+ """Return physical table names from ``tables`` that exist in ``schema``.
834
+
835
+ Preserves first-seen order and drops CTE aliases or unknown names so
836
+ join-path lookup only uses keys present in ``schema.tables``.
837
+
838
+ Args:
839
+
840
+ tables: Declared table list, possibly mixing CTE names and bases.
841
+
842
+ schema: Loaded schema graph.
843
+
844
+ Returns:
845
+
846
+ Deduped list of canonical table keys from ``schema.tables``.
847
+ """
848
+ if not tables:
849
+ return []
850
+ by_lower: dict[str, str] = {k.lower(): k for k in schema.tables}
851
+ out: list[str] = []
852
+ seen: set[str] = set()
853
+ for raw in tables:
854
+ key = by_lower.get(raw.lower()) if raw else None
855
+ if key is None or key in seen:
856
+ continue
857
+ out.append(key)
858
+ seen.add(key)
859
+ return out
860
+
861
+
862
+ def join_hints_multi(schema: SchemaGraph, tables: list[str], intent: RuntimeIntent | None = None) -> dict[str, Any]:
863
+ """Generate join hint candidates for SQL generation with deterministic ranking.
864
+
865
+ Args:
866
+
867
+ schema: The schema graph.
868
+
869
+ tables: List of table names to join.
870
+
871
+ intent: Optional ``RuntimeIntent`` used for score-based ranking.
872
+
873
+ Returns:
874
+
875
+ Dict with a ``"candidates"`` list, each entry containing ``candidate_id``, ``join_path_signature``, and ``edge_count``.
876
+ """
877
+ candidates = _candidate_join_paths_for_tables(schema, tables)
878
+ debug(f"[sql_gen.join_hints_multi] tables={tables}, raw_candidates={len(candidates)}")
879
+
880
+ if len(tables) <= 1:
881
+ debug("[sql_gen.join_hints_multi] single table, returning J00")
882
+ return {
883
+ "candidates": [
884
+ {
885
+ "candidate_id": "J00",
886
+ "join_path_signature": [],
887
+ "edge_count": 0,
888
+ }
889
+ ]
890
+ }
891
+
892
+ if intent:
893
+ debug(f"[sql_gen.join_hints_multi] ranking {len(candidates)} candidates with intent context")
894
+ ranked_candidates = _rank_join_candidates(candidates, intent, schema)
895
+ else:
896
+ debug("[sql_gen.join_hints_multi] no intent provided, using original order")
897
+ ranked_candidates = candidates
898
+
899
+ out = []
900
+ for idx, edges in enumerate(ranked_candidates):
901
+ out.append(
902
+ {
903
+ "candidate_id": f"J{idx + 1:02d}",
904
+ "join_path_signature": _join_path_signature_for_path(edges),
905
+ "edge_count": len(edges),
906
+ }
907
+ )
908
+ debug(f"[sql_gen.join_hints_multi] generated {len(out)} candidates")
909
+ return {"candidates": out}
910
+
911
+
912
+ def _format_sql_arg(k: str, v: Any) -> str:
913
+ """Format a scalar function argument for SQL expression guide.
914
+
915
+ Args:
916
+
917
+ k: Parameter key (used as ``:k`` placeholder); empty string means use literal.
918
+
919
+ v: Argument value used when ``k`` is empty.
920
+
921
+ Returns:
922
+
923
+ Parameter placeholder string (for example, ``:key``) or a quoted or numeric literal.
924
+ """
925
+ if k:
926
+ return f":{k}"
927
+ if isinstance(v, str):
928
+ return f"'{v}'"
929
+ return str(v)
930
+
931
+
932
+ def _render_group_sql(g: MulGroup) -> str:
933
+ """Render a MulGroup as a SQL fragment for expression guide.
934
+
935
+ Args:
936
+
937
+ g: The ``MulGroup`` containing multiply or divide columns, aggregation function, coefficient, and optional scalar or inner-scalar function wrappers.
938
+
939
+ Returns:
940
+
941
+ SQL fragment string representing the group, for example ``"ROUND(SUM(:coeff * table.col), 2)"``.
942
+ """
943
+ if not g.multiply:
944
+ return "1"
945
+ base = " * ".join(g.multiply)
946
+ if g.divide:
947
+ base = f"({base}) / ({' * '.join(g.divide)})"
948
+ if g.coeff_param_key:
949
+ base = f":{g.coeff_param_key} * {base}"
950
+ elif g.coefficient != 1.0:
951
+ base = f"{g.coefficient} * {base}"
952
+ if g.inner_scalar_func:
953
+ iargs = [
954
+ _format_sql_arg(k, v)
955
+ for k, v in zip(g.isarg_param_keys or [], g.inner_scalar_func_args or [], strict=False)
956
+ ]
957
+ iargs += [_format_sql_arg("", v) for v in (g.inner_scalar_func_args or [])[len(g.isarg_param_keys or []) :]]
958
+ args_str = ", ".join(iargs)
959
+ if g.inner_scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
960
+ inner = f"{g.inner_scalar_func.upper()}({args_str}, {base})"
961
+ else:
962
+ inner = f"{g.inner_scalar_func.upper()}({base}{', ' + args_str if args_str else ''})"
963
+ else:
964
+ inner = base
965
+ if g.agg_func:
966
+ mid = f"{g.agg_func.upper()}({inner})"
967
+ else:
968
+ mid = inner
969
+ if g.scalar_func:
970
+ sargs = [_format_sql_arg(k, v) for k, v in zip(g.sarg_param_keys or [], g.scalar_func_args or [], strict=False)]
971
+ sargs += [_format_sql_arg("", v) for v in (g.scalar_func_args or [])[len(g.sarg_param_keys or []) :]]
972
+ args_str = ", ".join(sargs)
973
+ if g.scalar_func.lower() == "extract" and args_str:
974
+ return f"EXTRACT({args_str} FROM {mid})"
975
+ if g.scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
976
+ return f"{g.scalar_func.upper()}({args_str}, {mid})"
977
+ return f"{g.scalar_func.upper()}({mid}{', ' + args_str if args_str else ''})"
978
+ return mid
979
+
980
+
981
+ def _render_expr_sql(expr: NormalizedExpr) -> str:
982
+ """Render a NormalizedExpr as a SQL fragment for expression guide.
983
+
984
+ Args:
985
+
986
+ expr: The ``NormalizedExpr`` to render, potentially containing multiple additive or subtractive groups and optional outer scalar wrapping.
987
+
988
+ Returns:
989
+
990
+ SQL fragment string that the LLM should produce for this expression.
991
+ """
992
+ parts: list[str] = []
993
+ for g in expr.add_groups:
994
+ parts.append(_render_group_sql(g))
995
+ for v in expr.add_values:
996
+ parts.append(f":{v.param_key}" if v.param_key else str(v.value))
997
+ sub_parts: list[str] = []
998
+ for g in expr.sub_groups:
999
+ sub_parts.append(_render_group_sql(g))
1000
+ for v in expr.sub_values:
1001
+ sub_parts.append(f":{v.param_key}" if v.param_key else str(v.value))
1002
+ result = " + ".join(parts) if parts else "0"
1003
+ if sub_parts:
1004
+ result = f"{result} - {' - '.join(sub_parts)}"
1005
+ if expr.inner_scalar_func and not any(g.inner_scalar_func for g in expr.add_groups):
1006
+ iargs = [
1007
+ _format_sql_arg(k, v)
1008
+ for k, v in zip(
1009
+ expr.isarg_param_keys or [],
1010
+ expr.inner_scalar_func_args or [],
1011
+ strict=False,
1012
+ )
1013
+ ]
1014
+ iargs += [
1015
+ _format_sql_arg("", v) for v in (expr.inner_scalar_func_args or [])[len(expr.isarg_param_keys or []) :]
1016
+ ]
1017
+ args_str = ", ".join(iargs)
1018
+ if expr.inner_scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
1019
+ result = f"{expr.inner_scalar_func.upper()}({args_str}, {result})"
1020
+ else:
1021
+ result = f"{expr.inner_scalar_func.upper()}({result}{', ' + args_str if args_str else ''})"
1022
+ if expr.agg_func and not any(g.agg_func for g in expr.add_groups):
1023
+ result = f"{expr.agg_func.upper()}({result})"
1024
+ if expr.scalar_func and not any(g.scalar_func for g in expr.add_groups):
1025
+ sargs = [
1026
+ _format_sql_arg(k, v) for k, v in zip(expr.sarg_param_keys or [], expr.scalar_func_args or [], strict=False)
1027
+ ]
1028
+ sargs += [_format_sql_arg("", v) for v in (expr.scalar_func_args or [])[len(expr.sarg_param_keys or []) :]]
1029
+ args_str = ", ".join(sargs)
1030
+ if expr.scalar_func.lower() == "extract" and args_str:
1031
+ result = f"EXTRACT({args_str} FROM {result})"
1032
+ elif expr.scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
1033
+ result = f"{expr.scalar_func.upper()}({args_str}, {result})"
1034
+ else:
1035
+ result = f"{expr.scalar_func.upper()}({result}{', ' + args_str if args_str else ''})"
1036
+ return result
1037
+
1038
+
1039
+ def build_deterministic_sql(
1040
+ intent: RuntimeIntent,
1041
+ cte_join_hints: dict[str, dict[str, Any]] | None = None,
1042
+ ) -> str:
1043
+ """Build a rough deterministic SQL from a RuntimeIntent.
1044
+
1045
+ The output is structurally correct but may lack JOIN clauses and dialect-specific syntax. It serves as a constrained template for the SQL LLM and as a reference for post-generation validation.
1046
+
1047
+ Each SELECT expression is rendered via ``_render_expr_sql``. CTE steps are emitted as ``WITH`` clauses with deterministic output column aliases. A ``-- <JOIN>`` placeholder marks where the LLM should insert the chosen join predicates.
1048
+ """
1049
+ parts: list[str] = []
1050
+
1051
+ cte_steps = intent.cte_steps or []
1052
+ if cte_steps:
1053
+ cte_clauses: list[str] = []
1054
+ for cte in cte_steps:
1055
+ cte_sql = _build_deterministic_select_block(
1056
+ cte.select_cols or [],
1057
+ cte.tables or [],
1058
+ cte.group_by_cols or [],
1059
+ cte.order_by_cols or [],
1060
+ cte.filters_param or [],
1061
+ cte.having_param or [],
1062
+ cte.limit,
1063
+ cte.grain or "row_level",
1064
+ cte.output_columns or [],
1065
+ )
1066
+ cte_clauses.append(f"{cte.cte_name} AS (\n{cte_sql}\n)")
1067
+ parts.append("WITH " + ",\n".join(cte_clauses))
1068
+
1069
+ main_sql = _build_deterministic_select_block(
1070
+ intent.select_cols or [],
1071
+ intent.tables or [],
1072
+ intent.group_by_cols or [],
1073
+ intent.order_by_cols or [],
1074
+ intent.filters_param or [],
1075
+ intent.having_param or [],
1076
+ intent.limit,
1077
+ intent.grain or "row_level",
1078
+ )
1079
+ parts.append(main_sql)
1080
+
1081
+ return "\n".join(parts)
1082
+
1083
+
1084
+ def _join_clause_parts_with_bool_op(
1085
+ parts: list[tuple[str, str]],
1086
+ ) -> str:
1087
+ """Chain SQL clause fragments using their positional boolean operators.
1088
+
1089
+ Each element's ``bool_op`` is the connector between that element and
1090
+ the next. The last element's ``bool_op`` is unused. Fragments are
1091
+ joined sequentially to preserve the canonical ordering established
1092
+ by ``_canonicalize_condition_order``.
1093
+
1094
+ When any ``OR`` connector is present the entire expression is wrapped
1095
+ in parentheses to maintain correct SQL precedence in outer contexts.
1096
+
1097
+ Args:
1098
+ parts: List of ``(sql_fragment, bool_op)`` tuples where
1099
+ ``bool_op`` is ``"AND"`` or ``"OR"``.
1100
+
1101
+ Returns:
1102
+ Combined SQL predicate string.
1103
+ """
1104
+ if not parts:
1105
+ return ""
1106
+
1107
+ result = parts[0][0]
1108
+ for i in range(1, len(parts)):
1109
+ connector = parts[i - 1][1]
1110
+ result = f"{result} {connector} {parts[i][0]}"
1111
+
1112
+ has_or = any(op == "OR" for _, op in parts[:-1])
1113
+ if has_or and len(parts) > 1:
1114
+ result = f"({result})"
1115
+
1116
+ return result
1117
+
1118
+
1119
+ def _build_deterministic_select_block(
1120
+ select_cols: list[SelectCol],
1121
+ tables: list[str],
1122
+ group_by_cols: list[NormalizedExpr],
1123
+ order_by_cols: list,
1124
+ filters_param: list,
1125
+ having_param: list,
1126
+ limit: int | None,
1127
+ grain: str,
1128
+ output_aliases: list[str] | None = None,
1129
+ ) -> str:
1130
+ """Build a single SELECT block from structured intent clauses.
1131
+
1132
+ Renders SELECT, FROM (with ``-- <JOIN>`` placeholder), WHERE, GROUP BY, HAVING, ORDER BY, and LIMIT clauses.
1133
+ """
1134
+ lines: list[str] = []
1135
+
1136
+ select_exprs: list[str] = []
1137
+ for idx, sc in enumerate(select_cols):
1138
+ rendered = _render_expr_sql(sc.expr)
1139
+ if output_aliases and idx < len(output_aliases):
1140
+ rendered = f"{rendered} AS {output_aliases[idx]}"
1141
+ select_exprs.append(rendered)
1142
+
1143
+ lines.append("SELECT " + ", ".join(select_exprs))
1144
+
1145
+ if tables:
1146
+ lines.append(f"FROM {tables[0]}")
1147
+ if len(tables) > 1:
1148
+ lines.append("-- <JOIN: integrate from join candidates>")
1149
+
1150
+ where_parts: list[tuple[str, str]] = []
1151
+ dialect_type = EngineConfig.TYPE or "postgresql"
1152
+ for fp in filters_param:
1153
+ left = _render_expr_sql(fp.left_expr)
1154
+ op = fp.op or "="
1155
+ case_insensitive = fp.value_type == "string" and op.lower() not in (
1156
+ "is null", "is not null", "ilike", "not ilike",
1157
+ )
1158
+ if case_insensitive:
1159
+ left = _wrap_for_case_insensitive(left, dialect_type)
1160
+ bool_op = getattr(fp, "bool_op", "AND") or "AND"
1161
+ if op.lower() in ("is null", "is not null"):
1162
+ where_parts.append((f"{left} {op.upper()}", bool_op))
1163
+ elif fp.value_type == "date_window" and isinstance(fp.raw_value, dict):
1164
+ for dw_frag in _render_date_window_where(fp, left, dialect_type):
1165
+ where_parts.append((dw_frag, "AND"))
1166
+ elif fp.value_type == "date_diff" and isinstance(fp.raw_value, dict):
1167
+ rv = fp.raw_value
1168
+ unit = rv.get("unit", "day")
1169
+ amount = int(rv.get("amount", 0)) if rv.get("amount") is not None else 0
1170
+ op = fp.op or ">"
1171
+ frag = render_date_diff_expr(dialect_type, left, op, unit, amount)
1172
+ where_parts.append((frag, "AND"))
1173
+ elif fp.right_expr:
1174
+ right = _render_expr_sql(fp.right_expr)
1175
+ if case_insensitive:
1176
+ right = _wrap_for_case_insensitive(right, dialect_type)
1177
+ where_parts.append((f"{left} {op} {right}", bool_op))
1178
+ elif fp.param_key:
1179
+ val_needs_lower = case_insensitive and op.lower() in ("like", "not like")
1180
+ val_ref = f"LOWER(:{fp.param_key})" if val_needs_lower else f":{fp.param_key}"
1181
+ where_parts.append((f"{left} {op} {val_ref}", bool_op))
1182
+ elif fp.raw_value is not None:
1183
+ pkey = fp.param_key or "p?"
1184
+ val_needs_lower = case_insensitive and op.lower() in ("like", "not like")
1185
+ val_ref = f"LOWER(:{pkey})" if val_needs_lower else f":{pkey}"
1186
+ where_parts.append((f"{left} {op} {val_ref}", bool_op))
1187
+ if where_parts:
1188
+ lines.append("WHERE " + _join_clause_parts_with_bool_op(where_parts))
1189
+
1190
+ if group_by_cols:
1191
+ gb_exprs = [_render_expr_sql(g) for g in group_by_cols]
1192
+ lines.append("GROUP BY " + ", ".join(gb_exprs))
1193
+
1194
+ having_parts: list[tuple[str, str]] = []
1195
+ for hp in having_param:
1196
+ left = _render_expr_sql(hp.left_expr)
1197
+ op = hp.op or ">"
1198
+ bool_op = getattr(hp, "bool_op", "AND") or "AND"
1199
+ if hp.right_expr:
1200
+ right = _render_expr_sql(hp.right_expr)
1201
+ having_parts.append((f"{left} {op} {right}", bool_op))
1202
+ elif hp.param_key:
1203
+ having_parts.append((f"{left} {op} :{hp.param_key}", bool_op))
1204
+ else:
1205
+ having_parts.append((f"{left} {op} ?", bool_op))
1206
+ if having_parts:
1207
+ lines.append("HAVING " + _join_clause_parts_with_bool_op(having_parts))
1208
+
1209
+ if order_by_cols:
1210
+ ob_exprs = []
1211
+ for obc in order_by_cols:
1212
+ rendered = _render_expr_sql(obc.expr)
1213
+ direction = obc.direction.upper() if obc.direction else "ASC"
1214
+ ob_exprs.append(f"{rendered} {direction}")
1215
+ lines.append("ORDER BY " + ", ".join(ob_exprs))
1216
+
1217
+ if limit:
1218
+ lines.append(f"LIMIT {limit}")
1219
+
1220
+ return "\n".join(lines)
1221
+
1222
+
1223
+ def _generate_col_alias(sc: SelectCol) -> str:
1224
+ """Build a deterministic display alias from a SelectCol's expression metadata.
1225
+
1226
+ Rules:
1227
+ * Plain column ``table.col`` → ``col``.
1228
+ * Aggregate ``COUNT(table.col)`` → ``count_col``.
1229
+ * Distinct aggregate ``COUNT(DISTINCT table.col)`` → ``count_distinct_col``.
1230
+ * Scalar wrapper ``ROUND(SUM(table.col), 2)`` → ``round_sum_col``.
1231
+ * Arithmetic ``table.a * table.b`` → ``a_times_b``.
1232
+ * Fallback: ``col_<idx>`` assigned by the caller.
1233
+
1234
+ Args:
1235
+ sc: The ``SelectCol`` to derive an alias for.
1236
+
1237
+ Returns:
1238
+ A lowercase alias string safe for SQL ``AS`` usage.
1239
+ """
1240
+ expr = sc.expr
1241
+ col = expr.primary_column
1242
+ if col:
1243
+ col_clean = col.rsplit(".", 1)[-1].lower()
1244
+ else:
1245
+ col_clean = ""
1246
+
1247
+ groups = expr.add_groups or []
1248
+ if len(groups) >= 2 and not expr.agg_func and not expr.scalar_func:
1249
+ parts = [g.multiply[0].rsplit(".", 1)[-1].lower() if g.multiply else "x" for g in groups]
1250
+ alias = "_times_".join(parts)
1251
+ elif expr.sub_groups and groups:
1252
+ plus_part = groups[0].multiply[0].rsplit(".", 1)[-1].lower() if groups[0].multiply else "x"
1253
+ minus_part = (
1254
+ expr.sub_groups[0].multiply[0].rsplit(".", 1)[-1].lower()
1255
+ if expr.sub_groups[0].multiply
1256
+ else "y"
1257
+ )
1258
+ alias = f"{plus_part}_minus_{minus_part}"
1259
+ elif col_clean:
1260
+ alias = col_clean
1261
+ else:
1262
+ return ""
1263
+
1264
+ distinct_prefix = ""
1265
+ if groups and groups[0].multiply:
1266
+ term = groups[0].multiply[0].upper()
1267
+ if "DISTINCT " in term:
1268
+ distinct_prefix = "distinct_"
1269
+
1270
+ if expr.agg_func:
1271
+ alias = f"{expr.agg_func}_{distinct_prefix}{alias}"
1272
+ if expr.inner_scalar_func:
1273
+ alias = f"{expr.inner_scalar_func}_{alias}"
1274
+ if expr.scalar_func:
1275
+ alias = f"{expr.scalar_func}_{alias}"
1276
+
1277
+ return alias.lower()
1278
+
1279
+
1280
+ def deterministic_alias_sql(sql_param: str, intent: RuntimeIntent) -> str:
1281
+ """Add deterministic display aliases to each SELECT expression.
1282
+
1283
+ Parses the ``SELECT ... FROM`` portion of the parameterized SQL,
1284
+ matches each comma-separated expression positionally with the
1285
+ intent's ``select_cols``, and appends an ``AS alias`` clause derived
1286
+ from column metadata.
1287
+
1288
+ Args:
1289
+ sql_param: Parameterized SQL string produced by ``build_deterministic_sql``.
1290
+ intent: The ``RuntimeIntent`` whose ``select_cols`` drive aliasing.
1291
+
1292
+ Returns:
1293
+ SQL string with ``AS`` aliases on every SELECT expression. Returns
1294
+ the original SQL unchanged when the SELECT clause cannot be parsed
1295
+ or the column count does not match.
1296
+ """
1297
+ import re
1298
+
1299
+ match = re.search(r"(?i)\bSELECT\s+", sql_param)
1300
+ if not match:
1301
+ return sql_param
1302
+
1303
+ select_start = match.end()
1304
+ from_match = re.search(r"(?i)\bFROM\b", sql_param[select_start:])
1305
+ if not from_match:
1306
+ return sql_param
1307
+
1308
+ select_body = sql_param[select_start : select_start + from_match.start()].strip()
1309
+ rest = sql_param[select_start + from_match.start() :]
1310
+
1311
+ depth = 0
1312
+ parts: list[str] = []
1313
+ current: list[str] = []
1314
+ for ch in select_body:
1315
+ if ch == "(":
1316
+ depth += 1
1317
+ elif ch == ")":
1318
+ depth -= 1
1319
+ if ch == "," and depth == 0:
1320
+ parts.append("".join(current).strip())
1321
+ current = []
1322
+ else:
1323
+ current.append(ch)
1324
+ if current:
1325
+ parts.append("".join(current).strip())
1326
+
1327
+ cols = intent.select_cols or []
1328
+ if len(parts) != len(cols):
1329
+ return sql_param
1330
+
1331
+ aliased: list[str] = []
1332
+ seen_aliases: set[str] = set()
1333
+ for idx, (expr_str, sc) in enumerate(zip(parts, cols, strict=False)):
1334
+ alias = _generate_col_alias(sc)
1335
+ if not alias:
1336
+ alias = f"col_{idx + 1}"
1337
+ base = alias
1338
+ counter = 2
1339
+ while alias in seen_aliases:
1340
+ alias = f"{base}_{counter}"
1341
+ counter += 1
1342
+ seen_aliases.add(alias)
1343
+ aliased.append(f"{expr_str} AS {alias}")
1344
+
1345
+ prefix = sql_param[: match.start()] + match.group()
1346
+ return prefix + ", ".join(aliased) + " " + rest
1347
+
1348
+
1349
+ def _render_date_window_where(
1350
+ fp: FilterParam, left_rendered: str, dialect_type: str
1351
+ ) -> list[str]:
1352
+ """Render WHERE clause part(s) for a date_window filter.
1353
+
1354
+ For raw_value with start/end keys emits two predicates (>= start AND <= end).
1355
+ For unit/offset uses render_date_window_expr. Returns a list of one or two
1356
+ fragments to AND together.
1357
+ """
1358
+ rv = fp.raw_value if isinstance(fp.raw_value, dict) else {}
1359
+ if "start" in rv and "end" in rv:
1360
+ start_val = rv["start"]
1361
+ end_val = rv["end"]
1362
+ if isinstance(start_val, str) and isinstance(end_val, str):
1363
+ return [
1364
+ f"{left_rendered} >= '{start_val}'",
1365
+ f"{left_rendered} <= '{end_val}'",
1366
+ ]
1367
+ unit = rv.get("unit", "day")
1368
+ offset = int(rv.get("offset", 0)) if rv.get("offset") is not None else 0
1369
+ op = fp.op or ">="
1370
+ return [render_date_window_expr(dialect_type, left_rendered, op, unit, offset)]
1371
+
1372
+
1373
+ def build_join_choice_prompt(
1374
+ q_norm: str,
1375
+ deterministic_sql: str,
1376
+ join_candidates: dict[str, Any],
1377
+ cte_join_hints: dict[str, dict[str, Any]] | None = None,
1378
+ ) -> tuple[str, str]:
1379
+ """Build minimal prompt for LLM to return only join candidate IDs.
1380
+
1381
+ Returns (system_prompt, user_prompt). Response must be JSON with
1382
+ chosen_join_candidate_id and optionally chosen_cte_join_candidate_ids.
1383
+ """
1384
+ system = (
1385
+ "You are a join selector for text-to-SQL. Output ONLY valid JSON. "
1386
+ "Return chosen_join_candidate_id and, if the query has CTEs that need joins, "
1387
+ "chosen_cte_join_candidate_ids mapping each CTE name to its candidate_id."
1388
+ )
1389
+ candidates = join_candidates.get("candidates", [])
1390
+ cte_names = list(cte_join_hints.keys()) if cte_join_hints else []
1391
+ cte_payload = None
1392
+ if cte_names and cte_join_hints:
1393
+ cte_payload = {}
1394
+ for cte, h in cte_join_hints.items():
1395
+ cands = h.get("candidates", []) or []
1396
+ cte_payload[cte] = [
1397
+ {"candidate_id": c.get("candidate_id"), "join_path_signature": c.get("join_path_signature")}
1398
+ for c in cands
1399
+ ]
1400
+ user = stable_json(
1401
+ {
1402
+ "task": (
1403
+ "Given the question and the deterministic SQL template, choose the join candidate "
1404
+ "that correctly connects the tables. Return only the IDs; do not modify the SQL."
1405
+ ),
1406
+ "question": q_norm,
1407
+ "deterministic_sql": deterministic_sql,
1408
+ "join_candidates": [
1409
+ {"candidate_id": c.get("candidate_id"), "join_path_signature": c.get("join_path_signature")}
1410
+ for c in candidates
1411
+ ],
1412
+ "cte_join_candidates": cte_payload,
1413
+ "output_format": {
1414
+ "chosen_join_candidate_id": "J00 or J01, J02, ...",
1415
+ "chosen_cte_join_candidate_ids": "Optional dict: cte_name -> candidate_id",
1416
+ },
1417
+ }
1418
+ )
1419
+ return system, user
1420
+
1421
+
1422
+ def get_join_choice_from_llm(
1423
+ q_norm: str,
1424
+ deterministic_sql: str,
1425
+ join_candidates: dict[str, Any],
1426
+ cte_join_hints: dict[str, dict[str, Any]] | None,
1427
+ ) -> tuple[str, dict[str, str]]:
1428
+ """Call LLM to get only chosen join candidate ID and per-CTE IDs.
1429
+
1430
+ Returns (chosen_join_candidate_id, chosen_cte_join_candidate_ids).
1431
+ Defaults to J00 and empty dict on parse failure.
1432
+ """
1433
+ system, user = build_join_choice_prompt(
1434
+ q_norm, deterministic_sql, join_candidates, cte_join_hints
1435
+ )
1436
+ parsed = llm_json(system, user, retries=1, task="sql")
1437
+ if not isinstance(parsed, dict):
1438
+ return "J00", {}
1439
+ chosen = parsed.get("chosen_join_candidate_id")
1440
+ if not isinstance(chosen, str):
1441
+ chosen = "J00"
1442
+ cte_ids = parsed.get("chosen_cte_join_candidate_ids")
1443
+ if not isinstance(cte_ids, dict):
1444
+ cte_ids = {}
1445
+ return chosen, {k: v for k, v in cte_ids.items() if isinstance(k, str) and isinstance(v, str)}
1446
+
1447
+
1448
+ def build_repair_prompt(
1449
+ schema: SchemaGraph,
1450
+ q_norm: str,
1451
+ prev_sql: str,
1452
+ db_error: str,
1453
+ nl_error: str,
1454
+ join_hints: dict[str, Any],
1455
+ cte_join_hints: dict[str, dict[str, Any]] | None = None,
1456
+ ) -> tuple[str, str]:
1457
+ """Build system and user prompts for SQL repair.
1458
+
1459
+ Args:
1460
+
1461
+ schema: The schema graph.
1462
+
1463
+ q_norm: The normalised user question string.
1464
+
1465
+ prev_sql: The previously generated SQL that failed.
1466
+
1467
+ db_error: The raw database error message.
1468
+
1469
+ nl_error: The human-readable explanation of the error.
1470
+
1471
+ join_hints: Join hint candidates from ``join_hints_multi``.
1472
+
1473
+ cte_join_hints: Optional dict mapping CTE name to join hints for CTE steps.
1474
+
1475
+ Returns:
1476
+
1477
+ Tuple of ``(system_prompt, user_prompt)`` strings ready for the LLM.
1478
+ """
1479
+ dialect_type = EngineConfig.TYPE
1480
+ if dialect_type == "databricks":
1481
+ dialect_name = "Spark"
1482
+ elif dialect_type == "postgresql":
1483
+ dialect_name = "PostgreSQL"
1484
+ else:
1485
+ dialect_name = "SQL"
1486
+
1487
+ hard_constraints = [
1488
+ f"Dialect: {dialect_name}. Use {dialect_name} syntax ONLY.",
1489
+ "Output ONLY valid JSON. No markdown. No explanations.",
1490
+ "SYNTAX-ONLY REPAIR: fix syntax errors, column qualification, keyword casing, operator typos, parameter placeholder format.",
1491
+ "Do NOT add, remove, or reorder tables in FROM/JOIN clauses.",
1492
+ "Do NOT change SELECT columns, aggregation functions, or expressions.",
1493
+ "Do NOT alter GROUP BY, ORDER BY, or HAVING structure.",
1494
+ "Do NOT change join conditions, join types, or join order.",
1495
+ "Do NOT add or remove WHERE/HAVING predicates.",
1496
+ "ALWAYS qualify column names with table name (for example, <table_1>.<column_1>).",
1497
+ "Do not alias SELECT expressions. No AS clauses on SELECT columns or aggregates.",
1498
+ "HAVING clause: Use full aggregation expression, NOT an alias.",
1499
+ "Use :p1, :p2 parameter placeholders for filter/having values, NOT literal values.",
1500
+ "WHERE/HAVING predicates: column reference on LEFT side of comparison.",
1501
+ "ORDER BY must include explicit ASC or DESC direction.",
1502
+ "chosen_join_candidate_id must match the SAME candidate as the original SQL.",
1503
+ ]
1504
+
1505
+ if cte_join_hints:
1506
+ cte_names = list(cte_join_hints.keys())
1507
+ hard_constraints.append(
1508
+ f"CTE join predicates MUST exactly match chosen CTE join candidate for: {', '.join(cte_names)}"
1509
+ )
1510
+ hard_constraints.append(
1511
+ "chosen_cte_join_candidate_ids must specify join candidate for each CTE with multi-table joins."
1512
+ )
1513
+
1514
+ output_format: dict[str, Any] = {"sql": "...", "chosen_join_candidate_id": "J01"}
1515
+ if cte_join_hints:
1516
+ output_format["chosen_cte_join_candidate_ids"] = {"cte_name": "J01"}
1517
+
1518
+ prompt_data: dict[str, Any] = {
1519
+ "task": "Fix ONLY syntax errors in the SQL query. Do not change query structure, tables, joins, or logic.",
1520
+ "error_info": {
1521
+ "db_error": db_error,
1522
+ "explanation": nl_error,
1523
+ },
1524
+ "hard_constraints": hard_constraints,
1525
+ "output_schema": output_format,
1526
+ "schema": schema.schema_literal_text,
1527
+ "join_candidates": join_hints,
1528
+ "question": q_norm,
1529
+ "previous_sql": prev_sql,
1530
+ "db_error": db_error,
1531
+ "error_explanation": nl_error,
1532
+ }
1533
+
1534
+ if cte_join_hints:
1535
+ prompt_data["cte_join_candidates"] = cte_join_hints
1536
+
1537
+ return (SQL_REPAIR_SYSTEM_PROMPT, stable_json(prompt_data))