aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/sql_gen.py
ADDED
|
@@ -0,0 +1,1537 @@
|
|
|
1
|
+
"""SQL generation, join path resolution, and canonical normalization.
|
|
2
|
+
|
|
3
|
+
Responsible for building the SQL prompt and repair prompt sent to the LLM, enumerating and ranking all valid FK-based join paths between a set of tables, normalising generated SQL to a canonical JOIN order and predicate form, and validating that the LLM-chosen join candidate matches the actual SQL structure. Also provides utilities for rendering intent expressions as SQL fragments used in the generation prompt and for post-hoc alias injection.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from .config import SCALAR_FUNCTIONS_LEADING_ARG, EngineConfig, PolicyConfig
|
|
12
|
+
from .contracts_base import SchemaGraph
|
|
13
|
+
from .contracts_core import FilterParam, MulGroup, NormalizedExpr, RuntimeCteStep, RuntimeIntent, SelectCol
|
|
14
|
+
from .core_utils import debug, llm_json, llm_sql_with_join, stable_json
|
|
15
|
+
from .dialect import get_dialect, render_date_diff_expr, render_date_window_expr
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def cte_to_intent_for_ranking(cte: RuntimeCteStep) -> RuntimeIntent:
|
|
19
|
+
"""Build a synthetic ``RuntimeIntent`` from ``RuntimeCteStep`` for CTE join ranking.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
|
|
23
|
+
cte: The ``RuntimeCteStep`` whose tables and intent fields should be promoted.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
|
|
27
|
+
A ``RuntimeIntent`` with the CTE's tables, grain, select/group/order/filter/having columns, param values, column map, and limit, suitable for join scoring.
|
|
28
|
+
"""
|
|
29
|
+
return RuntimeIntent(
|
|
30
|
+
tables=cte.tables,
|
|
31
|
+
grain=cte.grain,
|
|
32
|
+
select_cols=cte.select_cols,
|
|
33
|
+
group_by_cols=cte.group_by_cols,
|
|
34
|
+
order_by_cols=cte.order_by_cols,
|
|
35
|
+
filters_param=cte.filters_param,
|
|
36
|
+
having_param=cte.having_param,
|
|
37
|
+
param_values=cte.param_values,
|
|
38
|
+
column_map=cte.column_map,
|
|
39
|
+
limit=cte.limit,
|
|
40
|
+
cte_steps=[],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
SQL_REPAIR_SYSTEM_PROMPT = """You are a deterministic SQL repair assistant.
|
|
45
|
+
|
|
46
|
+
Output requirements:
|
|
47
|
+
- Output ONLY valid JSON that matches the specified output_schema.
|
|
48
|
+
- Do NOT include markdown, explanations, or commentary.
|
|
49
|
+
- Identical inputs must produce identical outputs.
|
|
50
|
+
|
|
51
|
+
Repair guidelines:
|
|
52
|
+
- Make minimal changes to fix the specific error.
|
|
53
|
+
- Do not change query logic unless required to fix the error.
|
|
54
|
+
- Preserve JOIN structure unless the error requires changing it.
|
|
55
|
+
- Maintain existing filters, aggregations, and ordering when possible.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def join_candidate_map(join_hints: dict[str, Any]) -> dict[str, list[str]]:
|
|
60
|
+
"""Build map from candidate ID to join path signature.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
|
|
64
|
+
join_hints: The join hints dict produced by ``join_hints_multi``.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
|
|
68
|
+
Dict mapping ``candidate_id`` to list of join path signature strings.
|
|
69
|
+
"""
|
|
70
|
+
out: dict[str, list[str]] = {}
|
|
71
|
+
for c in join_hints.get("candidates", []):
|
|
72
|
+
cid = c.get("candidate_id")
|
|
73
|
+
sig = c.get("join_path_signature")
|
|
74
|
+
if isinstance(cid, str) and isinstance(sig, list):
|
|
75
|
+
out[cid] = [str(x) for x in sig]
|
|
76
|
+
return out
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _analyze_join_topology(sig: list[str]) -> tuple[str, str, list[str]]:
|
|
80
|
+
"""Analyze join signature to determine topology type, hub table, and leaf tables.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
|
|
84
|
+
sig: List of join path signature strings.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
|
|
88
|
+
Tuple of ``(topology_type, anchor_table, leaf_tables)`` where ``topology_type`` is one of ``"none"``, ``"linear"``, ``"star"``, or ``"tree"``; ``anchor_table`` is the canonical root; and ``leaf_tables`` is the list of endpoint tables.
|
|
89
|
+
"""
|
|
90
|
+
if not sig:
|
|
91
|
+
return ("none", "", [])
|
|
92
|
+
table_counts: dict[str, int] = {}
|
|
93
|
+
for item in sig:
|
|
94
|
+
if "->" not in item:
|
|
95
|
+
continue
|
|
96
|
+
left, right = item.split("->", 1)
|
|
97
|
+
left_table = left.split(".")[0].strip()
|
|
98
|
+
right_table = right.split(".")[0].strip()
|
|
99
|
+
table_counts[left_table] = table_counts.get(left_table, 0) + 1
|
|
100
|
+
table_counts[right_table] = table_counts.get(right_table, 0) + 1
|
|
101
|
+
if not table_counts:
|
|
102
|
+
return ("none", "", [])
|
|
103
|
+
leaves = sorted([t for t, c in table_counts.items() if c == 1])
|
|
104
|
+
hubs = sorted(
|
|
105
|
+
[t for t, c in table_counts.items() if c > 1],
|
|
106
|
+
key=lambda t: (-table_counts[t], t),
|
|
107
|
+
)
|
|
108
|
+
if len(leaves) == 2 and len(hubs) == len(table_counts) - 2:
|
|
109
|
+
return ("linear", min(leaves), leaves)
|
|
110
|
+
if len(hubs) == 1:
|
|
111
|
+
return ("star", hubs[0], leaves)
|
|
112
|
+
if hubs:
|
|
113
|
+
return ("tree", hubs[0], leaves)
|
|
114
|
+
return ("linear", min(table_counts.keys()), list(table_counts.keys()))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def reorder_sql_joins_canonical(sql: str, join_sig: list[str]) -> str:
|
|
118
|
+
"""Reorder SQL FROM clause to canonical form based on join topology.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
|
|
122
|
+
sql: The SQL string whose FROM/JOIN clause should be reordered.
|
|
123
|
+
|
|
124
|
+
join_sig: The join path signature that defines the canonical order.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
|
|
128
|
+
SQL string with the FROM clause reordered; returns the original SQL unchanged if topology is ``"none"`` or the SQL cannot be parsed.
|
|
129
|
+
"""
|
|
130
|
+
if not join_sig or len(join_sig) == 0:
|
|
131
|
+
return sql
|
|
132
|
+
topology_type, anchor, leaves = _analyze_join_topology(join_sig)
|
|
133
|
+
if topology_type == "none":
|
|
134
|
+
return sql
|
|
135
|
+
from_match = re.search(r"\bFROM\s+(\w+)", sql, re.IGNORECASE)
|
|
136
|
+
if not from_match:
|
|
137
|
+
return sql
|
|
138
|
+
current_first_table = from_match.group(1).lower()
|
|
139
|
+
if topology_type == "linear":
|
|
140
|
+
if current_first_table == anchor.lower():
|
|
141
|
+
return sql
|
|
142
|
+
other_endpoint = [leaf for leaf in leaves if leaf.lower() != anchor.lower()]
|
|
143
|
+
if not other_endpoint or current_first_table != other_endpoint[0].lower():
|
|
144
|
+
return sql
|
|
145
|
+
join_pattern = re.compile(
|
|
146
|
+
r"\bFROM\s+(\w+)\s+((?:(?:INNER\s+)?JOIN\s+\w+\s+ON\s+[^)]+?(?=\s+(?:INNER\s+)?JOIN|\s+WHERE|\s+GROUP|\s+ORDER|\s+LIMIT|\s+HAVING|$))+)",
|
|
147
|
+
re.IGNORECASE | re.DOTALL,
|
|
148
|
+
)
|
|
149
|
+
match = join_pattern.search(sql)
|
|
150
|
+
if not match:
|
|
151
|
+
return sql
|
|
152
|
+
first_table = match.group(1)
|
|
153
|
+
joins_block = match.group(2)
|
|
154
|
+
join_clauses = re.findall(
|
|
155
|
+
r"((?:INNER\s+)?JOIN\s+(\w+)\s+ON\s+([^)]+?)(?=\s+(?:INNER\s+)?JOIN|\s*$))",
|
|
156
|
+
joins_block,
|
|
157
|
+
re.IGNORECASE | re.DOTALL,
|
|
158
|
+
)
|
|
159
|
+
if not join_clauses:
|
|
160
|
+
return sql
|
|
161
|
+
tables_in_order = [first_table]
|
|
162
|
+
for jc in join_clauses:
|
|
163
|
+
tables_in_order.append(jc[1])
|
|
164
|
+
reversed_tables = list(reversed(tables_in_order))
|
|
165
|
+
reversed_on_clauses = []
|
|
166
|
+
for jc in reversed(join_clauses):
|
|
167
|
+
on_clause = jc[2].strip()
|
|
168
|
+
reversed_on_clauses.append(on_clause)
|
|
169
|
+
new_from = f"FROM {reversed_tables[0]}"
|
|
170
|
+
for i, tbl in enumerate(reversed_tables[1:]):
|
|
171
|
+
new_from += f" JOIN {tbl} ON {reversed_on_clauses[i]}"
|
|
172
|
+
original_from_end = match.end()
|
|
173
|
+
original_from_start = match.start()
|
|
174
|
+
new_sql = sql[:original_from_start] + new_from + sql[original_from_end:]
|
|
175
|
+
debug(f"[sql_gen.reorder_sql_joins_canonical] linear reordered: {current_first_table} -> {reversed_tables[0]}")
|
|
176
|
+
return new_sql
|
|
177
|
+
if current_first_table == anchor.lower():
|
|
178
|
+
join_pattern = re.compile(
|
|
179
|
+
r"\bFROM\s+(\w+)\s+((?:(?:INNER\s+)?JOIN\s+\w+\s+ON\s+[^)]+?(?=\s+(?:INNER\s+)?JOIN|\s+WHERE|\s+GROUP|\s+ORDER|\s+LIMIT|\s+HAVING|$))+)",
|
|
180
|
+
re.IGNORECASE | re.DOTALL,
|
|
181
|
+
)
|
|
182
|
+
match = join_pattern.search(sql)
|
|
183
|
+
if not match:
|
|
184
|
+
return sql
|
|
185
|
+
joins_block = match.group(2)
|
|
186
|
+
join_clauses = re.findall(
|
|
187
|
+
r"((?:INNER\s+)?JOIN\s+(\w+)\s+ON\s+([^)]+?)(?=\s+(?:INNER\s+)?JOIN|\s*$))",
|
|
188
|
+
joins_block,
|
|
189
|
+
re.IGNORECASE | re.DOTALL,
|
|
190
|
+
)
|
|
191
|
+
if not join_clauses:
|
|
192
|
+
return sql
|
|
193
|
+
sorted_joins = sorted(join_clauses, key=lambda jc: jc[1].lower())
|
|
194
|
+
new_from = f"FROM {anchor}"
|
|
195
|
+
for jc in sorted_joins:
|
|
196
|
+
new_from += f" JOIN {jc[1]} ON {jc[2].strip()}"
|
|
197
|
+
original_from_end = match.end()
|
|
198
|
+
original_from_start = match.start()
|
|
199
|
+
new_sql = sql[:original_from_start] + new_from + sql[original_from_end:]
|
|
200
|
+
debug("[sql_gen.reorder_sql_joins_canonical] star/tree branches sorted alphabetically")
|
|
201
|
+
return new_sql
|
|
202
|
+
debug(
|
|
203
|
+
f"[sql_gen.reorder_sql_joins_canonical] star/tree topology but FROM table {current_first_table} != anchor {anchor}"
|
|
204
|
+
)
|
|
205
|
+
return sql
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
JOIN_PLACEHOLDER = "-- <JOIN: integrate from join candidates>"
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _wrap_for_case_insensitive(expr: str, dialect_type: str) -> str:
|
|
212
|
+
"""Wrap expression for case-insensitive string comparison.
|
|
213
|
+
|
|
214
|
+
On Databricks, uses LOWER(TRIM(expr)) to handle whitespace and
|
|
215
|
+
collation. On other dialects, uses LOWER(expr).
|
|
216
|
+
"""
|
|
217
|
+
if dialect_type == "databricks":
|
|
218
|
+
return f"LOWER(TRIM({expr}))"
|
|
219
|
+
return f"LOWER({expr})"
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _join_clause_from_signature(signature: list[str], from_table: str = "") -> str:
|
|
223
|
+
"""Build JOIN clause text from a join path signature.
|
|
224
|
+
|
|
225
|
+
Each segment is "src_tbl.col->dst_tbl.col". Tracks tables already in
|
|
226
|
+
the chain to avoid duplicate JOINs (e.g. when two edges target the
|
|
227
|
+
same table). When the target is already present, adds the source
|
|
228
|
+
table instead.
|
|
229
|
+
"""
|
|
230
|
+
if not signature:
|
|
231
|
+
return ""
|
|
232
|
+
chain: set[str] = {from_table.lower()} if from_table else set()
|
|
233
|
+
parts: list[str] = []
|
|
234
|
+
for seg in signature:
|
|
235
|
+
seg = seg.strip()
|
|
236
|
+
if "->" not in seg:
|
|
237
|
+
continue
|
|
238
|
+
left_part, right_part = seg.split("->", 1)
|
|
239
|
+
left_part = left_part.strip()
|
|
240
|
+
right_part = right_part.strip()
|
|
241
|
+
if "." not in left_part or "." not in right_part:
|
|
242
|
+
continue
|
|
243
|
+
left_tbl, left_cols = left_part.split(".", 1)
|
|
244
|
+
right_tbl, right_cols = right_part.split(".", 1)
|
|
245
|
+
left_col_list = [c.strip() for c in left_cols.split(",")]
|
|
246
|
+
right_col_list = [c.strip() for c in right_cols.split(",")]
|
|
247
|
+
on_terms = [
|
|
248
|
+
f"{left_tbl}.{lc} = {right_tbl}.{rc}"
|
|
249
|
+
for lc, rc in zip(left_col_list, right_col_list, strict=False)
|
|
250
|
+
]
|
|
251
|
+
if not on_terms:
|
|
252
|
+
continue
|
|
253
|
+
right_tbl_lower = right_tbl.lower()
|
|
254
|
+
left_tbl_lower = left_tbl.lower()
|
|
255
|
+
if right_tbl_lower in chain:
|
|
256
|
+
join_tbl = left_tbl
|
|
257
|
+
chain.add(left_tbl_lower)
|
|
258
|
+
else:
|
|
259
|
+
join_tbl = right_tbl
|
|
260
|
+
chain.add(right_tbl_lower)
|
|
261
|
+
parts.append(f" JOIN {join_tbl} ON " + " AND ".join(on_terms))
|
|
262
|
+
return "".join(parts)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _orient_join_sig_for_from(
|
|
266
|
+
sig: list[str],
|
|
267
|
+
from_table: str,
|
|
268
|
+
) -> list[str]:
|
|
269
|
+
"""Reorient join segments so that no target duplicates the FROM table.
|
|
270
|
+
|
|
271
|
+
When the right-hand (target) table of a segment equals the current
|
|
272
|
+
FROM table, the segment is flipped so the other table becomes the
|
|
273
|
+
JOIN target instead. This prevents ``FROM t JOIN t`` self-join
|
|
274
|
+
artefacts that occur when ``tables[0]`` in the intent happens to
|
|
275
|
+
sit on the target side of the join signature.
|
|
276
|
+
"""
|
|
277
|
+
if not from_table:
|
|
278
|
+
return sig
|
|
279
|
+
oriented: list[str] = []
|
|
280
|
+
for seg in sig:
|
|
281
|
+
if "->" not in seg:
|
|
282
|
+
oriented.append(seg)
|
|
283
|
+
continue
|
|
284
|
+
left, right = seg.split("->", 1)
|
|
285
|
+
right_tbl = right.split(".")[0].strip().lower()
|
|
286
|
+
if right_tbl == from_table:
|
|
287
|
+
oriented.append(f"{right.strip()}->{left.strip()}")
|
|
288
|
+
else:
|
|
289
|
+
oriented.append(seg)
|
|
290
|
+
return oriented
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def inject_join_into_deterministic_sql(
|
|
294
|
+
det_sql: str,
|
|
295
|
+
join_sigs_ordered: list[list[str]],
|
|
296
|
+
) -> str:
|
|
297
|
+
"""Replace each JOIN placeholder in deterministic SQL with JOIN clause from signatures.
|
|
298
|
+
|
|
299
|
+
Placeholders are replaced in order: first occurrence with
|
|
300
|
+
``join_sigs_ordered[0]``, etc. Before building each JOIN clause
|
|
301
|
+
the signature is oriented so that the target table does not
|
|
302
|
+
duplicate the current FROM table.
|
|
303
|
+
"""
|
|
304
|
+
if not join_sigs_ordered:
|
|
305
|
+
return det_sql
|
|
306
|
+
result = det_sql
|
|
307
|
+
for sig in join_sigs_ordered:
|
|
308
|
+
if JOIN_PLACEHOLDER not in result:
|
|
309
|
+
break
|
|
310
|
+
from_match = re.search(r"\bFROM\s+(\w+)", result, re.IGNORECASE)
|
|
311
|
+
from_tbl = from_match.group(1) if from_match else ""
|
|
312
|
+
oriented = _orient_join_sig_for_from(sig, from_tbl.lower())
|
|
313
|
+
join_clause = _join_clause_from_signature(oriented, from_tbl)
|
|
314
|
+
result = result.replace(JOIN_PLACEHOLDER, join_clause.strip(), 1)
|
|
315
|
+
result = re.sub(r"\n\s*-- <JOIN[^>]*>\s*", "\n", result)
|
|
316
|
+
return result
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def normalize_where_having_predicates(sql: str) -> str:
|
|
320
|
+
"""Normalize WHERE and HAVING predicates to put column references on the left.
|
|
321
|
+
|
|
322
|
+
Swaps predicates of the form ``:param op table.column`` to ``table.column op :param`` so that column references always appear on the left-hand side of comparison operators.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
|
|
326
|
+
sql: The SQL string to normalise.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
|
|
330
|
+
SQL string with swapped predicates in WHERE and HAVING clauses.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def swap_predicate(match):
|
|
334
|
+
full = match.group(0)
|
|
335
|
+
left = match.group(1).strip()
|
|
336
|
+
op = match.group(2).strip()
|
|
337
|
+
right = match.group(3).strip()
|
|
338
|
+
|
|
339
|
+
left_is_param = left.startswith(":") or left.startswith("'") or left.startswith('"') or left[0].isdigit()
|
|
340
|
+
right_is_col = "." in right and not (right.startswith(":") or right.startswith("'") or right.startswith('"'))
|
|
341
|
+
|
|
342
|
+
if left_is_param and right_is_col:
|
|
343
|
+
return f"{right} {op} {left}"
|
|
344
|
+
return full
|
|
345
|
+
|
|
346
|
+
pattern = r"([:\w.]+|'[^']*'|\"[^\"]*\")\s*(=|!=|<>|<=|>=|<|>)\s+([:\w.]+|'[^']*'|\"[^\"]*\")"
|
|
347
|
+
|
|
348
|
+
where_match = re.search(r"\bWHERE\b", sql, re.IGNORECASE)
|
|
349
|
+
if where_match:
|
|
350
|
+
where_start = where_match.end()
|
|
351
|
+
next_clause = re.search(
|
|
352
|
+
r"\b(GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT)\b",
|
|
353
|
+
sql[where_start:],
|
|
354
|
+
re.IGNORECASE,
|
|
355
|
+
)
|
|
356
|
+
where_end = where_start + next_clause.start() if next_clause else len(sql)
|
|
357
|
+
where_clause = sql[where_start:where_end]
|
|
358
|
+
normalized_where = re.sub(pattern, swap_predicate, where_clause)
|
|
359
|
+
sql = sql[:where_start] + normalized_where + sql[where_end:]
|
|
360
|
+
|
|
361
|
+
having_match = re.search(r"\bHAVING\b", sql, re.IGNORECASE)
|
|
362
|
+
if having_match:
|
|
363
|
+
having_start = having_match.end()
|
|
364
|
+
next_clause = re.search(r"\b(ORDER\s+BY|LIMIT)\b", sql[having_start:], re.IGNORECASE)
|
|
365
|
+
having_end = having_start + next_clause.start() if next_clause else len(sql)
|
|
366
|
+
having_clause = sql[having_start:having_end]
|
|
367
|
+
normalized_having = re.sub(pattern, swap_predicate, having_clause)
|
|
368
|
+
sql = sql[:having_start] + normalized_having + sql[having_end:]
|
|
369
|
+
|
|
370
|
+
return sql
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def normalize_cte_sql(sql: str, cte_join_sigs: dict[str, list[str]]) -> str:
|
|
374
|
+
"""Normalize CTE bodies with join reordering and predicate normalization.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
|
|
378
|
+
sql: The full SQL string that may contain WITH/CTE clauses.
|
|
379
|
+
|
|
380
|
+
cte_join_sigs: Dict mapping CTE name to join path signature for reordering.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
|
|
384
|
+
SQL string with each CTE body's FROM clause reordered and WHERE/HAVING predicates normalised.
|
|
385
|
+
"""
|
|
386
|
+
dialect = get_dialect()
|
|
387
|
+
cte_bodies = dialect.extract_cte_bodies(sql)
|
|
388
|
+
if not cte_bodies:
|
|
389
|
+
return sql
|
|
390
|
+
|
|
391
|
+
for cte_name, cte_body in cte_bodies.items():
|
|
392
|
+
join_sig = cte_join_sigs.get(cte_name, [])
|
|
393
|
+
normalized_body = cte_body
|
|
394
|
+
if join_sig:
|
|
395
|
+
normalized_body = reorder_sql_joins_canonical(normalized_body, join_sig)
|
|
396
|
+
normalized_body = normalize_where_having_predicates(normalized_body)
|
|
397
|
+
if normalized_body != cte_body:
|
|
398
|
+
old_cte = f"{cte_name} AS ({cte_body})"
|
|
399
|
+
new_cte = f"{cte_name} AS ({normalized_body})"
|
|
400
|
+
sql = sql.replace(old_cte, new_cte)
|
|
401
|
+
|
|
402
|
+
debug(f"[sql_gen.normalize_cte_sql] normalized {len(cte_bodies)} CTEs")
|
|
403
|
+
return sql
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _join_path_signature_for_path(path: list[dict[str, Any]]) -> list[str]:
|
|
407
|
+
"""Generate signature strings for a join path.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
|
|
411
|
+
path: List of edge dicts, each with ``src_table``, ``src_cols``, ``dst_table``, and ``dst_cols`` keys.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
|
|
415
|
+
List of strings in the form ``"src_table.col1,col2->dst_table.col3,col4"``.
|
|
416
|
+
"""
|
|
417
|
+
sig = []
|
|
418
|
+
for e in path:
|
|
419
|
+
sig.append(f"{e['src_table']}.{','.join(e['src_cols'])}->{e['dst_table']}.{','.join(e['dst_cols'])}")
|
|
420
|
+
return sig
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _candidate_join_paths_for_tables(schema: SchemaGraph, tables: list[str]) -> list[list[dict[str, Any]]]:
|
|
424
|
+
"""Compute all candidate join paths for a set of tables by trying every table as root.
|
|
425
|
+
|
|
426
|
+
First attempts direct paths (no bridge tables). Falls back to bridge-table paths if no direct paths are found.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
|
|
430
|
+
schema: The schema graph containing pre-computed join paths.
|
|
431
|
+
|
|
432
|
+
tables: List of table names that must all be reachable in each candidate.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
|
|
436
|
+
List of join paths, each a list of edge dicts with source and destination table and column keys. Returns ``[[]]`` for single-table queries.
|
|
437
|
+
"""
|
|
438
|
+
tables = sorted(set(tables))
|
|
439
|
+
if len(tables) < 2:
|
|
440
|
+
return [[]]
|
|
441
|
+
|
|
442
|
+
def uniq_edges(edges: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
443
|
+
seen: set = set()
|
|
444
|
+
out: list[dict[str, Any]] = []
|
|
445
|
+
for e in edges:
|
|
446
|
+
pair = (
|
|
447
|
+
(e["src_table"], tuple(e["src_cols"])),
|
|
448
|
+
(e["dst_table"], tuple(e["dst_cols"])),
|
|
449
|
+
)
|
|
450
|
+
canonical = tuple(sorted(pair))
|
|
451
|
+
if canonical in seen:
|
|
452
|
+
continue
|
|
453
|
+
seen.add(canonical)
|
|
454
|
+
out.append(e)
|
|
455
|
+
return out
|
|
456
|
+
|
|
457
|
+
table_set = set(tables)
|
|
458
|
+
|
|
459
|
+
def _edges_cover_tables(edges: list[dict[str, Any]], root: str) -> set[str]:
|
|
460
|
+
covered = {root}
|
|
461
|
+
for e in edges:
|
|
462
|
+
covered.add(e["src_table"])
|
|
463
|
+
covered.add(e["dst_table"])
|
|
464
|
+
return covered
|
|
465
|
+
|
|
466
|
+
def _merge_paths_minimal(
|
|
467
|
+
root: str, others: list[str], allow_bridges: bool
|
|
468
|
+
) -> list[list[dict[str, Any]]]:
|
|
469
|
+
covered: set[str] = {root}
|
|
470
|
+
merged: list[dict[str, Any]] = []
|
|
471
|
+
for target in others:
|
|
472
|
+
if target in covered:
|
|
473
|
+
continue
|
|
474
|
+
paths = schema.join_paths_multi.get(root, {}).get(target, [])
|
|
475
|
+
if not paths:
|
|
476
|
+
continue
|
|
477
|
+
best: list[dict[str, Any]] | None = None
|
|
478
|
+
for p in paths:
|
|
479
|
+
if not p:
|
|
480
|
+
continue
|
|
481
|
+
path_tables = _edges_cover_tables(p, root)
|
|
482
|
+
if target not in path_tables:
|
|
483
|
+
continue
|
|
484
|
+
if not allow_bridges and not path_tables <= table_set:
|
|
485
|
+
continue
|
|
486
|
+
if best is None or len(p) < len(best):
|
|
487
|
+
best = p
|
|
488
|
+
if best:
|
|
489
|
+
for e in best:
|
|
490
|
+
if e not in merged:
|
|
491
|
+
merged.append(e)
|
|
492
|
+
covered = _edges_cover_tables(merged, root)
|
|
493
|
+
return [merged] if merged else []
|
|
494
|
+
|
|
495
|
+
def _collect(allow_bridges: bool) -> dict[tuple, list[dict[str, Any]]]:
|
|
496
|
+
candidates: dict[tuple, list[dict[str, Any]]] = {}
|
|
497
|
+
for root in tables:
|
|
498
|
+
others = [t for t in tables if t != root]
|
|
499
|
+
for merged in _merge_paths_minimal(root, others, allow_bridges):
|
|
500
|
+
deduped = uniq_edges(merged)
|
|
501
|
+
edge_tables = (
|
|
502
|
+
{root}
|
|
503
|
+
| {e["src_table"] for e in deduped}
|
|
504
|
+
| {e["dst_table"] for e in deduped}
|
|
505
|
+
)
|
|
506
|
+
if not table_set <= edge_tables:
|
|
507
|
+
continue
|
|
508
|
+
if not allow_bridges and not edge_tables <= table_set:
|
|
509
|
+
continue
|
|
510
|
+
sig = tuple(_join_path_signature_for_path(deduped))
|
|
511
|
+
if sig not in candidates:
|
|
512
|
+
candidates[sig] = deduped
|
|
513
|
+
return candidates
|
|
514
|
+
|
|
515
|
+
all_candidates = _collect(allow_bridges=False)
|
|
516
|
+
if not all_candidates:
|
|
517
|
+
all_candidates = _collect(allow_bridges=True)
|
|
518
|
+
if all_candidates:
|
|
519
|
+
debug(
|
|
520
|
+
f"[sql_gen.candidate_join_paths_for_tables] no direct paths, found {len(all_candidates)} bridge paths"
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
res = list(all_candidates.values())
|
|
524
|
+
res.sort(key=lambda m: (len(m), tuple(_join_path_signature_for_path(m))))
|
|
525
|
+
return res
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _score_join_path(edges: list[dict[str, Any]], intent: RuntimeIntent, schema: SchemaGraph) -> float:
|
|
529
|
+
"""Score a join path based on FK direction, intent alignment, and path characteristics.
|
|
530
|
+
|
|
531
|
+
Higher scores indicate more semantically appropriate join paths. Points are awarded for forward FK direction, shorter paths, joins that connect to filtered or grouped columns, and FACT→DIMENSION relationships.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
|
|
535
|
+
edges: List of join edge dicts for the candidate path.
|
|
536
|
+
|
|
537
|
+
intent: The ``RuntimeIntent`` providing filter, group-by, and aggregation context.
|
|
538
|
+
|
|
539
|
+
schema: The schema graph for table role and FK metadata.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
|
|
543
|
+
Float score; higher is better.
|
|
544
|
+
"""
|
|
545
|
+
score = 0.0
|
|
546
|
+
|
|
547
|
+
filter_columns = set()
|
|
548
|
+
for fp in intent.filters_param or []:
|
|
549
|
+
pcol = fp.left_expr.primary_column
|
|
550
|
+
col_parts = pcol.split(".")
|
|
551
|
+
if len(col_parts) == 2:
|
|
552
|
+
filter_columns.add((col_parts[0], col_parts[1]))
|
|
553
|
+
|
|
554
|
+
groupby_columns = set()
|
|
555
|
+
for gb in intent.group_by_cols or []:
|
|
556
|
+
col_parts = gb.primary_column.split(".")
|
|
557
|
+
if len(col_parts) == 2:
|
|
558
|
+
groupby_columns.add((col_parts[0], col_parts[1]))
|
|
559
|
+
|
|
560
|
+
agg_tables = set()
|
|
561
|
+
for sc in intent.select_cols or []:
|
|
562
|
+
if sc.is_aggregated:
|
|
563
|
+
pcol = sc.expr.primary_column
|
|
564
|
+
col_parts = pcol.split(".")
|
|
565
|
+
if len(col_parts) == 2:
|
|
566
|
+
agg_tables.add(col_parts[0])
|
|
567
|
+
|
|
568
|
+
path_tables = set()
|
|
569
|
+
for edge in edges:
|
|
570
|
+
path_tables.add(edge["src_table"])
|
|
571
|
+
path_tables.add(edge["dst_table"])
|
|
572
|
+
|
|
573
|
+
score += max(20 - (len(edges) * 3), 0)
|
|
574
|
+
|
|
575
|
+
for edge in edges:
|
|
576
|
+
src_table = edge["src_table"]
|
|
577
|
+
dst_table = edge["dst_table"]
|
|
578
|
+
src_cols = edge["src_cols"]
|
|
579
|
+
dst_cols = edge["dst_cols"]
|
|
580
|
+
|
|
581
|
+
src_meta = schema.tables.get(src_table)
|
|
582
|
+
dst_meta = schema.tables.get(dst_table)
|
|
583
|
+
|
|
584
|
+
if not src_meta or not dst_meta:
|
|
585
|
+
continue
|
|
586
|
+
|
|
587
|
+
is_forward = False
|
|
588
|
+
for fk in src_meta.foreign_keys:
|
|
589
|
+
if fk.dst_table == dst_table and set(fk.src_cols) == set(src_cols) and set(fk.dst_cols) == set(dst_cols):
|
|
590
|
+
is_forward = True
|
|
591
|
+
break
|
|
592
|
+
|
|
593
|
+
if is_forward:
|
|
594
|
+
score += 10
|
|
595
|
+
else:
|
|
596
|
+
score += 5
|
|
597
|
+
|
|
598
|
+
for dst_col in dst_cols:
|
|
599
|
+
if (dst_table, dst_col) in filter_columns:
|
|
600
|
+
score += 15
|
|
601
|
+
if (dst_table, dst_col) in groupby_columns:
|
|
602
|
+
score += 10
|
|
603
|
+
|
|
604
|
+
for src_col in src_cols:
|
|
605
|
+
if (src_table, src_col) in filter_columns:
|
|
606
|
+
if is_forward and any(
|
|
607
|
+
dst_col == src_col.replace("_id", "") or src_col.endswith(f"_{dst_table}_id")
|
|
608
|
+
for dst_col in dst_cols
|
|
609
|
+
):
|
|
610
|
+
score += 12
|
|
611
|
+
|
|
612
|
+
if src_table in agg_tables:
|
|
613
|
+
score += 8
|
|
614
|
+
|
|
615
|
+
src_role = src_meta.role or ""
|
|
616
|
+
dst_role = dst_meta.role or ""
|
|
617
|
+
|
|
618
|
+
if dst_role == "DIMENSION":
|
|
619
|
+
score += 3
|
|
620
|
+
if src_role == "FACT" and dst_role == "DIMENSION":
|
|
621
|
+
score += 5
|
|
622
|
+
if src_role == "FACT" and dst_role == "FACT":
|
|
623
|
+
score -= 10
|
|
624
|
+
if dst_role == "BRIDGE":
|
|
625
|
+
score -= 5
|
|
626
|
+
|
|
627
|
+
return score
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def _format_join_candidate_semantic(
|
|
631
|
+
candidate_id: str, edges: list[dict[str, Any]], schema: SchemaGraph, score: float
|
|
632
|
+
) -> str:
|
|
633
|
+
"""Format join candidate with semantic labels and FK direction indicators.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
|
|
637
|
+
candidate_id: The candidate identifier string (for example, ``"J01"``).
|
|
638
|
+
|
|
639
|
+
edges: List of join edge dicts for this candidate.
|
|
640
|
+
|
|
641
|
+
schema: The schema graph for FK direction lookup.
|
|
642
|
+
|
|
643
|
+
score: The numeric score assigned to this candidate.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
|
|
647
|
+
Multi-line string describing each join edge with FK direction label and overall score, suitable for inclusion in an LLM prompt.
|
|
648
|
+
"""
|
|
649
|
+
if not edges:
|
|
650
|
+
return f"{candidate_id}: Single table (no joins)"
|
|
651
|
+
|
|
652
|
+
lines = [f"{candidate_id} [Score: {score:.1f}]:"]
|
|
653
|
+
|
|
654
|
+
for edge in edges:
|
|
655
|
+
src_table = edge["src_table"]
|
|
656
|
+
dst_table = edge["dst_table"]
|
|
657
|
+
src_cols = edge["src_cols"]
|
|
658
|
+
dst_cols = edge["dst_cols"]
|
|
659
|
+
|
|
660
|
+
src_meta = schema.tables.get(src_table)
|
|
661
|
+
|
|
662
|
+
fk_direction = "Reverse FK"
|
|
663
|
+
if src_meta:
|
|
664
|
+
for fk in src_meta.foreign_keys:
|
|
665
|
+
if (
|
|
666
|
+
fk.dst_table == dst_table
|
|
667
|
+
and set(fk.src_cols) == set(src_cols)
|
|
668
|
+
and set(fk.dst_cols) == set(dst_cols)
|
|
669
|
+
):
|
|
670
|
+
fk_direction = "Forward FK"
|
|
671
|
+
break
|
|
672
|
+
|
|
673
|
+
src_col_str = ",".join(src_cols)
|
|
674
|
+
dst_col_str = ",".join(dst_cols)
|
|
675
|
+
lines.append(f" {src_table}.{src_col_str} -> {dst_table}.{dst_col_str} ({fk_direction})")
|
|
676
|
+
|
|
677
|
+
return "\n".join(lines)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def _llm_rank_join_candidates(
|
|
681
|
+
candidates: list[dict[str, Any]], intent: RuntimeIntent, schema: SchemaGraph
|
|
682
|
+
) -> list[int]:
|
|
683
|
+
"""Use LLM to rank ambiguous join candidates when scores are tied.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
|
|
687
|
+
candidates: List of scored candidate dicts (each with ``edges``, ``score``, and ``candidate_id`` keys).
|
|
688
|
+
|
|
689
|
+
intent: The ``RuntimeIntent`` for query context.
|
|
690
|
+
|
|
691
|
+
schema: The schema graph.
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
|
|
695
|
+
List of integer indices into ``candidates`` in LLM-ranked order (best first). Falls back to the original order if the LLM call fails.
|
|
696
|
+
"""
|
|
697
|
+
if len(candidates) <= 1:
|
|
698
|
+
return list(range(len(candidates)))
|
|
699
|
+
|
|
700
|
+
filter_desc = []
|
|
701
|
+
for fp in intent.filters_param or []:
|
|
702
|
+
filter_desc.append(f"{fp.left_expr.primary_column} {fp.op}")
|
|
703
|
+
|
|
704
|
+
agg_desc = []
|
|
705
|
+
for sc in intent.select_cols or []:
|
|
706
|
+
if sc.is_aggregated:
|
|
707
|
+
agg_desc.append(sc.expr.primary_term)
|
|
708
|
+
|
|
709
|
+
groupby_desc = ", ".join(g.primary_column for g in (intent.group_by_cols or []))
|
|
710
|
+
|
|
711
|
+
intent_summary = (
|
|
712
|
+
f"Tables: {intent.tables}\n"
|
|
713
|
+
f"Filters: {', '.join(filter_desc) if filter_desc else 'none'}\n"
|
|
714
|
+
f"Aggregations: {', '.join(agg_desc) if agg_desc else 'none'}\n"
|
|
715
|
+
f"Group By: {groupby_desc if groupby_desc else 'none'}\n"
|
|
716
|
+
f"Grain: {intent.grain}"
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
candidate_descriptions = []
|
|
720
|
+
for idx, cand in enumerate(candidates[:3]):
|
|
721
|
+
edges = cand.get("edges", [])
|
|
722
|
+
cand_id = cand.get("candidate_id", f"J{idx + 1:02d}")
|
|
723
|
+
score = cand.get("score", 0.0)
|
|
724
|
+
desc = _format_join_candidate_semantic(cand_id, edges, schema, score)
|
|
725
|
+
candidate_descriptions.append(desc)
|
|
726
|
+
|
|
727
|
+
system_prompt = (
|
|
728
|
+
"You are a SQL join path validator. Rank join paths by semantic correctness for the given query intent.\n\n"
|
|
729
|
+
"Output Requirements:\n"
|
|
730
|
+
"- Output ONLY valid JSON matching the specified output_schema.\n"
|
|
731
|
+
"- Do NOT include markdown code blocks, explanations, or commentary.\n"
|
|
732
|
+
"- Identical inputs must produce identical outputs.\n\n"
|
|
733
|
+
"Ranking Criteria:\n"
|
|
734
|
+
"1. FK direction: Forward FKs (natural flow) preferred over reverse FKs.\n"
|
|
735
|
+
"2. Filter alignment: Joins connecting directly to filtered columns score higher.\n"
|
|
736
|
+
"3. Semantic correctness: Does the join path match business intent?\n"
|
|
737
|
+
"4. Simplicity: Shorter, more intuitive paths preferred."
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
user_prompt = stable_json(
|
|
741
|
+
{
|
|
742
|
+
"task": "Rank SQL join path candidates by semantic correctness for the given query intent.",
|
|
743
|
+
"intent_summary": intent_summary,
|
|
744
|
+
"join_path_candidates": candidate_descriptions,
|
|
745
|
+
"output_schema": {
|
|
746
|
+
"ranked_ids": ["J01", "J02", "J03"],
|
|
747
|
+
"reasoning": "Brief explanation of ranking",
|
|
748
|
+
},
|
|
749
|
+
"instructions": "Rank the candidates from best to worst based on semantic fit for this query.",
|
|
750
|
+
}
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
result = llm_json(system_prompt, user_prompt, task="sql")
|
|
754
|
+
if not result or "ranked_ids" not in result:
|
|
755
|
+
debug("[sql_gen.llm_rank_join_candidates] LLM ranking failed, using original order")
|
|
756
|
+
return list(range(len(candidates)))
|
|
757
|
+
|
|
758
|
+
ranked_ids = result["ranked_ids"]
|
|
759
|
+
id_to_idx = {cand.get("candidate_id", f"J{i + 1:02d}"): i for i, cand in enumerate(candidates)}
|
|
760
|
+
|
|
761
|
+
ranked_indices = []
|
|
762
|
+
for cand_id in ranked_ids:
|
|
763
|
+
if cand_id in id_to_idx:
|
|
764
|
+
ranked_indices.append(id_to_idx[cand_id])
|
|
765
|
+
|
|
766
|
+
for i in range(len(candidates)):
|
|
767
|
+
if i not in ranked_indices:
|
|
768
|
+
ranked_indices.append(i)
|
|
769
|
+
|
|
770
|
+
debug(f"[sql_gen.llm_rank_join_candidates] LLM ranked: {ranked_ids}, reasoning: {result.get('reasoning', 'none')}")
|
|
771
|
+
return ranked_indices
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def _rank_join_candidates(
|
|
775
|
+
candidates: list[list[dict[str, Any]]], intent: RuntimeIntent, schema: SchemaGraph
|
|
776
|
+
) -> list[list[dict[str, Any]]]:
|
|
777
|
+
"""Rank join candidates deterministically by score with optional LLM tie-breaking.
|
|
778
|
+
|
|
779
|
+
Sorts candidates by ``score_join_path`` score descending. When the top two candidates are within 5 points of each other, invokes ``llm_rank_join_candidates`` on the top three for LLM tie-breaking.
|
|
780
|
+
|
|
781
|
+
Args:
|
|
782
|
+
|
|
783
|
+
candidates: List of join paths (each a list of edge dicts).
|
|
784
|
+
|
|
785
|
+
intent: The ``RuntimeIntent`` for scoring context.
|
|
786
|
+
|
|
787
|
+
schema: The schema graph.
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
|
|
791
|
+
List of join paths sorted from most to least preferred.
|
|
792
|
+
"""
|
|
793
|
+
if len(candidates) <= 1:
|
|
794
|
+
return candidates
|
|
795
|
+
|
|
796
|
+
scored = []
|
|
797
|
+
for edges in candidates:
|
|
798
|
+
score = _score_join_path(edges, intent, schema)
|
|
799
|
+
scored.append({"edges": edges, "score": score})
|
|
800
|
+
|
|
801
|
+
scored.sort(key=lambda x: x["score"], reverse=True)
|
|
802
|
+
|
|
803
|
+
if len(scored) >= 2:
|
|
804
|
+
top_score = scored[0]["score"]
|
|
805
|
+
second_score = scored[1]["score"]
|
|
806
|
+
|
|
807
|
+
if abs(top_score - second_score) <= 5.0:
|
|
808
|
+
debug(
|
|
809
|
+
f"[sql_gen.rank_join_candidates] top scores within threshold: {top_score:.1f} vs {second_score:.1f}, invoking LLM"
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
for idx, item in enumerate(scored[:3]):
|
|
813
|
+
item["candidate_id"] = f"J{idx + 1:02d}"
|
|
814
|
+
|
|
815
|
+
llm_ranking = _llm_rank_join_candidates(scored[:3], intent, schema)
|
|
816
|
+
|
|
817
|
+
reordered = [scored[i] for i in llm_ranking if i < len(scored)]
|
|
818
|
+
remaining = [scored[i] for i in range(len(scored)) if i not in llm_ranking[: len(reordered)]]
|
|
819
|
+
scored = reordered + remaining
|
|
820
|
+
|
|
821
|
+
ranked = [item["edges"] for item in scored]
|
|
822
|
+
|
|
823
|
+
top_score = scored[0]["score"] if scored else 0.0
|
|
824
|
+
debug(f"[sql_gen.rank_join_candidates] ranked {len(ranked)} candidates, top_score={top_score:.1f}")
|
|
825
|
+
|
|
826
|
+
return ranked
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def physical_tables_for_join_hints(
|
|
830
|
+
tables: list[str] | None,
|
|
831
|
+
schema: SchemaGraph,
|
|
832
|
+
) -> list[str]:
|
|
833
|
+
"""Return physical table names from ``tables`` that exist in ``schema``.
|
|
834
|
+
|
|
835
|
+
Preserves first-seen order and drops CTE aliases or unknown names so
|
|
836
|
+
join-path lookup only uses keys present in ``schema.tables``.
|
|
837
|
+
|
|
838
|
+
Args:
|
|
839
|
+
|
|
840
|
+
tables: Declared table list, possibly mixing CTE names and bases.
|
|
841
|
+
|
|
842
|
+
schema: Loaded schema graph.
|
|
843
|
+
|
|
844
|
+
Returns:
|
|
845
|
+
|
|
846
|
+
Deduped list of canonical table keys from ``schema.tables``.
|
|
847
|
+
"""
|
|
848
|
+
if not tables:
|
|
849
|
+
return []
|
|
850
|
+
by_lower: dict[str, str] = {k.lower(): k for k in schema.tables}
|
|
851
|
+
out: list[str] = []
|
|
852
|
+
seen: set[str] = set()
|
|
853
|
+
for raw in tables:
|
|
854
|
+
key = by_lower.get(raw.lower()) if raw else None
|
|
855
|
+
if key is None or key in seen:
|
|
856
|
+
continue
|
|
857
|
+
out.append(key)
|
|
858
|
+
seen.add(key)
|
|
859
|
+
return out
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def join_hints_multi(schema: SchemaGraph, tables: list[str], intent: RuntimeIntent | None = None) -> dict[str, Any]:
|
|
863
|
+
"""Generate join hint candidates for SQL generation with deterministic ranking.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
|
|
867
|
+
schema: The schema graph.
|
|
868
|
+
|
|
869
|
+
tables: List of table names to join.
|
|
870
|
+
|
|
871
|
+
intent: Optional ``RuntimeIntent`` used for score-based ranking.
|
|
872
|
+
|
|
873
|
+
Returns:
|
|
874
|
+
|
|
875
|
+
Dict with a ``"candidates"`` list, each entry containing ``candidate_id``, ``join_path_signature``, and ``edge_count``.
|
|
876
|
+
"""
|
|
877
|
+
candidates = _candidate_join_paths_for_tables(schema, tables)
|
|
878
|
+
debug(f"[sql_gen.join_hints_multi] tables={tables}, raw_candidates={len(candidates)}")
|
|
879
|
+
|
|
880
|
+
if len(tables) <= 1:
|
|
881
|
+
debug("[sql_gen.join_hints_multi] single table, returning J00")
|
|
882
|
+
return {
|
|
883
|
+
"candidates": [
|
|
884
|
+
{
|
|
885
|
+
"candidate_id": "J00",
|
|
886
|
+
"join_path_signature": [],
|
|
887
|
+
"edge_count": 0,
|
|
888
|
+
}
|
|
889
|
+
]
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
if intent:
|
|
893
|
+
debug(f"[sql_gen.join_hints_multi] ranking {len(candidates)} candidates with intent context")
|
|
894
|
+
ranked_candidates = _rank_join_candidates(candidates, intent, schema)
|
|
895
|
+
else:
|
|
896
|
+
debug("[sql_gen.join_hints_multi] no intent provided, using original order")
|
|
897
|
+
ranked_candidates = candidates
|
|
898
|
+
|
|
899
|
+
out = []
|
|
900
|
+
for idx, edges in enumerate(ranked_candidates):
|
|
901
|
+
out.append(
|
|
902
|
+
{
|
|
903
|
+
"candidate_id": f"J{idx + 1:02d}",
|
|
904
|
+
"join_path_signature": _join_path_signature_for_path(edges),
|
|
905
|
+
"edge_count": len(edges),
|
|
906
|
+
}
|
|
907
|
+
)
|
|
908
|
+
debug(f"[sql_gen.join_hints_multi] generated {len(out)} candidates")
|
|
909
|
+
return {"candidates": out}
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
def _format_sql_arg(k: str, v: Any) -> str:
|
|
913
|
+
"""Format a scalar function argument for SQL expression guide.
|
|
914
|
+
|
|
915
|
+
Args:
|
|
916
|
+
|
|
917
|
+
k: Parameter key (used as ``:k`` placeholder); empty string means use literal.
|
|
918
|
+
|
|
919
|
+
v: Argument value used when ``k`` is empty.
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
|
|
923
|
+
Parameter placeholder string (for example, ``:key``) or a quoted or numeric literal.
|
|
924
|
+
"""
|
|
925
|
+
if k:
|
|
926
|
+
return f":{k}"
|
|
927
|
+
if isinstance(v, str):
|
|
928
|
+
return f"'{v}'"
|
|
929
|
+
return str(v)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def _render_group_sql(g: MulGroup) -> str:
|
|
933
|
+
"""Render a MulGroup as a SQL fragment for expression guide.
|
|
934
|
+
|
|
935
|
+
Args:
|
|
936
|
+
|
|
937
|
+
g: The ``MulGroup`` containing multiply or divide columns, aggregation function, coefficient, and optional scalar or inner-scalar function wrappers.
|
|
938
|
+
|
|
939
|
+
Returns:
|
|
940
|
+
|
|
941
|
+
SQL fragment string representing the group, for example ``"ROUND(SUM(:coeff * table.col), 2)"``.
|
|
942
|
+
"""
|
|
943
|
+
if not g.multiply:
|
|
944
|
+
return "1"
|
|
945
|
+
base = " * ".join(g.multiply)
|
|
946
|
+
if g.divide:
|
|
947
|
+
base = f"({base}) / ({' * '.join(g.divide)})"
|
|
948
|
+
if g.coeff_param_key:
|
|
949
|
+
base = f":{g.coeff_param_key} * {base}"
|
|
950
|
+
elif g.coefficient != 1.0:
|
|
951
|
+
base = f"{g.coefficient} * {base}"
|
|
952
|
+
if g.inner_scalar_func:
|
|
953
|
+
iargs = [
|
|
954
|
+
_format_sql_arg(k, v)
|
|
955
|
+
for k, v in zip(g.isarg_param_keys or [], g.inner_scalar_func_args or [], strict=False)
|
|
956
|
+
]
|
|
957
|
+
iargs += [_format_sql_arg("", v) for v in (g.inner_scalar_func_args or [])[len(g.isarg_param_keys or []) :]]
|
|
958
|
+
args_str = ", ".join(iargs)
|
|
959
|
+
if g.inner_scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
|
|
960
|
+
inner = f"{g.inner_scalar_func.upper()}({args_str}, {base})"
|
|
961
|
+
else:
|
|
962
|
+
inner = f"{g.inner_scalar_func.upper()}({base}{', ' + args_str if args_str else ''})"
|
|
963
|
+
else:
|
|
964
|
+
inner = base
|
|
965
|
+
if g.agg_func:
|
|
966
|
+
mid = f"{g.agg_func.upper()}({inner})"
|
|
967
|
+
else:
|
|
968
|
+
mid = inner
|
|
969
|
+
if g.scalar_func:
|
|
970
|
+
sargs = [_format_sql_arg(k, v) for k, v in zip(g.sarg_param_keys or [], g.scalar_func_args or [], strict=False)]
|
|
971
|
+
sargs += [_format_sql_arg("", v) for v in (g.scalar_func_args or [])[len(g.sarg_param_keys or []) :]]
|
|
972
|
+
args_str = ", ".join(sargs)
|
|
973
|
+
if g.scalar_func.lower() == "extract" and args_str:
|
|
974
|
+
return f"EXTRACT({args_str} FROM {mid})"
|
|
975
|
+
if g.scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
|
|
976
|
+
return f"{g.scalar_func.upper()}({args_str}, {mid})"
|
|
977
|
+
return f"{g.scalar_func.upper()}({mid}{', ' + args_str if args_str else ''})"
|
|
978
|
+
return mid
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
def _render_expr_sql(expr: NormalizedExpr) -> str:
|
|
982
|
+
"""Render a NormalizedExpr as a SQL fragment for expression guide.
|
|
983
|
+
|
|
984
|
+
Args:
|
|
985
|
+
|
|
986
|
+
expr: The ``NormalizedExpr`` to render, potentially containing multiple additive or subtractive groups and optional outer scalar wrapping.
|
|
987
|
+
|
|
988
|
+
Returns:
|
|
989
|
+
|
|
990
|
+
SQL fragment string that the LLM should produce for this expression.
|
|
991
|
+
"""
|
|
992
|
+
parts: list[str] = []
|
|
993
|
+
for g in expr.add_groups:
|
|
994
|
+
parts.append(_render_group_sql(g))
|
|
995
|
+
for v in expr.add_values:
|
|
996
|
+
parts.append(f":{v.param_key}" if v.param_key else str(v.value))
|
|
997
|
+
sub_parts: list[str] = []
|
|
998
|
+
for g in expr.sub_groups:
|
|
999
|
+
sub_parts.append(_render_group_sql(g))
|
|
1000
|
+
for v in expr.sub_values:
|
|
1001
|
+
sub_parts.append(f":{v.param_key}" if v.param_key else str(v.value))
|
|
1002
|
+
result = " + ".join(parts) if parts else "0"
|
|
1003
|
+
if sub_parts:
|
|
1004
|
+
result = f"{result} - {' - '.join(sub_parts)}"
|
|
1005
|
+
if expr.inner_scalar_func and not any(g.inner_scalar_func for g in expr.add_groups):
|
|
1006
|
+
iargs = [
|
|
1007
|
+
_format_sql_arg(k, v)
|
|
1008
|
+
for k, v in zip(
|
|
1009
|
+
expr.isarg_param_keys or [],
|
|
1010
|
+
expr.inner_scalar_func_args or [],
|
|
1011
|
+
strict=False,
|
|
1012
|
+
)
|
|
1013
|
+
]
|
|
1014
|
+
iargs += [
|
|
1015
|
+
_format_sql_arg("", v) for v in (expr.inner_scalar_func_args or [])[len(expr.isarg_param_keys or []) :]
|
|
1016
|
+
]
|
|
1017
|
+
args_str = ", ".join(iargs)
|
|
1018
|
+
if expr.inner_scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
|
|
1019
|
+
result = f"{expr.inner_scalar_func.upper()}({args_str}, {result})"
|
|
1020
|
+
else:
|
|
1021
|
+
result = f"{expr.inner_scalar_func.upper()}({result}{', ' + args_str if args_str else ''})"
|
|
1022
|
+
if expr.agg_func and not any(g.agg_func for g in expr.add_groups):
|
|
1023
|
+
result = f"{expr.agg_func.upper()}({result})"
|
|
1024
|
+
if expr.scalar_func and not any(g.scalar_func for g in expr.add_groups):
|
|
1025
|
+
sargs = [
|
|
1026
|
+
_format_sql_arg(k, v) for k, v in zip(expr.sarg_param_keys or [], expr.scalar_func_args or [], strict=False)
|
|
1027
|
+
]
|
|
1028
|
+
sargs += [_format_sql_arg("", v) for v in (expr.scalar_func_args or [])[len(expr.sarg_param_keys or []) :]]
|
|
1029
|
+
args_str = ", ".join(sargs)
|
|
1030
|
+
if expr.scalar_func.lower() == "extract" and args_str:
|
|
1031
|
+
result = f"EXTRACT({args_str} FROM {result})"
|
|
1032
|
+
elif expr.scalar_func.lower() in SCALAR_FUNCTIONS_LEADING_ARG and args_str:
|
|
1033
|
+
result = f"{expr.scalar_func.upper()}({args_str}, {result})"
|
|
1034
|
+
else:
|
|
1035
|
+
result = f"{expr.scalar_func.upper()}({result}{', ' + args_str if args_str else ''})"
|
|
1036
|
+
return result
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def build_deterministic_sql(
|
|
1040
|
+
intent: RuntimeIntent,
|
|
1041
|
+
cte_join_hints: dict[str, dict[str, Any]] | None = None,
|
|
1042
|
+
) -> str:
|
|
1043
|
+
"""Build a rough deterministic SQL from a RuntimeIntent.
|
|
1044
|
+
|
|
1045
|
+
The output is structurally correct but may lack JOIN clauses and dialect-specific syntax. It serves as a constrained template for the SQL LLM and as a reference for post-generation validation.
|
|
1046
|
+
|
|
1047
|
+
Each SELECT expression is rendered via ``_render_expr_sql``. CTE steps are emitted as ``WITH`` clauses with deterministic output column aliases. A ``-- <JOIN>`` placeholder marks where the LLM should insert the chosen join predicates.
|
|
1048
|
+
"""
|
|
1049
|
+
parts: list[str] = []
|
|
1050
|
+
|
|
1051
|
+
cte_steps = intent.cte_steps or []
|
|
1052
|
+
if cte_steps:
|
|
1053
|
+
cte_clauses: list[str] = []
|
|
1054
|
+
for cte in cte_steps:
|
|
1055
|
+
cte_sql = _build_deterministic_select_block(
|
|
1056
|
+
cte.select_cols or [],
|
|
1057
|
+
cte.tables or [],
|
|
1058
|
+
cte.group_by_cols or [],
|
|
1059
|
+
cte.order_by_cols or [],
|
|
1060
|
+
cte.filters_param or [],
|
|
1061
|
+
cte.having_param or [],
|
|
1062
|
+
cte.limit,
|
|
1063
|
+
cte.grain or "row_level",
|
|
1064
|
+
cte.output_columns or [],
|
|
1065
|
+
)
|
|
1066
|
+
cte_clauses.append(f"{cte.cte_name} AS (\n{cte_sql}\n)")
|
|
1067
|
+
parts.append("WITH " + ",\n".join(cte_clauses))
|
|
1068
|
+
|
|
1069
|
+
main_sql = _build_deterministic_select_block(
|
|
1070
|
+
intent.select_cols or [],
|
|
1071
|
+
intent.tables or [],
|
|
1072
|
+
intent.group_by_cols or [],
|
|
1073
|
+
intent.order_by_cols or [],
|
|
1074
|
+
intent.filters_param or [],
|
|
1075
|
+
intent.having_param or [],
|
|
1076
|
+
intent.limit,
|
|
1077
|
+
intent.grain or "row_level",
|
|
1078
|
+
)
|
|
1079
|
+
parts.append(main_sql)
|
|
1080
|
+
|
|
1081
|
+
return "\n".join(parts)
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
def _join_clause_parts_with_bool_op(
|
|
1085
|
+
parts: list[tuple[str, str]],
|
|
1086
|
+
) -> str:
|
|
1087
|
+
"""Chain SQL clause fragments using their positional boolean operators.
|
|
1088
|
+
|
|
1089
|
+
Each element's ``bool_op`` is the connector between that element and
|
|
1090
|
+
the next. The last element's ``bool_op`` is unused. Fragments are
|
|
1091
|
+
joined sequentially to preserve the canonical ordering established
|
|
1092
|
+
by ``_canonicalize_condition_order``.
|
|
1093
|
+
|
|
1094
|
+
When any ``OR`` connector is present the entire expression is wrapped
|
|
1095
|
+
in parentheses to maintain correct SQL precedence in outer contexts.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
parts: List of ``(sql_fragment, bool_op)`` tuples where
|
|
1099
|
+
``bool_op`` is ``"AND"`` or ``"OR"``.
|
|
1100
|
+
|
|
1101
|
+
Returns:
|
|
1102
|
+
Combined SQL predicate string.
|
|
1103
|
+
"""
|
|
1104
|
+
if not parts:
|
|
1105
|
+
return ""
|
|
1106
|
+
|
|
1107
|
+
result = parts[0][0]
|
|
1108
|
+
for i in range(1, len(parts)):
|
|
1109
|
+
connector = parts[i - 1][1]
|
|
1110
|
+
result = f"{result} {connector} {parts[i][0]}"
|
|
1111
|
+
|
|
1112
|
+
has_or = any(op == "OR" for _, op in parts[:-1])
|
|
1113
|
+
if has_or and len(parts) > 1:
|
|
1114
|
+
result = f"({result})"
|
|
1115
|
+
|
|
1116
|
+
return result
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
def _build_deterministic_select_block(
|
|
1120
|
+
select_cols: list[SelectCol],
|
|
1121
|
+
tables: list[str],
|
|
1122
|
+
group_by_cols: list[NormalizedExpr],
|
|
1123
|
+
order_by_cols: list,
|
|
1124
|
+
filters_param: list,
|
|
1125
|
+
having_param: list,
|
|
1126
|
+
limit: int | None,
|
|
1127
|
+
grain: str,
|
|
1128
|
+
output_aliases: list[str] | None = None,
|
|
1129
|
+
) -> str:
|
|
1130
|
+
"""Build a single SELECT block from structured intent clauses.
|
|
1131
|
+
|
|
1132
|
+
Renders SELECT, FROM (with ``-- <JOIN>`` placeholder), WHERE, GROUP BY, HAVING, ORDER BY, and LIMIT clauses.
|
|
1133
|
+
"""
|
|
1134
|
+
lines: list[str] = []
|
|
1135
|
+
|
|
1136
|
+
select_exprs: list[str] = []
|
|
1137
|
+
for idx, sc in enumerate(select_cols):
|
|
1138
|
+
rendered = _render_expr_sql(sc.expr)
|
|
1139
|
+
if output_aliases and idx < len(output_aliases):
|
|
1140
|
+
rendered = f"{rendered} AS {output_aliases[idx]}"
|
|
1141
|
+
select_exprs.append(rendered)
|
|
1142
|
+
|
|
1143
|
+
lines.append("SELECT " + ", ".join(select_exprs))
|
|
1144
|
+
|
|
1145
|
+
if tables:
|
|
1146
|
+
lines.append(f"FROM {tables[0]}")
|
|
1147
|
+
if len(tables) > 1:
|
|
1148
|
+
lines.append("-- <JOIN: integrate from join candidates>")
|
|
1149
|
+
|
|
1150
|
+
where_parts: list[tuple[str, str]] = []
|
|
1151
|
+
dialect_type = EngineConfig.TYPE or "postgresql"
|
|
1152
|
+
for fp in filters_param:
|
|
1153
|
+
left = _render_expr_sql(fp.left_expr)
|
|
1154
|
+
op = fp.op or "="
|
|
1155
|
+
case_insensitive = fp.value_type == "string" and op.lower() not in (
|
|
1156
|
+
"is null", "is not null", "ilike", "not ilike",
|
|
1157
|
+
)
|
|
1158
|
+
if case_insensitive:
|
|
1159
|
+
left = _wrap_for_case_insensitive(left, dialect_type)
|
|
1160
|
+
bool_op = getattr(fp, "bool_op", "AND") or "AND"
|
|
1161
|
+
if op.lower() in ("is null", "is not null"):
|
|
1162
|
+
where_parts.append((f"{left} {op.upper()}", bool_op))
|
|
1163
|
+
elif fp.value_type == "date_window" and isinstance(fp.raw_value, dict):
|
|
1164
|
+
for dw_frag in _render_date_window_where(fp, left, dialect_type):
|
|
1165
|
+
where_parts.append((dw_frag, "AND"))
|
|
1166
|
+
elif fp.value_type == "date_diff" and isinstance(fp.raw_value, dict):
|
|
1167
|
+
rv = fp.raw_value
|
|
1168
|
+
unit = rv.get("unit", "day")
|
|
1169
|
+
amount = int(rv.get("amount", 0)) if rv.get("amount") is not None else 0
|
|
1170
|
+
op = fp.op or ">"
|
|
1171
|
+
frag = render_date_diff_expr(dialect_type, left, op, unit, amount)
|
|
1172
|
+
where_parts.append((frag, "AND"))
|
|
1173
|
+
elif fp.right_expr:
|
|
1174
|
+
right = _render_expr_sql(fp.right_expr)
|
|
1175
|
+
if case_insensitive:
|
|
1176
|
+
right = _wrap_for_case_insensitive(right, dialect_type)
|
|
1177
|
+
where_parts.append((f"{left} {op} {right}", bool_op))
|
|
1178
|
+
elif fp.param_key:
|
|
1179
|
+
val_needs_lower = case_insensitive and op.lower() in ("like", "not like")
|
|
1180
|
+
val_ref = f"LOWER(:{fp.param_key})" if val_needs_lower else f":{fp.param_key}"
|
|
1181
|
+
where_parts.append((f"{left} {op} {val_ref}", bool_op))
|
|
1182
|
+
elif fp.raw_value is not None:
|
|
1183
|
+
pkey = fp.param_key or "p?"
|
|
1184
|
+
val_needs_lower = case_insensitive and op.lower() in ("like", "not like")
|
|
1185
|
+
val_ref = f"LOWER(:{pkey})" if val_needs_lower else f":{pkey}"
|
|
1186
|
+
where_parts.append((f"{left} {op} {val_ref}", bool_op))
|
|
1187
|
+
if where_parts:
|
|
1188
|
+
lines.append("WHERE " + _join_clause_parts_with_bool_op(where_parts))
|
|
1189
|
+
|
|
1190
|
+
if group_by_cols:
|
|
1191
|
+
gb_exprs = [_render_expr_sql(g) for g in group_by_cols]
|
|
1192
|
+
lines.append("GROUP BY " + ", ".join(gb_exprs))
|
|
1193
|
+
|
|
1194
|
+
having_parts: list[tuple[str, str]] = []
|
|
1195
|
+
for hp in having_param:
|
|
1196
|
+
left = _render_expr_sql(hp.left_expr)
|
|
1197
|
+
op = hp.op or ">"
|
|
1198
|
+
bool_op = getattr(hp, "bool_op", "AND") or "AND"
|
|
1199
|
+
if hp.right_expr:
|
|
1200
|
+
right = _render_expr_sql(hp.right_expr)
|
|
1201
|
+
having_parts.append((f"{left} {op} {right}", bool_op))
|
|
1202
|
+
elif hp.param_key:
|
|
1203
|
+
having_parts.append((f"{left} {op} :{hp.param_key}", bool_op))
|
|
1204
|
+
else:
|
|
1205
|
+
having_parts.append((f"{left} {op} ?", bool_op))
|
|
1206
|
+
if having_parts:
|
|
1207
|
+
lines.append("HAVING " + _join_clause_parts_with_bool_op(having_parts))
|
|
1208
|
+
|
|
1209
|
+
if order_by_cols:
|
|
1210
|
+
ob_exprs = []
|
|
1211
|
+
for obc in order_by_cols:
|
|
1212
|
+
rendered = _render_expr_sql(obc.expr)
|
|
1213
|
+
direction = obc.direction.upper() if obc.direction else "ASC"
|
|
1214
|
+
ob_exprs.append(f"{rendered} {direction}")
|
|
1215
|
+
lines.append("ORDER BY " + ", ".join(ob_exprs))
|
|
1216
|
+
|
|
1217
|
+
if limit:
|
|
1218
|
+
lines.append(f"LIMIT {limit}")
|
|
1219
|
+
|
|
1220
|
+
return "\n".join(lines)
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
def _generate_col_alias(sc: SelectCol) -> str:
|
|
1224
|
+
"""Build a deterministic display alias from a SelectCol's expression metadata.
|
|
1225
|
+
|
|
1226
|
+
Rules:
|
|
1227
|
+
* Plain column ``table.col`` → ``col``.
|
|
1228
|
+
* Aggregate ``COUNT(table.col)`` → ``count_col``.
|
|
1229
|
+
* Distinct aggregate ``COUNT(DISTINCT table.col)`` → ``count_distinct_col``.
|
|
1230
|
+
* Scalar wrapper ``ROUND(SUM(table.col), 2)`` → ``round_sum_col``.
|
|
1231
|
+
* Arithmetic ``table.a * table.b`` → ``a_times_b``.
|
|
1232
|
+
* Fallback: ``col_<idx>`` assigned by the caller.
|
|
1233
|
+
|
|
1234
|
+
Args:
|
|
1235
|
+
sc: The ``SelectCol`` to derive an alias for.
|
|
1236
|
+
|
|
1237
|
+
Returns:
|
|
1238
|
+
A lowercase alias string safe for SQL ``AS`` usage.
|
|
1239
|
+
"""
|
|
1240
|
+
expr = sc.expr
|
|
1241
|
+
col = expr.primary_column
|
|
1242
|
+
if col:
|
|
1243
|
+
col_clean = col.rsplit(".", 1)[-1].lower()
|
|
1244
|
+
else:
|
|
1245
|
+
col_clean = ""
|
|
1246
|
+
|
|
1247
|
+
groups = expr.add_groups or []
|
|
1248
|
+
if len(groups) >= 2 and not expr.agg_func and not expr.scalar_func:
|
|
1249
|
+
parts = [g.multiply[0].rsplit(".", 1)[-1].lower() if g.multiply else "x" for g in groups]
|
|
1250
|
+
alias = "_times_".join(parts)
|
|
1251
|
+
elif expr.sub_groups and groups:
|
|
1252
|
+
plus_part = groups[0].multiply[0].rsplit(".", 1)[-1].lower() if groups[0].multiply else "x"
|
|
1253
|
+
minus_part = (
|
|
1254
|
+
expr.sub_groups[0].multiply[0].rsplit(".", 1)[-1].lower()
|
|
1255
|
+
if expr.sub_groups[0].multiply
|
|
1256
|
+
else "y"
|
|
1257
|
+
)
|
|
1258
|
+
alias = f"{plus_part}_minus_{minus_part}"
|
|
1259
|
+
elif col_clean:
|
|
1260
|
+
alias = col_clean
|
|
1261
|
+
else:
|
|
1262
|
+
return ""
|
|
1263
|
+
|
|
1264
|
+
distinct_prefix = ""
|
|
1265
|
+
if groups and groups[0].multiply:
|
|
1266
|
+
term = groups[0].multiply[0].upper()
|
|
1267
|
+
if "DISTINCT " in term:
|
|
1268
|
+
distinct_prefix = "distinct_"
|
|
1269
|
+
|
|
1270
|
+
if expr.agg_func:
|
|
1271
|
+
alias = f"{expr.agg_func}_{distinct_prefix}{alias}"
|
|
1272
|
+
if expr.inner_scalar_func:
|
|
1273
|
+
alias = f"{expr.inner_scalar_func}_{alias}"
|
|
1274
|
+
if expr.scalar_func:
|
|
1275
|
+
alias = f"{expr.scalar_func}_{alias}"
|
|
1276
|
+
|
|
1277
|
+
return alias.lower()
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
def deterministic_alias_sql(sql_param: str, intent: RuntimeIntent) -> str:
|
|
1281
|
+
"""Add deterministic display aliases to each SELECT expression.
|
|
1282
|
+
|
|
1283
|
+
Parses the ``SELECT ... FROM`` portion of the parameterized SQL,
|
|
1284
|
+
matches each comma-separated expression positionally with the
|
|
1285
|
+
intent's ``select_cols``, and appends an ``AS alias`` clause derived
|
|
1286
|
+
from column metadata.
|
|
1287
|
+
|
|
1288
|
+
Args:
|
|
1289
|
+
sql_param: Parameterized SQL string produced by ``build_deterministic_sql``.
|
|
1290
|
+
intent: The ``RuntimeIntent`` whose ``select_cols`` drive aliasing.
|
|
1291
|
+
|
|
1292
|
+
Returns:
|
|
1293
|
+
SQL string with ``AS`` aliases on every SELECT expression. Returns
|
|
1294
|
+
the original SQL unchanged when the SELECT clause cannot be parsed
|
|
1295
|
+
or the column count does not match.
|
|
1296
|
+
"""
|
|
1297
|
+
import re
|
|
1298
|
+
|
|
1299
|
+
match = re.search(r"(?i)\bSELECT\s+", sql_param)
|
|
1300
|
+
if not match:
|
|
1301
|
+
return sql_param
|
|
1302
|
+
|
|
1303
|
+
select_start = match.end()
|
|
1304
|
+
from_match = re.search(r"(?i)\bFROM\b", sql_param[select_start:])
|
|
1305
|
+
if not from_match:
|
|
1306
|
+
return sql_param
|
|
1307
|
+
|
|
1308
|
+
select_body = sql_param[select_start : select_start + from_match.start()].strip()
|
|
1309
|
+
rest = sql_param[select_start + from_match.start() :]
|
|
1310
|
+
|
|
1311
|
+
depth = 0
|
|
1312
|
+
parts: list[str] = []
|
|
1313
|
+
current: list[str] = []
|
|
1314
|
+
for ch in select_body:
|
|
1315
|
+
if ch == "(":
|
|
1316
|
+
depth += 1
|
|
1317
|
+
elif ch == ")":
|
|
1318
|
+
depth -= 1
|
|
1319
|
+
if ch == "," and depth == 0:
|
|
1320
|
+
parts.append("".join(current).strip())
|
|
1321
|
+
current = []
|
|
1322
|
+
else:
|
|
1323
|
+
current.append(ch)
|
|
1324
|
+
if current:
|
|
1325
|
+
parts.append("".join(current).strip())
|
|
1326
|
+
|
|
1327
|
+
cols = intent.select_cols or []
|
|
1328
|
+
if len(parts) != len(cols):
|
|
1329
|
+
return sql_param
|
|
1330
|
+
|
|
1331
|
+
aliased: list[str] = []
|
|
1332
|
+
seen_aliases: set[str] = set()
|
|
1333
|
+
for idx, (expr_str, sc) in enumerate(zip(parts, cols, strict=False)):
|
|
1334
|
+
alias = _generate_col_alias(sc)
|
|
1335
|
+
if not alias:
|
|
1336
|
+
alias = f"col_{idx + 1}"
|
|
1337
|
+
base = alias
|
|
1338
|
+
counter = 2
|
|
1339
|
+
while alias in seen_aliases:
|
|
1340
|
+
alias = f"{base}_{counter}"
|
|
1341
|
+
counter += 1
|
|
1342
|
+
seen_aliases.add(alias)
|
|
1343
|
+
aliased.append(f"{expr_str} AS {alias}")
|
|
1344
|
+
|
|
1345
|
+
prefix = sql_param[: match.start()] + match.group()
|
|
1346
|
+
return prefix + ", ".join(aliased) + " " + rest
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
def _render_date_window_where(
|
|
1350
|
+
fp: FilterParam, left_rendered: str, dialect_type: str
|
|
1351
|
+
) -> list[str]:
|
|
1352
|
+
"""Render WHERE clause part(s) for a date_window filter.
|
|
1353
|
+
|
|
1354
|
+
For raw_value with start/end keys emits two predicates (>= start AND <= end).
|
|
1355
|
+
For unit/offset uses render_date_window_expr. Returns a list of one or two
|
|
1356
|
+
fragments to AND together.
|
|
1357
|
+
"""
|
|
1358
|
+
rv = fp.raw_value if isinstance(fp.raw_value, dict) else {}
|
|
1359
|
+
if "start" in rv and "end" in rv:
|
|
1360
|
+
start_val = rv["start"]
|
|
1361
|
+
end_val = rv["end"]
|
|
1362
|
+
if isinstance(start_val, str) and isinstance(end_val, str):
|
|
1363
|
+
return [
|
|
1364
|
+
f"{left_rendered} >= '{start_val}'",
|
|
1365
|
+
f"{left_rendered} <= '{end_val}'",
|
|
1366
|
+
]
|
|
1367
|
+
unit = rv.get("unit", "day")
|
|
1368
|
+
offset = int(rv.get("offset", 0)) if rv.get("offset") is not None else 0
|
|
1369
|
+
op = fp.op or ">="
|
|
1370
|
+
return [render_date_window_expr(dialect_type, left_rendered, op, unit, offset)]
|
|
1371
|
+
|
|
1372
|
+
|
|
1373
|
+
def build_join_choice_prompt(
|
|
1374
|
+
q_norm: str,
|
|
1375
|
+
deterministic_sql: str,
|
|
1376
|
+
join_candidates: dict[str, Any],
|
|
1377
|
+
cte_join_hints: dict[str, dict[str, Any]] | None = None,
|
|
1378
|
+
) -> tuple[str, str]:
|
|
1379
|
+
"""Build minimal prompt for LLM to return only join candidate IDs.
|
|
1380
|
+
|
|
1381
|
+
Returns (system_prompt, user_prompt). Response must be JSON with
|
|
1382
|
+
chosen_join_candidate_id and optionally chosen_cte_join_candidate_ids.
|
|
1383
|
+
"""
|
|
1384
|
+
system = (
|
|
1385
|
+
"You are a join selector for text-to-SQL. Output ONLY valid JSON. "
|
|
1386
|
+
"Return chosen_join_candidate_id and, if the query has CTEs that need joins, "
|
|
1387
|
+
"chosen_cte_join_candidate_ids mapping each CTE name to its candidate_id."
|
|
1388
|
+
)
|
|
1389
|
+
candidates = join_candidates.get("candidates", [])
|
|
1390
|
+
cte_names = list(cte_join_hints.keys()) if cte_join_hints else []
|
|
1391
|
+
cte_payload = None
|
|
1392
|
+
if cte_names and cte_join_hints:
|
|
1393
|
+
cte_payload = {}
|
|
1394
|
+
for cte, h in cte_join_hints.items():
|
|
1395
|
+
cands = h.get("candidates", []) or []
|
|
1396
|
+
cte_payload[cte] = [
|
|
1397
|
+
{"candidate_id": c.get("candidate_id"), "join_path_signature": c.get("join_path_signature")}
|
|
1398
|
+
for c in cands
|
|
1399
|
+
]
|
|
1400
|
+
user = stable_json(
|
|
1401
|
+
{
|
|
1402
|
+
"task": (
|
|
1403
|
+
"Given the question and the deterministic SQL template, choose the join candidate "
|
|
1404
|
+
"that correctly connects the tables. Return only the IDs; do not modify the SQL."
|
|
1405
|
+
),
|
|
1406
|
+
"question": q_norm,
|
|
1407
|
+
"deterministic_sql": deterministic_sql,
|
|
1408
|
+
"join_candidates": [
|
|
1409
|
+
{"candidate_id": c.get("candidate_id"), "join_path_signature": c.get("join_path_signature")}
|
|
1410
|
+
for c in candidates
|
|
1411
|
+
],
|
|
1412
|
+
"cte_join_candidates": cte_payload,
|
|
1413
|
+
"output_format": {
|
|
1414
|
+
"chosen_join_candidate_id": "J00 or J01, J02, ...",
|
|
1415
|
+
"chosen_cte_join_candidate_ids": "Optional dict: cte_name -> candidate_id",
|
|
1416
|
+
},
|
|
1417
|
+
}
|
|
1418
|
+
)
|
|
1419
|
+
return system, user
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
def get_join_choice_from_llm(
|
|
1423
|
+
q_norm: str,
|
|
1424
|
+
deterministic_sql: str,
|
|
1425
|
+
join_candidates: dict[str, Any],
|
|
1426
|
+
cte_join_hints: dict[str, dict[str, Any]] | None,
|
|
1427
|
+
) -> tuple[str, dict[str, str]]:
|
|
1428
|
+
"""Call LLM to get only chosen join candidate ID and per-CTE IDs.
|
|
1429
|
+
|
|
1430
|
+
Returns (chosen_join_candidate_id, chosen_cte_join_candidate_ids).
|
|
1431
|
+
Defaults to J00 and empty dict on parse failure.
|
|
1432
|
+
"""
|
|
1433
|
+
system, user = build_join_choice_prompt(
|
|
1434
|
+
q_norm, deterministic_sql, join_candidates, cte_join_hints
|
|
1435
|
+
)
|
|
1436
|
+
parsed = llm_json(system, user, retries=1, task="sql")
|
|
1437
|
+
if not isinstance(parsed, dict):
|
|
1438
|
+
return "J00", {}
|
|
1439
|
+
chosen = parsed.get("chosen_join_candidate_id")
|
|
1440
|
+
if not isinstance(chosen, str):
|
|
1441
|
+
chosen = "J00"
|
|
1442
|
+
cte_ids = parsed.get("chosen_cte_join_candidate_ids")
|
|
1443
|
+
if not isinstance(cte_ids, dict):
|
|
1444
|
+
cte_ids = {}
|
|
1445
|
+
return chosen, {k: v for k, v in cte_ids.items() if isinstance(k, str) and isinstance(v, str)}
|
|
1446
|
+
|
|
1447
|
+
|
|
1448
|
+
def build_repair_prompt(
|
|
1449
|
+
schema: SchemaGraph,
|
|
1450
|
+
q_norm: str,
|
|
1451
|
+
prev_sql: str,
|
|
1452
|
+
db_error: str,
|
|
1453
|
+
nl_error: str,
|
|
1454
|
+
join_hints: dict[str, Any],
|
|
1455
|
+
cte_join_hints: dict[str, dict[str, Any]] | None = None,
|
|
1456
|
+
) -> tuple[str, str]:
|
|
1457
|
+
"""Build system and user prompts for SQL repair.
|
|
1458
|
+
|
|
1459
|
+
Args:
|
|
1460
|
+
|
|
1461
|
+
schema: The schema graph.
|
|
1462
|
+
|
|
1463
|
+
q_norm: The normalised user question string.
|
|
1464
|
+
|
|
1465
|
+
prev_sql: The previously generated SQL that failed.
|
|
1466
|
+
|
|
1467
|
+
db_error: The raw database error message.
|
|
1468
|
+
|
|
1469
|
+
nl_error: The human-readable explanation of the error.
|
|
1470
|
+
|
|
1471
|
+
join_hints: Join hint candidates from ``join_hints_multi``.
|
|
1472
|
+
|
|
1473
|
+
cte_join_hints: Optional dict mapping CTE name to join hints for CTE steps.
|
|
1474
|
+
|
|
1475
|
+
Returns:
|
|
1476
|
+
|
|
1477
|
+
Tuple of ``(system_prompt, user_prompt)`` strings ready for the LLM.
|
|
1478
|
+
"""
|
|
1479
|
+
dialect_type = EngineConfig.TYPE
|
|
1480
|
+
if dialect_type == "databricks":
|
|
1481
|
+
dialect_name = "Spark"
|
|
1482
|
+
elif dialect_type == "postgresql":
|
|
1483
|
+
dialect_name = "PostgreSQL"
|
|
1484
|
+
else:
|
|
1485
|
+
dialect_name = "SQL"
|
|
1486
|
+
|
|
1487
|
+
hard_constraints = [
|
|
1488
|
+
f"Dialect: {dialect_name}. Use {dialect_name} syntax ONLY.",
|
|
1489
|
+
"Output ONLY valid JSON. No markdown. No explanations.",
|
|
1490
|
+
"SYNTAX-ONLY REPAIR: fix syntax errors, column qualification, keyword casing, operator typos, parameter placeholder format.",
|
|
1491
|
+
"Do NOT add, remove, or reorder tables in FROM/JOIN clauses.",
|
|
1492
|
+
"Do NOT change SELECT columns, aggregation functions, or expressions.",
|
|
1493
|
+
"Do NOT alter GROUP BY, ORDER BY, or HAVING structure.",
|
|
1494
|
+
"Do NOT change join conditions, join types, or join order.",
|
|
1495
|
+
"Do NOT add or remove WHERE/HAVING predicates.",
|
|
1496
|
+
"ALWAYS qualify column names with table name (for example, <table_1>.<column_1>).",
|
|
1497
|
+
"Do not alias SELECT expressions. No AS clauses on SELECT columns or aggregates.",
|
|
1498
|
+
"HAVING clause: Use full aggregation expression, NOT an alias.",
|
|
1499
|
+
"Use :p1, :p2 parameter placeholders for filter/having values, NOT literal values.",
|
|
1500
|
+
"WHERE/HAVING predicates: column reference on LEFT side of comparison.",
|
|
1501
|
+
"ORDER BY must include explicit ASC or DESC direction.",
|
|
1502
|
+
"chosen_join_candidate_id must match the SAME candidate as the original SQL.",
|
|
1503
|
+
]
|
|
1504
|
+
|
|
1505
|
+
if cte_join_hints:
|
|
1506
|
+
cte_names = list(cte_join_hints.keys())
|
|
1507
|
+
hard_constraints.append(
|
|
1508
|
+
f"CTE join predicates MUST exactly match chosen CTE join candidate for: {', '.join(cte_names)}"
|
|
1509
|
+
)
|
|
1510
|
+
hard_constraints.append(
|
|
1511
|
+
"chosen_cte_join_candidate_ids must specify join candidate for each CTE with multi-table joins."
|
|
1512
|
+
)
|
|
1513
|
+
|
|
1514
|
+
output_format: dict[str, Any] = {"sql": "...", "chosen_join_candidate_id": "J01"}
|
|
1515
|
+
if cte_join_hints:
|
|
1516
|
+
output_format["chosen_cte_join_candidate_ids"] = {"cte_name": "J01"}
|
|
1517
|
+
|
|
1518
|
+
prompt_data: dict[str, Any] = {
|
|
1519
|
+
"task": "Fix ONLY syntax errors in the SQL query. Do not change query structure, tables, joins, or logic.",
|
|
1520
|
+
"error_info": {
|
|
1521
|
+
"db_error": db_error,
|
|
1522
|
+
"explanation": nl_error,
|
|
1523
|
+
},
|
|
1524
|
+
"hard_constraints": hard_constraints,
|
|
1525
|
+
"output_schema": output_format,
|
|
1526
|
+
"schema": schema.schema_literal_text,
|
|
1527
|
+
"join_candidates": join_hints,
|
|
1528
|
+
"question": q_norm,
|
|
1529
|
+
"previous_sql": prev_sql,
|
|
1530
|
+
"db_error": db_error,
|
|
1531
|
+
"error_explanation": nl_error,
|
|
1532
|
+
}
|
|
1533
|
+
|
|
1534
|
+
if cte_join_hints:
|
|
1535
|
+
prompt_data["cte_join_candidates"] = cte_join_hints
|
|
1536
|
+
|
|
1537
|
+
return (SQL_REPAIR_SYSTEM_PROMPT, stable_json(prompt_data))
|