aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/utils.py
ADDED
|
@@ -0,0 +1,973 @@
|
|
|
1
|
+
"""Shared utility functions for intent normalisation, question validation, fingerprint generation, and natural language question generation.
|
|
2
|
+
|
|
3
|
+
Provides SHA-256 intent fingerprinting over canonical stable JSON, fuzzy question matching with Levenshtein distance, and filter/HAVING/CTE normalisation helpers used by both the main pipeline and the query simulator. Also exposes the ``generate_question`` LLM helper used by QSim and Simulator to produce human-readable questions from structured query intent.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import random
|
|
9
|
+
import re
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from .config import (
|
|
13
|
+
AGG_PATTERN,
|
|
14
|
+
QUESTION_STARTS_AGG,
|
|
15
|
+
QUESTION_STARTS_GROUP,
|
|
16
|
+
QUESTION_STARTS_LIST,
|
|
17
|
+
VALID_AGGREGATION_FUNCTIONS,
|
|
18
|
+
VALID_EXPECTED_ROWS,
|
|
19
|
+
VALID_FILTER_OPS,
|
|
20
|
+
VALID_GRAINS,
|
|
21
|
+
VALID_HAVING_OPS,
|
|
22
|
+
PolicyConfig,
|
|
23
|
+
)
|
|
24
|
+
from .contracts_base import CteOutputColumnMeta, SchemaGraph, SQLShape
|
|
25
|
+
from .contracts_core import (
|
|
26
|
+
FilterParam,
|
|
27
|
+
HavingParam,
|
|
28
|
+
NormalizedExpr,
|
|
29
|
+
OrderByCol,
|
|
30
|
+
RuntimeCteStep,
|
|
31
|
+
RuntimeIntent,
|
|
32
|
+
SelectCol,
|
|
33
|
+
)
|
|
34
|
+
from .core_utils import debug, llm_json, normalize_question, sha256, stable_json
|
|
35
|
+
from .intent_resolve import sort_filters, sort_having
|
|
36
|
+
|
|
37
|
+
_QUESTION_VALIDATION_SYSTEM = (
|
|
38
|
+
"You decide if user input is a database query request or not.\n\n"
|
|
39
|
+
"A database query request asks to retrieve, count, sum, filter, sort, or analyze data.\n"
|
|
40
|
+
"When in doubt, treat it as a valid query request.\n\n"
|
|
41
|
+
"Only mark as invalid if it is clearly:\n"
|
|
42
|
+
'- Chitchat (e.g. "hello", "thanks", "who are you")\n'
|
|
43
|
+
'- A help or instruction request (e.g. "how do I query", "explain SQL")\n'
|
|
44
|
+
'- General knowledge unrelated to data (e.g. "what is the capital of France")\n\n'
|
|
45
|
+
"Respond with JSON containing exactly three fields:\n"
|
|
46
|
+
'- "valid_database_question": "yes" or "no"\n'
|
|
47
|
+
'- "query_type": "allowed" if read/SELECT operation, "restricted" if write/destructive operation (UPDATE/DELETE/INSERT/DROP/ALTER/CREATE/TRUNCATE), "unspecified" if unclear\n'
|
|
48
|
+
'- "corrected": the input with spelling typos fixed only. Do NOT remove, reorder, or rephrase any words. If no corrections needed, return the original text unchanged.\n\n'
|
|
49
|
+
"Respond ONLY with valid JSON, no explanation."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def validate_question(question: str) -> tuple[bool, str, str]:
|
|
54
|
+
"""Validate a question and return typo-corrected text via LLM.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
|
|
58
|
+
question: The raw user input string.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
|
|
62
|
+
Tuple of ``(is_valid, query_type, corrected_text)`` where ``is_valid`` is True for valid database questions, ``query_type`` is one of ``'allowed'``, ``'restricted'``, or ``'unspecified'``, and ``corrected_text`` is the typo-corrected version of the original input.
|
|
63
|
+
"""
|
|
64
|
+
result = llm_json(_QUESTION_VALIDATION_SYSTEM, question, task="default")
|
|
65
|
+
if not result:
|
|
66
|
+
return False, "invalid", question
|
|
67
|
+
valid = result.get("valid_database_question", "").lower() == "yes"
|
|
68
|
+
query_type = result.get("query_type", "unspecified").lower()
|
|
69
|
+
corrected = result.get("corrected", question) or question
|
|
70
|
+
if not valid:
|
|
71
|
+
return False, "invalid", corrected
|
|
72
|
+
if query_type == "restricted":
|
|
73
|
+
return False, "restricted", corrected
|
|
74
|
+
return True, "allowed", corrected
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def sql_shape(sql: str, intent: RuntimeIntent) -> SQLShape:
|
|
78
|
+
"""Extract structural features from SQL and intent for shape comparison.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
|
|
82
|
+
sql: The generated SQL string.
|
|
83
|
+
intent: The RuntimeIntent providing filter and HAVING counts (including CTE steps).
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
|
|
87
|
+
SQLShape with join count, GROUP BY flag, aggregation flag, CTE count, filter count, HAVING count, and DISTINCT flag.
|
|
88
|
+
"""
|
|
89
|
+
s = sql.lower()
|
|
90
|
+
num_filters = len(intent.filters_param or [])
|
|
91
|
+
num_having = len(intent.having_param or [])
|
|
92
|
+
for cte in intent.cte_steps or []:
|
|
93
|
+
num_filters += len(cte.filters_param or [])
|
|
94
|
+
num_having += len(cte.having_param or [])
|
|
95
|
+
return SQLShape(
|
|
96
|
+
num_joins=len(re.findall(r"\bjoin\b", s)),
|
|
97
|
+
has_group_by=bool(re.search(r"\bgroup\s+by\b", s)),
|
|
98
|
+
has_agg=bool(re.search(r"\b(count|sum|avg|min|max)\s*\(", s)),
|
|
99
|
+
num_cte=len(intent.cte_steps or []),
|
|
100
|
+
num_filters=num_filters,
|
|
101
|
+
num_having=num_having,
|
|
102
|
+
has_distinct=bool(re.search(r"\bselect\s+distinct\b", s)),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _levenshtein_distance(s1: str, s2: str) -> int:
|
|
107
|
+
"""Compute Levenshtein edit distance between two strings.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
|
|
111
|
+
s1: The first string.
|
|
112
|
+
s2: The second string.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
|
|
116
|
+
Integer edit distance (number of single-character insertions, deletions, or substitutions required to transform s1 into s2).
|
|
117
|
+
"""
|
|
118
|
+
if len(s1) < len(s2):
|
|
119
|
+
return _levenshtein_distance(s2, s1)
|
|
120
|
+
if len(s2) == 0:
|
|
121
|
+
return len(s1)
|
|
122
|
+
previous_row = range(len(s2) + 1)
|
|
123
|
+
for i, c1 in enumerate(s1):
|
|
124
|
+
current_row = [i + 1]
|
|
125
|
+
for j, c2 in enumerate(s2):
|
|
126
|
+
insertions = previous_row[j + 1] + 1
|
|
127
|
+
deletions = current_row[j] + 1
|
|
128
|
+
substitutions = previous_row[j] + (c1 != c2)
|
|
129
|
+
current_row.append(min(insertions, deletions, substitutions))
|
|
130
|
+
previous_row = current_row
|
|
131
|
+
return previous_row[-1]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def exact_question_match(
|
|
135
|
+
q1: str,
|
|
136
|
+
q2: str,
|
|
137
|
+
max_distance: int = PolicyConfig.FUZZY_MATCH_MAX_DISTANCE,
|
|
138
|
+
label: str = "",
|
|
139
|
+
) -> bool:
|
|
140
|
+
"""Check if two questions are identical after normalisation, within a fuzzy tolerance.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
|
|
144
|
+
q1: The first question string.
|
|
145
|
+
q2: The second question string.
|
|
146
|
+
max_distance: Maximum total Levenshtein distance across all token pairs.
|
|
147
|
+
label: Optional debug label appended to log messages.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
|
|
151
|
+
True if the questions match within the specified fuzzy tolerance.
|
|
152
|
+
"""
|
|
153
|
+
tag = f"[utils.exact_question_match] {label}" if label else "[utils.exact_question_match]"
|
|
154
|
+
q1_norm = normalize_question(q1)
|
|
155
|
+
q2_norm = normalize_question(q2)
|
|
156
|
+
tokens1 = _tokenize(q1_norm)
|
|
157
|
+
tokens2 = _tokenize(q2_norm)
|
|
158
|
+
if len(tokens1) == 0 or len(tokens2) == 0:
|
|
159
|
+
debug(f"{tag} FAIL empty_tokens t1={len(tokens1)} t2={len(tokens2)}")
|
|
160
|
+
return False
|
|
161
|
+
if len(tokens1) != len(tokens2):
|
|
162
|
+
debug(f"{tag} FAIL token_count t1={len(tokens1)} t2={len(tokens2)}")
|
|
163
|
+
return False
|
|
164
|
+
total_dist = 0
|
|
165
|
+
worst_pair = ("", "")
|
|
166
|
+
for t1, t2 in zip(tokens1, tokens2, strict=False):
|
|
167
|
+
dist = _levenshtein_distance(t1, t2)
|
|
168
|
+
total_dist += dist
|
|
169
|
+
if dist > 0:
|
|
170
|
+
worst_pair = (t1, t2)
|
|
171
|
+
if total_dist > max_distance:
|
|
172
|
+
debug(f"{tag} FAIL total_dist={total_dist} worst_pair='{worst_pair[0]}'→'{worst_pair[1]}'")
|
|
173
|
+
return False
|
|
174
|
+
debug(f"{tag} MATCH tokens={len(tokens1)} total_dist={total_dist}")
|
|
175
|
+
return True
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _normalize_filters(filters: list[Any]) -> list[FilterParam]:
|
|
179
|
+
"""Normalise a filter list into sorted FilterParam objects for intent-key hashing.
|
|
180
|
+
|
|
181
|
+
Accepts both FilterParam instances and raw dicts (as returned by the LLM). Operators are lowercased; invalid values are defaulted to ``'='``.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
|
|
185
|
+
filters: List of FilterParam instances or raw dicts.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
|
|
189
|
+
Sorted list of normalised FilterParam instances.
|
|
190
|
+
"""
|
|
191
|
+
if not filters:
|
|
192
|
+
return []
|
|
193
|
+
out = []
|
|
194
|
+
for f in filters:
|
|
195
|
+
if isinstance(f, FilterParam):
|
|
196
|
+
left_expr = f.left_expr
|
|
197
|
+
op = f.op.strip().lower() if f.op else "="
|
|
198
|
+
vtype = f.value_type.strip().lower() if isinstance(f.value_type, str) else "unknown"
|
|
199
|
+
right_expr = f.right_expr
|
|
200
|
+
bool_op = f.bool_op
|
|
201
|
+
filter_group = f.filter_group
|
|
202
|
+
elif isinstance(f, dict):
|
|
203
|
+
left_raw = f.get("left_expr")
|
|
204
|
+
if isinstance(left_raw, dict):
|
|
205
|
+
left_expr = NormalizedExpr.from_dict(left_raw)
|
|
206
|
+
else:
|
|
207
|
+
col = f.get("column", "")
|
|
208
|
+
left_expr = NormalizedExpr.from_column(col.strip().lower() if isinstance(col, str) else "")
|
|
209
|
+
op = f.get("op", "=").strip().lower()
|
|
210
|
+
vtype = f.get("value_type", "unknown").strip().lower()
|
|
211
|
+
right_raw = f.get("right_expr")
|
|
212
|
+
right_expr = NormalizedExpr.from_dict(right_raw) if isinstance(right_raw, dict) and right_raw else None
|
|
213
|
+
bool_op = f.get("bool_op", "AND")
|
|
214
|
+
fg_raw = f.get("filter_group")
|
|
215
|
+
filter_group = int(fg_raw) if fg_raw is not None else None
|
|
216
|
+
else:
|
|
217
|
+
continue
|
|
218
|
+
if not left_expr or not isinstance(op, str):
|
|
219
|
+
continue
|
|
220
|
+
fp = FilterParam(
|
|
221
|
+
left_expr=left_expr,
|
|
222
|
+
op=op,
|
|
223
|
+
value_type=vtype,
|
|
224
|
+
param_key="",
|
|
225
|
+
right_expr=right_expr,
|
|
226
|
+
bool_op=bool_op,
|
|
227
|
+
filter_group=filter_group,
|
|
228
|
+
)
|
|
229
|
+
out.append(fp)
|
|
230
|
+
return sort_filters(out)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _normalize_having_conditions(conditions: list[Any]) -> list[HavingParam]:
|
|
234
|
+
"""Normalise a HAVING conditions list into sorted HavingParam objects.
|
|
235
|
+
|
|
236
|
+
Accepts both HavingParam instances and raw dicts. Operators are validated against the allowed set and defaulted to ``'='`` when invalid.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
|
|
240
|
+
conditions: List of HavingParam instances or raw dicts.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
|
|
244
|
+
Sorted list of normalised HavingParam instances.
|
|
245
|
+
"""
|
|
246
|
+
if not conditions:
|
|
247
|
+
return []
|
|
248
|
+
out = []
|
|
249
|
+
for c in conditions:
|
|
250
|
+
if isinstance(c, HavingParam):
|
|
251
|
+
left_expr = c.left_expr
|
|
252
|
+
op = c.op.strip().lower() if c.op else "="
|
|
253
|
+
value_type = c.value_type.strip().lower() if c.value_type else "number"
|
|
254
|
+
right_expr = c.right_expr
|
|
255
|
+
bool_op = c.bool_op
|
|
256
|
+
filter_group = c.filter_group
|
|
257
|
+
elif isinstance(c, dict):
|
|
258
|
+
left_raw = c.get("left_expr")
|
|
259
|
+
if isinstance(left_raw, dict):
|
|
260
|
+
left_expr = NormalizedExpr.from_dict(left_raw)
|
|
261
|
+
else:
|
|
262
|
+
agg = c.get("aggregation", "")
|
|
263
|
+
left_expr = NormalizedExpr.from_column(str(agg)) if agg else NormalizedExpr()
|
|
264
|
+
op = c.get("op", "=")
|
|
265
|
+
value_type = c.get("value_type", "number")
|
|
266
|
+
right_raw = c.get("right_expr")
|
|
267
|
+
right_expr = NormalizedExpr.from_dict(right_raw) if isinstance(right_raw, dict) and right_raw else None
|
|
268
|
+
bool_op = c.get("bool_op", "AND")
|
|
269
|
+
fg_raw = c.get("filter_group")
|
|
270
|
+
filter_group = int(fg_raw) if fg_raw is not None else None
|
|
271
|
+
else:
|
|
272
|
+
continue
|
|
273
|
+
if not left_expr or not left_expr.primary_term:
|
|
274
|
+
continue
|
|
275
|
+
op_norm = op.strip().lower() if isinstance(op, str) else "="
|
|
276
|
+
if op_norm not in VALID_HAVING_OPS:
|
|
277
|
+
op_norm = "="
|
|
278
|
+
hp = HavingParam(
|
|
279
|
+
left_expr=left_expr,
|
|
280
|
+
op=op_norm,
|
|
281
|
+
value_type=str(value_type),
|
|
282
|
+
param_key="",
|
|
283
|
+
right_expr=right_expr,
|
|
284
|
+
bool_op=bool_op,
|
|
285
|
+
filter_group=filter_group,
|
|
286
|
+
)
|
|
287
|
+
out.append(hp)
|
|
288
|
+
return sort_having(out)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _normalize_cte_steps(steps: Any, available_ctes: dict[str, list[str]] | None = None) -> list[RuntimeCteStep]:
|
|
292
|
+
"""Normalise a CTE steps list into RuntimeCteStep objects with resolved column maps.
|
|
293
|
+
|
|
294
|
+
Handles both RuntimeCteStep instances and raw dicts. Normalises filters, HAVING conditions, grain, select columns, order-by columns, and infers column-map and output column metadata from available CTE output names.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
|
|
298
|
+
steps: List of RuntimeCteStep instances or raw dicts.
|
|
299
|
+
available_ctes: Dict of already-processed CTE name -> output columns (mutated in place as each step is processed).
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
|
|
303
|
+
List of normalised RuntimeCteStep instances in input order.
|
|
304
|
+
"""
|
|
305
|
+
if not isinstance(steps, list):
|
|
306
|
+
return []
|
|
307
|
+
if available_ctes is None:
|
|
308
|
+
available_ctes = {}
|
|
309
|
+
out = []
|
|
310
|
+
for s in steps:
|
|
311
|
+
if isinstance(s, RuntimeCteStep):
|
|
312
|
+
cte_name = s.cte_name
|
|
313
|
+
tables = s.tables or []
|
|
314
|
+
grain = s.grain or "row_level"
|
|
315
|
+
select_cols = s.select_cols or []
|
|
316
|
+
group_by_cols = s.group_by_cols or []
|
|
317
|
+
output_columns = s.output_columns or []
|
|
318
|
+
filters_param = s.filters_param or []
|
|
319
|
+
having_param = s.having_param or []
|
|
320
|
+
param_values = s.param_values or {}
|
|
321
|
+
order_by_cols = s.order_by_cols or []
|
|
322
|
+
limit = s.limit
|
|
323
|
+
column_map = s.column_map or {}
|
|
324
|
+
output_column_metadata = s.output_column_metadata or {}
|
|
325
|
+
chosen_join_candidate_id = s.chosen_join_candidate_id or ""
|
|
326
|
+
chosen_join_path_signature = s.chosen_join_path_signature or []
|
|
327
|
+
elif isinstance(s, dict):
|
|
328
|
+
cte_name = s.get("cte_name", "")
|
|
329
|
+
tables = s.get("tables", [])
|
|
330
|
+
grain = s.get("grain", "row_level")
|
|
331
|
+
sc_raw = s.get("select_cols", [])
|
|
332
|
+
select_cols = [
|
|
333
|
+
(
|
|
334
|
+
SelectCol.from_dict(sc)
|
|
335
|
+
if isinstance(sc, dict)
|
|
336
|
+
else (SelectCol(expr=NormalizedExpr.from_column(sc)) if isinstance(sc, str) else sc)
|
|
337
|
+
)
|
|
338
|
+
for sc in sc_raw
|
|
339
|
+
]
|
|
340
|
+
group_by_cols = s.get("group_by_cols", [])
|
|
341
|
+
group_by_cols = [
|
|
342
|
+
(
|
|
343
|
+
NormalizedExpr.from_dict(g)
|
|
344
|
+
if isinstance(g, dict)
|
|
345
|
+
else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
|
|
346
|
+
)
|
|
347
|
+
for g in group_by_cols
|
|
348
|
+
]
|
|
349
|
+
output_columns = s.get("output_columns", [])
|
|
350
|
+
fp_raw = s.get("filters_param", [])
|
|
351
|
+
filters_param = [FilterParam.from_dict(f) if isinstance(f, dict) else f for f in fp_raw]
|
|
352
|
+
hp_raw = s.get("having_param", [])
|
|
353
|
+
having_param = [HavingParam.from_dict(h) if isinstance(h, dict) else h for h in hp_raw]
|
|
354
|
+
param_values = s.get("param_values", {})
|
|
355
|
+
obc_raw = s.get("order_by_cols", [])
|
|
356
|
+
order_by_cols = [
|
|
357
|
+
(
|
|
358
|
+
OrderByCol.from_dict(o)
|
|
359
|
+
if isinstance(o, dict)
|
|
360
|
+
else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
|
|
361
|
+
)
|
|
362
|
+
for o in obc_raw
|
|
363
|
+
]
|
|
364
|
+
limit = s.get("limit")
|
|
365
|
+
column_map = s.get("column_map", {})
|
|
366
|
+
ocm_raw = s.get("output_column_metadata", {})
|
|
367
|
+
output_column_metadata = {
|
|
368
|
+
k: CteOutputColumnMeta.from_dict(v) if isinstance(v, dict) else v for k, v in ocm_raw.items()
|
|
369
|
+
}
|
|
370
|
+
chosen_join_candidate_id = s.get("chosen_join_candidate_id", "")
|
|
371
|
+
chosen_join_path_signature = s.get("chosen_join_path_signature", [])
|
|
372
|
+
else:
|
|
373
|
+
continue
|
|
374
|
+
if not cte_name:
|
|
375
|
+
continue
|
|
376
|
+
if grain not in VALID_GRAINS:
|
|
377
|
+
grain = "row_level"
|
|
378
|
+
normalized_fp = []
|
|
379
|
+
for f in filters_param:
|
|
380
|
+
if isinstance(f, FilterParam):
|
|
381
|
+
op = f.op.strip().lower() if f.op else "="
|
|
382
|
+
if op not in VALID_FILTER_OPS:
|
|
383
|
+
op = "="
|
|
384
|
+
vtype = f.value_type.strip().lower() if f.value_type else "unknown"
|
|
385
|
+
fp = FilterParam(
|
|
386
|
+
left_expr=f.left_expr,
|
|
387
|
+
op=op,
|
|
388
|
+
value_type=vtype,
|
|
389
|
+
param_key=f.param_key or "",
|
|
390
|
+
right_expr=f.right_expr,
|
|
391
|
+
)
|
|
392
|
+
normalized_fp.append(fp)
|
|
393
|
+
normalized_hp = []
|
|
394
|
+
for h in having_param:
|
|
395
|
+
if isinstance(h, HavingParam):
|
|
396
|
+
op = h.op.strip().lower() if h.op else "="
|
|
397
|
+
if op not in VALID_HAVING_OPS:
|
|
398
|
+
op = "="
|
|
399
|
+
vtype = h.value_type.strip().lower() if h.value_type else "number"
|
|
400
|
+
hp = HavingParam(
|
|
401
|
+
left_expr=h.left_expr,
|
|
402
|
+
op=op,
|
|
403
|
+
value_type=vtype,
|
|
404
|
+
param_key=h.param_key or "",
|
|
405
|
+
right_expr=h.right_expr,
|
|
406
|
+
)
|
|
407
|
+
normalized_hp.append(hp)
|
|
408
|
+
cte_column_map = {}
|
|
409
|
+
all_cols_raw = [g.primary_column for g in group_by_cols] + [f.left_expr.primary_column for f in normalized_fp]
|
|
410
|
+
for sc in select_cols:
|
|
411
|
+
if isinstance(sc, SelectCol):
|
|
412
|
+
all_cols_raw.append(sc.expr.primary_column)
|
|
413
|
+
elif isinstance(sc, str):
|
|
414
|
+
all_cols_raw.append(sc)
|
|
415
|
+
for col in all_cols_raw:
|
|
416
|
+
if "." in col:
|
|
417
|
+
parts = col.split(".", 1)
|
|
418
|
+
table_ref = parts[0].lower()
|
|
419
|
+
col_name = parts[1].lower()
|
|
420
|
+
if table_ref in {t.lower() for t in tables} or table_ref in {c.lower() for c in available_ctes.keys()}:
|
|
421
|
+
cte_column_map[col_name] = table_ref
|
|
422
|
+
if column_map:
|
|
423
|
+
cte_column_map.update(column_map)
|
|
424
|
+
ocm = {}
|
|
425
|
+
for out_col in output_columns:
|
|
426
|
+
out_col_lower = out_col.lower()
|
|
427
|
+
if out_col_lower in output_column_metadata:
|
|
428
|
+
ocm[out_col_lower] = output_column_metadata[out_col_lower]
|
|
429
|
+
else:
|
|
430
|
+
is_agg = any(out_col_lower.startswith(f"{agg}_") for agg in VALID_AGGREGATION_FUNCTIONS)
|
|
431
|
+
base_col = ""
|
|
432
|
+
agg_func = ""
|
|
433
|
+
if is_agg:
|
|
434
|
+
for agg in VALID_AGGREGATION_FUNCTIONS:
|
|
435
|
+
if out_col_lower.startswith(f"{agg}_"):
|
|
436
|
+
agg_func = agg
|
|
437
|
+
base_col = out_col_lower[len(agg) + 1 :].replace("_", ".")
|
|
438
|
+
break
|
|
439
|
+
ocm[out_col_lower] = CteOutputColumnMeta(
|
|
440
|
+
source="aggregation" if is_agg else "passthrough",
|
|
441
|
+
base_column=base_col,
|
|
442
|
+
agg_func=agg_func,
|
|
443
|
+
filterable=True,
|
|
444
|
+
aggregatable=True,
|
|
445
|
+
data_type=("integer" if is_agg and agg_func == "count" else "unknown"),
|
|
446
|
+
)
|
|
447
|
+
normalized_select_cols = (
|
|
448
|
+
select_cols
|
|
449
|
+
if select_cols and isinstance(select_cols[0], SelectCol)
|
|
450
|
+
else ([SelectCol(expr=NormalizedExpr.from_column(c)) for c in select_cols] if select_cols else [])
|
|
451
|
+
)
|
|
452
|
+
normalized_order_by = (
|
|
453
|
+
order_by_cols
|
|
454
|
+
if order_by_cols and isinstance(order_by_cols[0], OrderByCol)
|
|
455
|
+
else (
|
|
456
|
+
[(OrderByCol(expr=NormalizedExpr.from_column(c)) if isinstance(c, str) else c) for c in order_by_cols]
|
|
457
|
+
if order_by_cols
|
|
458
|
+
else []
|
|
459
|
+
)
|
|
460
|
+
)
|
|
461
|
+
cte = RuntimeCteStep(
|
|
462
|
+
cte_name=str(cte_name),
|
|
463
|
+
tables=sorted(set(str(t) for t in tables)),
|
|
464
|
+
grain=grain,
|
|
465
|
+
select_cols=normalized_select_cols,
|
|
466
|
+
group_by_cols=group_by_cols,
|
|
467
|
+
output_columns=list(str(c) for c in output_columns),
|
|
468
|
+
filters_param=sorted(
|
|
469
|
+
normalized_fp,
|
|
470
|
+
key=lambda x: (
|
|
471
|
+
x.left_expr.signature_key,
|
|
472
|
+
x.op,
|
|
473
|
+
x.right_expr.signature_key if x.right_expr else "",
|
|
474
|
+
x.value_type,
|
|
475
|
+
),
|
|
476
|
+
),
|
|
477
|
+
having_param=sorted(
|
|
478
|
+
normalized_hp,
|
|
479
|
+
key=lambda x: (
|
|
480
|
+
x.left_expr.signature_key,
|
|
481
|
+
x.op,
|
|
482
|
+
x.right_expr.signature_key if x.right_expr else "",
|
|
483
|
+
x.value_type,
|
|
484
|
+
),
|
|
485
|
+
),
|
|
486
|
+
param_values=param_values,
|
|
487
|
+
order_by_cols=normalized_order_by,
|
|
488
|
+
limit=limit,
|
|
489
|
+
column_map=cte_column_map,
|
|
490
|
+
output_column_metadata=ocm,
|
|
491
|
+
chosen_join_candidate_id=chosen_join_candidate_id,
|
|
492
|
+
chosen_join_path_signature=chosen_join_path_signature,
|
|
493
|
+
)
|
|
494
|
+
out.append(cte)
|
|
495
|
+
available_ctes[cte_name] = output_columns
|
|
496
|
+
return out
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def _normalize_cte_steps_for_key(steps: list[RuntimeCteStep]) -> list[dict[str, Any]]:
|
|
500
|
+
"""Convert RuntimeCteStep list to stable structural dicts for intent-key hashing.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
|
|
504
|
+
steps: List of normalised RuntimeCteStep instances.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
|
|
508
|
+
List of dicts with sorted keys suitable for stable JSON serialisation.
|
|
509
|
+
"""
|
|
510
|
+
result = []
|
|
511
|
+
for cte in steps:
|
|
512
|
+
select_sigs = []
|
|
513
|
+
for sc in cte.select_cols or []:
|
|
514
|
+
if isinstance(sc, SelectCol):
|
|
515
|
+
select_sigs.append(sc.signature_key)
|
|
516
|
+
elif isinstance(sc, str):
|
|
517
|
+
select_sigs.append(sc)
|
|
518
|
+
order_sigs = []
|
|
519
|
+
for obc in cte.order_by_cols or []:
|
|
520
|
+
if isinstance(obc, OrderByCol):
|
|
521
|
+
order_sigs.append(obc.signature_key)
|
|
522
|
+
elif isinstance(obc, str):
|
|
523
|
+
order_sigs.append(obc)
|
|
524
|
+
cte_dict = {
|
|
525
|
+
"cte_name": cte.cte_name,
|
|
526
|
+
"tables": sorted(cte.tables or []),
|
|
527
|
+
"select_cols": sorted(select_sigs),
|
|
528
|
+
"group_by_cols": sorted([g.signature_key for g in (cte.group_by_cols or [])]),
|
|
529
|
+
"output_columns": sorted(cte.output_columns or []),
|
|
530
|
+
"filters_param": [
|
|
531
|
+
f"{f.signature_key}|{f.bool_op}|{f.filter_group}"
|
|
532
|
+
for f in sorted(
|
|
533
|
+
cte.filters_param or [],
|
|
534
|
+
key=lambda x: (
|
|
535
|
+
x.filter_group if x.filter_group is not None else -1,
|
|
536
|
+
x.left_expr.signature_key,
|
|
537
|
+
x.op,
|
|
538
|
+
x.right_expr.signature_key if x.right_expr else "",
|
|
539
|
+
x.value_type,
|
|
540
|
+
),
|
|
541
|
+
)
|
|
542
|
+
],
|
|
543
|
+
"having_param": [
|
|
544
|
+
f"{h.signature_key}|{h.bool_op}|{h.filter_group}"
|
|
545
|
+
for h in sorted(
|
|
546
|
+
cte.having_param or [],
|
|
547
|
+
key=lambda x: (
|
|
548
|
+
x.filter_group if x.filter_group is not None else -1,
|
|
549
|
+
x.left_expr.signature_key,
|
|
550
|
+
x.op,
|
|
551
|
+
x.right_expr.signature_key if x.right_expr else "",
|
|
552
|
+
x.value_type,
|
|
553
|
+
),
|
|
554
|
+
)
|
|
555
|
+
],
|
|
556
|
+
"order_by_cols": sorted(order_sigs),
|
|
557
|
+
}
|
|
558
|
+
result.append(cte_dict)
|
|
559
|
+
return result
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def flatten_param_values(intent: RuntimeIntent) -> dict[str, Any]:
|
|
563
|
+
"""Flatten all param_values from CTE steps and the main query into a single dict.
|
|
564
|
+
|
|
565
|
+
CTE values are merged first; main query values take precedence on key collision.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
|
|
569
|
+
intent: The RuntimeIntent whose CTE steps and main param_values are merged.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
|
|
573
|
+
Merged dict of all parameter values.
|
|
574
|
+
"""
|
|
575
|
+
merged = {}
|
|
576
|
+
for cte in intent.cte_steps or []:
|
|
577
|
+
merged.update(cte.param_values or {})
|
|
578
|
+
merged.update(intent.param_values or {})
|
|
579
|
+
return merged
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def intent_key(intent: RuntimeIntent) -> str:
|
|
583
|
+
"""Generate a stable SHA-256 hash key for a RuntimeIntent.
|
|
584
|
+
|
|
585
|
+
Normalises tables, select columns, filters, GROUP BY, ORDER BY, HAVING, and CTE steps before hashing so that semantically equivalent intents produce the same key regardless of input order.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
|
|
589
|
+
intent: The RuntimeIntent to fingerprint.
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
|
|
593
|
+
Hex-encoded SHA-256 hash string.
|
|
594
|
+
"""
|
|
595
|
+
select_cols = intent.select_cols or []
|
|
596
|
+
|
|
597
|
+
grain = intent.grain or "row_level"
|
|
598
|
+
if grain not in VALID_GRAINS:
|
|
599
|
+
debug(f"[utils.intent_key] invalid_grain: '{grain}' defaulting to 'row_level'")
|
|
600
|
+
grain = "row_level"
|
|
601
|
+
|
|
602
|
+
expected_rows = intent.expected_rows or "many"
|
|
603
|
+
if expected_rows not in VALID_EXPECTED_ROWS:
|
|
604
|
+
if grain == "scalar":
|
|
605
|
+
expected_rows = "one"
|
|
606
|
+
elif grain == "grouped":
|
|
607
|
+
expected_rows = "few"
|
|
608
|
+
else:
|
|
609
|
+
expected_rows = "many"
|
|
610
|
+
debug(f"[utils.intent_key] inferred_expected_rows: grain={grain} -> expected_rows={expected_rows}")
|
|
611
|
+
|
|
612
|
+
filters_normalized = _normalize_filters(intent.filters_param or [])
|
|
613
|
+
having_conditions_normalized = _normalize_having_conditions(intent.having_param or [])
|
|
614
|
+
cte_steps_normalized = _normalize_cte_steps(intent.cte_steps or [])
|
|
615
|
+
cte_steps_for_key = _normalize_cte_steps_for_key(cte_steps_normalized)
|
|
616
|
+
|
|
617
|
+
select_cols_sorted = sorted([s.signature_key for s in select_cols])
|
|
618
|
+
order_by_sorted = sorted([o.signature_key for o in (intent.order_by_cols or [])])
|
|
619
|
+
|
|
620
|
+
normalized = {
|
|
621
|
+
"tables": sorted(intent.tables or []),
|
|
622
|
+
"select_cols": select_cols_sorted,
|
|
623
|
+
"filters": [f.to_dict() for f in filters_normalized],
|
|
624
|
+
"group_by_cols": sorted([g.signature_key for g in (intent.group_by_cols or [])]),
|
|
625
|
+
"order_by_cols": order_by_sorted,
|
|
626
|
+
"having_conditions": [hc.to_dict() for hc in having_conditions_normalized],
|
|
627
|
+
"cte_steps": cte_steps_for_key,
|
|
628
|
+
}
|
|
629
|
+
key = sha256(stable_json(normalized))
|
|
630
|
+
debug(f"[utils.intent_key] computed: tables={normalized['tables']} key={key[:16]}...")
|
|
631
|
+
return key
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _tokenize(q: str) -> list[str]:
|
|
635
|
+
"""Extract sorted, unique tokens from a normalised question, excluding stopwords.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
|
|
639
|
+
q: The normalised question string.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
|
|
643
|
+
Sorted list of alphanumeric/underscore tokens with stopwords removed.
|
|
644
|
+
"""
|
|
645
|
+
return sorted(t for t in re.findall(r"[a-z0-9_]+", q) if t and t not in PolicyConfig.STOPWORDS)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def extract_tables_from_sql(sql: str, known_tables: list[str]) -> list[str]:
|
|
649
|
+
"""Extract real table names referenced in SQL from a known list, excluding CTE names.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
|
|
653
|
+
sql: The SQL string to scan.
|
|
654
|
+
known_tables: List of base table names to match against.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
|
|
658
|
+
Sorted list of matched table names that are not CTE aliases.
|
|
659
|
+
"""
|
|
660
|
+
s = sql.lower()
|
|
661
|
+
|
|
662
|
+
cte_names = set()
|
|
663
|
+
if re.search(r"\bwith\b", s):
|
|
664
|
+
for m in re.finditer(r"\b(\w+)\s+AS\s*\(", s, re.IGNORECASE):
|
|
665
|
+
cte_names.add(m.group(1).lower())
|
|
666
|
+
|
|
667
|
+
hits = []
|
|
668
|
+
for t in known_tables:
|
|
669
|
+
tn = t.lower()
|
|
670
|
+
if tn not in cte_names and re.search(rf"\b{re.escape(tn)}\b", s):
|
|
671
|
+
hits.append(t)
|
|
672
|
+
return sorted(set(hits))
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _describe_operation(select_terms: list[str]) -> str:
|
|
676
|
+
"""Derive the primary operation type from a list of SELECT expression strings.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
|
|
680
|
+
select_terms: List of SELECT column/expression strings (e.g. ``["COUNT(t.id)"]``).
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
|
|
684
|
+
Lowercase aggregation function name (e.g. ``'count'``, ``'sum'``) or ``'list'`` if no aggregation is detected.
|
|
685
|
+
"""
|
|
686
|
+
for sc in select_terms:
|
|
687
|
+
m = AGG_PATTERN.match(sc)
|
|
688
|
+
if m:
|
|
689
|
+
return m.group(1).lower()
|
|
690
|
+
return "list"
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def _pick_question_style(select_terms: list[str], has_grouping: bool) -> str:
|
|
694
|
+
"""Pick a random question-opening phrase appropriate for the query structure.
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
|
|
698
|
+
select_terms: List of SELECT expression strings used to detect aggregation.
|
|
699
|
+
has_grouping: True if the query has GROUP BY columns.
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
|
|
703
|
+
A randomly selected opening phrase string.
|
|
704
|
+
"""
|
|
705
|
+
agg_funcs = []
|
|
706
|
+
for sc in select_terms:
|
|
707
|
+
m = AGG_PATTERN.match(sc)
|
|
708
|
+
if m:
|
|
709
|
+
agg_funcs.append(m.group(1).lower())
|
|
710
|
+
|
|
711
|
+
if has_grouping:
|
|
712
|
+
return random.choice(QUESTION_STARTS_GROUP)
|
|
713
|
+
elif agg_funcs:
|
|
714
|
+
agg = agg_funcs[0]
|
|
715
|
+
if agg == "count":
|
|
716
|
+
return random.choice(["How many", "Count", "What is the number of"])
|
|
717
|
+
elif agg == "sum":
|
|
718
|
+
return random.choice(["What is the total", "Find the sum of", "Calculate the total"])
|
|
719
|
+
elif agg == "avg":
|
|
720
|
+
return random.choice(["What is the average", "Calculate the average", "Find the mean"])
|
|
721
|
+
elif agg in ("min", "max"):
|
|
722
|
+
return random.choice([f"What is the {agg}imum", f"Find the {agg}imum", f"Get the {agg}"])
|
|
723
|
+
return random.choice(QUESTION_STARTS_AGG)
|
|
724
|
+
else:
|
|
725
|
+
return random.choice(QUESTION_STARTS_LIST)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def generate_question(
|
|
729
|
+
tables: list[str],
|
|
730
|
+
select_terms: list[str],
|
|
731
|
+
filter_descriptions: list[dict[str, str]],
|
|
732
|
+
group_by_terms: list[str],
|
|
733
|
+
having_descriptions: list[dict[str, str]],
|
|
734
|
+
schema: SchemaGraph,
|
|
735
|
+
) -> str | None:
|
|
736
|
+
"""Generate a natural language question from a plain-data intent description via LLM.
|
|
737
|
+
|
|
738
|
+
Builds a structured prompt including semantic context from the schema graph and enforces a randomly selected question-opening phrase for variety. Callers are responsible for resolving all param placeholder values into the condition strings of ``filter_descriptions`` and ``having_descriptions`` before calling.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
|
|
742
|
+
tables: List of table names involved in the query.
|
|
743
|
+
select_terms: List of SELECT expression strings.
|
|
744
|
+
filter_descriptions: List of dicts with ``'column'`` and ``'condition'`` keys, with concrete values already substituted.
|
|
745
|
+
group_by_terms: List of GROUP BY column expression strings.
|
|
746
|
+
having_descriptions: List of dicts with HAVING condition descriptions, with concrete values already substituted.
|
|
747
|
+
schema: The SchemaGraph used to look up table and column descriptions.
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
|
|
751
|
+
The generated question string, or None if the LLM call fails or returns a response that violates the required opening phrase.
|
|
752
|
+
"""
|
|
753
|
+
semantics = {}
|
|
754
|
+
for table in tables:
|
|
755
|
+
table_ir = schema.tables.get(table)
|
|
756
|
+
if table_ir:
|
|
757
|
+
semantics[table] = {
|
|
758
|
+
"description": table_ir.description or f"{table} records",
|
|
759
|
+
"columns": {col: (getattr(meta, "description", None) or col) for col, meta in table_ir.columns.items()},
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
roles = {}
|
|
763
|
+
all_columns = set()
|
|
764
|
+
for fd in filter_descriptions:
|
|
765
|
+
all_columns.add(fd.get("column", ""))
|
|
766
|
+
for col in group_by_terms:
|
|
767
|
+
all_columns.add(col)
|
|
768
|
+
for col in all_columns:
|
|
769
|
+
parts = col.split(".")
|
|
770
|
+
if len(parts) == 2 and parts[0] in schema.tables:
|
|
771
|
+
col_meta = schema.tables[parts[0]].columns.get(parts[1])
|
|
772
|
+
if col_meta and col_meta.role:
|
|
773
|
+
roles[col] = col_meta.role
|
|
774
|
+
|
|
775
|
+
operation = _describe_operation(select_terms)
|
|
776
|
+
|
|
777
|
+
intent_structure = {
|
|
778
|
+
"tables": tables,
|
|
779
|
+
"operation": operation,
|
|
780
|
+
"columns": select_terms,
|
|
781
|
+
"filters": filter_descriptions if filter_descriptions else None,
|
|
782
|
+
"grouping": group_by_terms if group_by_terms else None,
|
|
783
|
+
"having": having_descriptions if having_descriptions else None,
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
selected_style = _pick_question_style(select_terms, bool(group_by_terms))
|
|
787
|
+
|
|
788
|
+
system = (
|
|
789
|
+
"You are a natural language question generator for database queries. Convert structured query intent to conversational questions.\n\n"
|
|
790
|
+
"Output Requirements:\n"
|
|
791
|
+
'- Output ONLY valid JSON with single "question" field\n'
|
|
792
|
+
"- Do NOT include markdown, explanations, or commentary\n"
|
|
793
|
+
"- Identical inputs must produce identical outputs\n\n"
|
|
794
|
+
"Generation Rules:\n"
|
|
795
|
+
"- Question MUST reflect EXACTLY the intent structure (same tables, columns, filters, aggregation)\n"
|
|
796
|
+
"- Do NOT add columns, tables, or filters not in the intent\n"
|
|
797
|
+
"- Do NOT change aggregation type or omit filter values\n"
|
|
798
|
+
"- Use natural, conversational language - avoid SQL jargon and raw table names\n"
|
|
799
|
+
"- Use semantic context to refer to business concepts naturally\n"
|
|
800
|
+
"- For multi-table queries, describe relationships naturally\n"
|
|
801
|
+
"- Inject filter values naturally into the question\n"
|
|
802
|
+
"- ONE sentence only\n"
|
|
803
|
+
"- Vary phrasing naturally within the required start constraint\n"
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
user_prompt = {
|
|
807
|
+
"task": "Generate a natural language question for this query intent",
|
|
808
|
+
"intent": intent_structure,
|
|
809
|
+
"semantic_context": semantics,
|
|
810
|
+
"column_roles": roles,
|
|
811
|
+
"phrasing_constraint": {
|
|
812
|
+
"required_start": selected_style,
|
|
813
|
+
"description": f"Question MUST start with exactly: '{selected_style}'",
|
|
814
|
+
"strict_mode": not filter_descriptions and not group_by_terms,
|
|
815
|
+
"phrasing_flexibility": "Vary word choice and sentence structure while keeping the required start",
|
|
816
|
+
},
|
|
817
|
+
"examples": [
|
|
818
|
+
{
|
|
819
|
+
"intent": {
|
|
820
|
+
"tables": ["table1"],
|
|
821
|
+
"operation": "count",
|
|
822
|
+
"columns": ["COUNT(table1.column1)"],
|
|
823
|
+
"filters": [{"column": "table1.column2", "condition": "= 1"}],
|
|
824
|
+
},
|
|
825
|
+
"required_start": "How many",
|
|
826
|
+
"question": "How many active records do we have?",
|
|
827
|
+
},
|
|
828
|
+
{
|
|
829
|
+
"intent": {
|
|
830
|
+
"tables": ["table1"],
|
|
831
|
+
"operation": "list",
|
|
832
|
+
"columns": ["table1.column1", "table1.column2"],
|
|
833
|
+
"filters": [{"column": "table1.column2", "condition": "= 'value1'"}],
|
|
834
|
+
},
|
|
835
|
+
"required_start": "What are",
|
|
836
|
+
"question": "What are the records with column2 equal to value1?",
|
|
837
|
+
},
|
|
838
|
+
{
|
|
839
|
+
"intent": {
|
|
840
|
+
"tables": ["table1", "table2"],
|
|
841
|
+
"operation": "sum",
|
|
842
|
+
"columns": ["SUM(table2.column1)", "table1.column1"],
|
|
843
|
+
"grouping": ["table2.column2"],
|
|
844
|
+
},
|
|
845
|
+
"required_start": "Show me",
|
|
846
|
+
"question": "Show me the total column1 for each group.",
|
|
847
|
+
},
|
|
848
|
+
],
|
|
849
|
+
"output_format": {"question": "Your natural language question here"},
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
try:
|
|
853
|
+
response = llm_json(system, stable_json(user_prompt), retries=1, task="default")
|
|
854
|
+
question = response.get("question")
|
|
855
|
+
if question and isinstance(question, str):
|
|
856
|
+
question = question.strip()
|
|
857
|
+
template_start = selected_style.split("{")[0].strip()
|
|
858
|
+
if template_start and not question.startswith(template_start):
|
|
859
|
+
debug(
|
|
860
|
+
f"[utils.generate_question] phrasing_violation: expected_start={template_start}, got={question[:30]}"
|
|
861
|
+
)
|
|
862
|
+
return None
|
|
863
|
+
debug(f"[utils.generate_question] generated: {question[:50]}")
|
|
864
|
+
return question
|
|
865
|
+
debug("[utils.generate_question] missing_question_field")
|
|
866
|
+
return None
|
|
867
|
+
except Exception as e:
|
|
868
|
+
debug(f"[utils.generate_question] failed: {e}")
|
|
869
|
+
return None
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
_QUESTION_FROM_SQL_SYSTEM = (
|
|
873
|
+
"You are given a SQL query and a schema description. "
|
|
874
|
+
"Your job is to decide whether the query represents a realistic, "
|
|
875
|
+
"meaningful business question and, if so, generate a natural language "
|
|
876
|
+
"question that a human analyst would ask to get this query's result.\n\n"
|
|
877
|
+
"Rules:\n"
|
|
878
|
+
"- If the query is unrealistic, nonsensical, or produces meaningless "
|
|
879
|
+
"results, set is_realistic to false and explain why in drop_reason.\n"
|
|
880
|
+
"- If realistic, write ONE clear, conversational question that a "
|
|
881
|
+
"non-technical user would ask. Do NOT use SQL jargon or raw column "
|
|
882
|
+
"names — use natural business language.\n"
|
|
883
|
+
"- Output ONLY valid JSON with exactly three fields: "
|
|
884
|
+
'"question", "is_realistic" (boolean), "drop_reason" (string or null).\n'
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def generate_question_from_sql(
|
|
889
|
+
sql: str,
|
|
890
|
+
schema: SchemaGraph,
|
|
891
|
+
tables: list[str],
|
|
892
|
+
) -> dict[str, Any] | None:
|
|
893
|
+
"""Generate an NL question from a substituted SQL query with a realism gate.
|
|
894
|
+
|
|
895
|
+
Sends the SQL and table/column descriptions to the LLM (default
|
|
896
|
+
task profile = gpt-4o-mini). Returns a dict with ``question``,
|
|
897
|
+
``is_realistic``, and ``drop_reason`` keys, or ``None`` on failure.
|
|
898
|
+
|
|
899
|
+
Args:
|
|
900
|
+
|
|
901
|
+
sql: Fully substituted SQL string with literal values.
|
|
902
|
+
schema: Schema graph for table/column descriptions.
|
|
903
|
+
tables: Tables referenced by the query.
|
|
904
|
+
|
|
905
|
+
Returns:
|
|
906
|
+
|
|
907
|
+
Dict with question and realism judgment, or None on LLM failure.
|
|
908
|
+
"""
|
|
909
|
+
col_descriptions: list[str] = []
|
|
910
|
+
for table in tables:
|
|
911
|
+
table_meta = schema.tables.get(table)
|
|
912
|
+
if not table_meta:
|
|
913
|
+
continue
|
|
914
|
+
cols = []
|
|
915
|
+
for col_name, col_meta in table_meta.columns.items():
|
|
916
|
+
desc = f"{col_name} ({col_meta.data_type or 'unknown'})"
|
|
917
|
+
if col_meta.role:
|
|
918
|
+
desc += f" [{col_meta.role}]"
|
|
919
|
+
cols.append(desc)
|
|
920
|
+
col_descriptions.append(
|
|
921
|
+
f"TABLE {table}: {', '.join(cols)}"
|
|
922
|
+
)
|
|
923
|
+
schema_context = "\n".join(col_descriptions)
|
|
924
|
+
|
|
925
|
+
user_prompt = stable_json({
|
|
926
|
+
"sql": sql,
|
|
927
|
+
"schema": schema_context,
|
|
928
|
+
"output_format": {
|
|
929
|
+
"question": "natural language question or empty string",
|
|
930
|
+
"is_realistic": True,
|
|
931
|
+
"drop_reason": None,
|
|
932
|
+
},
|
|
933
|
+
})
|
|
934
|
+
|
|
935
|
+
try:
|
|
936
|
+
response = llm_json(
|
|
937
|
+
_QUESTION_FROM_SQL_SYSTEM, user_prompt,
|
|
938
|
+
retries=1, task="default",
|
|
939
|
+
)
|
|
940
|
+
is_realistic = response.get("is_realistic", False)
|
|
941
|
+
question = response.get("question", "")
|
|
942
|
+
drop_reason = response.get("drop_reason")
|
|
943
|
+
|
|
944
|
+
if not isinstance(is_realistic, bool):
|
|
945
|
+
is_realistic = str(is_realistic).lower() in ("true", "1", "yes")
|
|
946
|
+
|
|
947
|
+
if not is_realistic:
|
|
948
|
+
debug(
|
|
949
|
+
f"[utils.generate_question_from_sql] dropped: "
|
|
950
|
+
f"reason={drop_reason}"
|
|
951
|
+
)
|
|
952
|
+
return {
|
|
953
|
+
"question": "",
|
|
954
|
+
"is_realistic": False,
|
|
955
|
+
"drop_reason": drop_reason or "unrealistic",
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
if not question or not isinstance(question, str):
|
|
959
|
+
debug("[utils.generate_question_from_sql] empty question")
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
debug(
|
|
963
|
+
f"[utils.generate_question_from_sql] generated: "
|
|
964
|
+
f"{question[:60]}"
|
|
965
|
+
)
|
|
966
|
+
return {
|
|
967
|
+
"question": question.strip(),
|
|
968
|
+
"is_realistic": True,
|
|
969
|
+
"drop_reason": None,
|
|
970
|
+
}
|
|
971
|
+
except Exception as e:
|
|
972
|
+
debug(f"[utils.generate_question_from_sql] failed: {e}")
|
|
973
|
+
return None
|