aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
text2sql/utils.py ADDED
@@ -0,0 +1,973 @@
1
+ """Shared utility functions for intent normalisation, question validation, fingerprint generation, and natural language question generation.
2
+
3
+ Provides SHA-256 intent fingerprinting over canonical stable JSON, fuzzy question matching with Levenshtein distance, and filter/HAVING/CTE normalisation helpers used by both the main pipeline and the query simulator. Also exposes the ``generate_question`` LLM helper used by QSim and Simulator to produce human-readable questions from structured query intent.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import random
9
+ import re
10
+ from typing import Any
11
+
12
+ from .config import (
13
+ AGG_PATTERN,
14
+ QUESTION_STARTS_AGG,
15
+ QUESTION_STARTS_GROUP,
16
+ QUESTION_STARTS_LIST,
17
+ VALID_AGGREGATION_FUNCTIONS,
18
+ VALID_EXPECTED_ROWS,
19
+ VALID_FILTER_OPS,
20
+ VALID_GRAINS,
21
+ VALID_HAVING_OPS,
22
+ PolicyConfig,
23
+ )
24
+ from .contracts_base import CteOutputColumnMeta, SchemaGraph, SQLShape
25
+ from .contracts_core import (
26
+ FilterParam,
27
+ HavingParam,
28
+ NormalizedExpr,
29
+ OrderByCol,
30
+ RuntimeCteStep,
31
+ RuntimeIntent,
32
+ SelectCol,
33
+ )
34
+ from .core_utils import debug, llm_json, normalize_question, sha256, stable_json
35
+ from .intent_resolve import sort_filters, sort_having
36
+
37
+ _QUESTION_VALIDATION_SYSTEM = (
38
+ "You decide if user input is a database query request or not.\n\n"
39
+ "A database query request asks to retrieve, count, sum, filter, sort, or analyze data.\n"
40
+ "When in doubt, treat it as a valid query request.\n\n"
41
+ "Only mark as invalid if it is clearly:\n"
42
+ '- Chitchat (e.g. "hello", "thanks", "who are you")\n'
43
+ '- A help or instruction request (e.g. "how do I query", "explain SQL")\n'
44
+ '- General knowledge unrelated to data (e.g. "what is the capital of France")\n\n'
45
+ "Respond with JSON containing exactly three fields:\n"
46
+ '- "valid_database_question": "yes" or "no"\n'
47
+ '- "query_type": "allowed" if read/SELECT operation, "restricted" if write/destructive operation (UPDATE/DELETE/INSERT/DROP/ALTER/CREATE/TRUNCATE), "unspecified" if unclear\n'
48
+ '- "corrected": the input with spelling typos fixed only. Do NOT remove, reorder, or rephrase any words. If no corrections needed, return the original text unchanged.\n\n'
49
+ "Respond ONLY with valid JSON, no explanation."
50
+ )
51
+
52
+
53
+ def validate_question(question: str) -> tuple[bool, str, str]:
54
+ """Validate a question and return typo-corrected text via LLM.
55
+
56
+ Args:
57
+
58
+ question: The raw user input string.
59
+
60
+ Returns:
61
+
62
+ Tuple of ``(is_valid, query_type, corrected_text)`` where ``is_valid`` is True for valid database questions, ``query_type`` is one of ``'allowed'``, ``'restricted'``, or ``'unspecified'``, and ``corrected_text`` is the typo-corrected version of the original input.
63
+ """
64
+ result = llm_json(_QUESTION_VALIDATION_SYSTEM, question, task="default")
65
+ if not result:
66
+ return False, "invalid", question
67
+ valid = result.get("valid_database_question", "").lower() == "yes"
68
+ query_type = result.get("query_type", "unspecified").lower()
69
+ corrected = result.get("corrected", question) or question
70
+ if not valid:
71
+ return False, "invalid", corrected
72
+ if query_type == "restricted":
73
+ return False, "restricted", corrected
74
+ return True, "allowed", corrected
75
+
76
+
77
+ def sql_shape(sql: str, intent: RuntimeIntent) -> SQLShape:
78
+ """Extract structural features from SQL and intent for shape comparison.
79
+
80
+ Args:
81
+
82
+ sql: The generated SQL string.
83
+ intent: The RuntimeIntent providing filter and HAVING counts (including CTE steps).
84
+
85
+ Returns:
86
+
87
+ SQLShape with join count, GROUP BY flag, aggregation flag, CTE count, filter count, HAVING count, and DISTINCT flag.
88
+ """
89
+ s = sql.lower()
90
+ num_filters = len(intent.filters_param or [])
91
+ num_having = len(intent.having_param or [])
92
+ for cte in intent.cte_steps or []:
93
+ num_filters += len(cte.filters_param or [])
94
+ num_having += len(cte.having_param or [])
95
+ return SQLShape(
96
+ num_joins=len(re.findall(r"\bjoin\b", s)),
97
+ has_group_by=bool(re.search(r"\bgroup\s+by\b", s)),
98
+ has_agg=bool(re.search(r"\b(count|sum|avg|min|max)\s*\(", s)),
99
+ num_cte=len(intent.cte_steps or []),
100
+ num_filters=num_filters,
101
+ num_having=num_having,
102
+ has_distinct=bool(re.search(r"\bselect\s+distinct\b", s)),
103
+ )
104
+
105
+
106
+ def _levenshtein_distance(s1: str, s2: str) -> int:
107
+ """Compute Levenshtein edit distance between two strings.
108
+
109
+ Args:
110
+
111
+ s1: The first string.
112
+ s2: The second string.
113
+
114
+ Returns:
115
+
116
+ Integer edit distance (number of single-character insertions, deletions, or substitutions required to transform s1 into s2).
117
+ """
118
+ if len(s1) < len(s2):
119
+ return _levenshtein_distance(s2, s1)
120
+ if len(s2) == 0:
121
+ return len(s1)
122
+ previous_row = range(len(s2) + 1)
123
+ for i, c1 in enumerate(s1):
124
+ current_row = [i + 1]
125
+ for j, c2 in enumerate(s2):
126
+ insertions = previous_row[j + 1] + 1
127
+ deletions = current_row[j] + 1
128
+ substitutions = previous_row[j] + (c1 != c2)
129
+ current_row.append(min(insertions, deletions, substitutions))
130
+ previous_row = current_row
131
+ return previous_row[-1]
132
+
133
+
134
+ def exact_question_match(
135
+ q1: str,
136
+ q2: str,
137
+ max_distance: int = PolicyConfig.FUZZY_MATCH_MAX_DISTANCE,
138
+ label: str = "",
139
+ ) -> bool:
140
+ """Check if two questions are identical after normalisation, within a fuzzy tolerance.
141
+
142
+ Args:
143
+
144
+ q1: The first question string.
145
+ q2: The second question string.
146
+ max_distance: Maximum total Levenshtein distance across all token pairs.
147
+ label: Optional debug label appended to log messages.
148
+
149
+ Returns:
150
+
151
+ True if the questions match within the specified fuzzy tolerance.
152
+ """
153
+ tag = f"[utils.exact_question_match] {label}" if label else "[utils.exact_question_match]"
154
+ q1_norm = normalize_question(q1)
155
+ q2_norm = normalize_question(q2)
156
+ tokens1 = _tokenize(q1_norm)
157
+ tokens2 = _tokenize(q2_norm)
158
+ if len(tokens1) == 0 or len(tokens2) == 0:
159
+ debug(f"{tag} FAIL empty_tokens t1={len(tokens1)} t2={len(tokens2)}")
160
+ return False
161
+ if len(tokens1) != len(tokens2):
162
+ debug(f"{tag} FAIL token_count t1={len(tokens1)} t2={len(tokens2)}")
163
+ return False
164
+ total_dist = 0
165
+ worst_pair = ("", "")
166
+ for t1, t2 in zip(tokens1, tokens2, strict=False):
167
+ dist = _levenshtein_distance(t1, t2)
168
+ total_dist += dist
169
+ if dist > 0:
170
+ worst_pair = (t1, t2)
171
+ if total_dist > max_distance:
172
+ debug(f"{tag} FAIL total_dist={total_dist} worst_pair='{worst_pair[0]}'→'{worst_pair[1]}'")
173
+ return False
174
+ debug(f"{tag} MATCH tokens={len(tokens1)} total_dist={total_dist}")
175
+ return True
176
+
177
+
178
+ def _normalize_filters(filters: list[Any]) -> list[FilterParam]:
179
+ """Normalise a filter list into sorted FilterParam objects for intent-key hashing.
180
+
181
+ Accepts both FilterParam instances and raw dicts (as returned by the LLM). Operators are lowercased; invalid values are defaulted to ``'='``.
182
+
183
+ Args:
184
+
185
+ filters: List of FilterParam instances or raw dicts.
186
+
187
+ Returns:
188
+
189
+ Sorted list of normalised FilterParam instances.
190
+ """
191
+ if not filters:
192
+ return []
193
+ out = []
194
+ for f in filters:
195
+ if isinstance(f, FilterParam):
196
+ left_expr = f.left_expr
197
+ op = f.op.strip().lower() if f.op else "="
198
+ vtype = f.value_type.strip().lower() if isinstance(f.value_type, str) else "unknown"
199
+ right_expr = f.right_expr
200
+ bool_op = f.bool_op
201
+ filter_group = f.filter_group
202
+ elif isinstance(f, dict):
203
+ left_raw = f.get("left_expr")
204
+ if isinstance(left_raw, dict):
205
+ left_expr = NormalizedExpr.from_dict(left_raw)
206
+ else:
207
+ col = f.get("column", "")
208
+ left_expr = NormalizedExpr.from_column(col.strip().lower() if isinstance(col, str) else "")
209
+ op = f.get("op", "=").strip().lower()
210
+ vtype = f.get("value_type", "unknown").strip().lower()
211
+ right_raw = f.get("right_expr")
212
+ right_expr = NormalizedExpr.from_dict(right_raw) if isinstance(right_raw, dict) and right_raw else None
213
+ bool_op = f.get("bool_op", "AND")
214
+ fg_raw = f.get("filter_group")
215
+ filter_group = int(fg_raw) if fg_raw is not None else None
216
+ else:
217
+ continue
218
+ if not left_expr or not isinstance(op, str):
219
+ continue
220
+ fp = FilterParam(
221
+ left_expr=left_expr,
222
+ op=op,
223
+ value_type=vtype,
224
+ param_key="",
225
+ right_expr=right_expr,
226
+ bool_op=bool_op,
227
+ filter_group=filter_group,
228
+ )
229
+ out.append(fp)
230
+ return sort_filters(out)
231
+
232
+
233
+ def _normalize_having_conditions(conditions: list[Any]) -> list[HavingParam]:
234
+ """Normalise a HAVING conditions list into sorted HavingParam objects.
235
+
236
+ Accepts both HavingParam instances and raw dicts. Operators are validated against the allowed set and defaulted to ``'='`` when invalid.
237
+
238
+ Args:
239
+
240
+ conditions: List of HavingParam instances or raw dicts.
241
+
242
+ Returns:
243
+
244
+ Sorted list of normalised HavingParam instances.
245
+ """
246
+ if not conditions:
247
+ return []
248
+ out = []
249
+ for c in conditions:
250
+ if isinstance(c, HavingParam):
251
+ left_expr = c.left_expr
252
+ op = c.op.strip().lower() if c.op else "="
253
+ value_type = c.value_type.strip().lower() if c.value_type else "number"
254
+ right_expr = c.right_expr
255
+ bool_op = c.bool_op
256
+ filter_group = c.filter_group
257
+ elif isinstance(c, dict):
258
+ left_raw = c.get("left_expr")
259
+ if isinstance(left_raw, dict):
260
+ left_expr = NormalizedExpr.from_dict(left_raw)
261
+ else:
262
+ agg = c.get("aggregation", "")
263
+ left_expr = NormalizedExpr.from_column(str(agg)) if agg else NormalizedExpr()
264
+ op = c.get("op", "=")
265
+ value_type = c.get("value_type", "number")
266
+ right_raw = c.get("right_expr")
267
+ right_expr = NormalizedExpr.from_dict(right_raw) if isinstance(right_raw, dict) and right_raw else None
268
+ bool_op = c.get("bool_op", "AND")
269
+ fg_raw = c.get("filter_group")
270
+ filter_group = int(fg_raw) if fg_raw is not None else None
271
+ else:
272
+ continue
273
+ if not left_expr or not left_expr.primary_term:
274
+ continue
275
+ op_norm = op.strip().lower() if isinstance(op, str) else "="
276
+ if op_norm not in VALID_HAVING_OPS:
277
+ op_norm = "="
278
+ hp = HavingParam(
279
+ left_expr=left_expr,
280
+ op=op_norm,
281
+ value_type=str(value_type),
282
+ param_key="",
283
+ right_expr=right_expr,
284
+ bool_op=bool_op,
285
+ filter_group=filter_group,
286
+ )
287
+ out.append(hp)
288
+ return sort_having(out)
289
+
290
+
291
+ def _normalize_cte_steps(steps: Any, available_ctes: dict[str, list[str]] | None = None) -> list[RuntimeCteStep]:
292
+ """Normalise a CTE steps list into RuntimeCteStep objects with resolved column maps.
293
+
294
+ Handles both RuntimeCteStep instances and raw dicts. Normalises filters, HAVING conditions, grain, select columns, order-by columns, and infers column-map and output column metadata from available CTE output names.
295
+
296
+ Args:
297
+
298
+ steps: List of RuntimeCteStep instances or raw dicts.
299
+ available_ctes: Dict of already-processed CTE name -> output columns (mutated in place as each step is processed).
300
+
301
+ Returns:
302
+
303
+ List of normalised RuntimeCteStep instances in input order.
304
+ """
305
+ if not isinstance(steps, list):
306
+ return []
307
+ if available_ctes is None:
308
+ available_ctes = {}
309
+ out = []
310
+ for s in steps:
311
+ if isinstance(s, RuntimeCteStep):
312
+ cte_name = s.cte_name
313
+ tables = s.tables or []
314
+ grain = s.grain or "row_level"
315
+ select_cols = s.select_cols or []
316
+ group_by_cols = s.group_by_cols or []
317
+ output_columns = s.output_columns or []
318
+ filters_param = s.filters_param or []
319
+ having_param = s.having_param or []
320
+ param_values = s.param_values or {}
321
+ order_by_cols = s.order_by_cols or []
322
+ limit = s.limit
323
+ column_map = s.column_map or {}
324
+ output_column_metadata = s.output_column_metadata or {}
325
+ chosen_join_candidate_id = s.chosen_join_candidate_id or ""
326
+ chosen_join_path_signature = s.chosen_join_path_signature or []
327
+ elif isinstance(s, dict):
328
+ cte_name = s.get("cte_name", "")
329
+ tables = s.get("tables", [])
330
+ grain = s.get("grain", "row_level")
331
+ sc_raw = s.get("select_cols", [])
332
+ select_cols = [
333
+ (
334
+ SelectCol.from_dict(sc)
335
+ if isinstance(sc, dict)
336
+ else (SelectCol(expr=NormalizedExpr.from_column(sc)) if isinstance(sc, str) else sc)
337
+ )
338
+ for sc in sc_raw
339
+ ]
340
+ group_by_cols = s.get("group_by_cols", [])
341
+ group_by_cols = [
342
+ (
343
+ NormalizedExpr.from_dict(g)
344
+ if isinstance(g, dict)
345
+ else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
346
+ )
347
+ for g in group_by_cols
348
+ ]
349
+ output_columns = s.get("output_columns", [])
350
+ fp_raw = s.get("filters_param", [])
351
+ filters_param = [FilterParam.from_dict(f) if isinstance(f, dict) else f for f in fp_raw]
352
+ hp_raw = s.get("having_param", [])
353
+ having_param = [HavingParam.from_dict(h) if isinstance(h, dict) else h for h in hp_raw]
354
+ param_values = s.get("param_values", {})
355
+ obc_raw = s.get("order_by_cols", [])
356
+ order_by_cols = [
357
+ (
358
+ OrderByCol.from_dict(o)
359
+ if isinstance(o, dict)
360
+ else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
361
+ )
362
+ for o in obc_raw
363
+ ]
364
+ limit = s.get("limit")
365
+ column_map = s.get("column_map", {})
366
+ ocm_raw = s.get("output_column_metadata", {})
367
+ output_column_metadata = {
368
+ k: CteOutputColumnMeta.from_dict(v) if isinstance(v, dict) else v for k, v in ocm_raw.items()
369
+ }
370
+ chosen_join_candidate_id = s.get("chosen_join_candidate_id", "")
371
+ chosen_join_path_signature = s.get("chosen_join_path_signature", [])
372
+ else:
373
+ continue
374
+ if not cte_name:
375
+ continue
376
+ if grain not in VALID_GRAINS:
377
+ grain = "row_level"
378
+ normalized_fp = []
379
+ for f in filters_param:
380
+ if isinstance(f, FilterParam):
381
+ op = f.op.strip().lower() if f.op else "="
382
+ if op not in VALID_FILTER_OPS:
383
+ op = "="
384
+ vtype = f.value_type.strip().lower() if f.value_type else "unknown"
385
+ fp = FilterParam(
386
+ left_expr=f.left_expr,
387
+ op=op,
388
+ value_type=vtype,
389
+ param_key=f.param_key or "",
390
+ right_expr=f.right_expr,
391
+ )
392
+ normalized_fp.append(fp)
393
+ normalized_hp = []
394
+ for h in having_param:
395
+ if isinstance(h, HavingParam):
396
+ op = h.op.strip().lower() if h.op else "="
397
+ if op not in VALID_HAVING_OPS:
398
+ op = "="
399
+ vtype = h.value_type.strip().lower() if h.value_type else "number"
400
+ hp = HavingParam(
401
+ left_expr=h.left_expr,
402
+ op=op,
403
+ value_type=vtype,
404
+ param_key=h.param_key or "",
405
+ right_expr=h.right_expr,
406
+ )
407
+ normalized_hp.append(hp)
408
+ cte_column_map = {}
409
+ all_cols_raw = [g.primary_column for g in group_by_cols] + [f.left_expr.primary_column for f in normalized_fp]
410
+ for sc in select_cols:
411
+ if isinstance(sc, SelectCol):
412
+ all_cols_raw.append(sc.expr.primary_column)
413
+ elif isinstance(sc, str):
414
+ all_cols_raw.append(sc)
415
+ for col in all_cols_raw:
416
+ if "." in col:
417
+ parts = col.split(".", 1)
418
+ table_ref = parts[0].lower()
419
+ col_name = parts[1].lower()
420
+ if table_ref in {t.lower() for t in tables} or table_ref in {c.lower() for c in available_ctes.keys()}:
421
+ cte_column_map[col_name] = table_ref
422
+ if column_map:
423
+ cte_column_map.update(column_map)
424
+ ocm = {}
425
+ for out_col in output_columns:
426
+ out_col_lower = out_col.lower()
427
+ if out_col_lower in output_column_metadata:
428
+ ocm[out_col_lower] = output_column_metadata[out_col_lower]
429
+ else:
430
+ is_agg = any(out_col_lower.startswith(f"{agg}_") for agg in VALID_AGGREGATION_FUNCTIONS)
431
+ base_col = ""
432
+ agg_func = ""
433
+ if is_agg:
434
+ for agg in VALID_AGGREGATION_FUNCTIONS:
435
+ if out_col_lower.startswith(f"{agg}_"):
436
+ agg_func = agg
437
+ base_col = out_col_lower[len(agg) + 1 :].replace("_", ".")
438
+ break
439
+ ocm[out_col_lower] = CteOutputColumnMeta(
440
+ source="aggregation" if is_agg else "passthrough",
441
+ base_column=base_col,
442
+ agg_func=agg_func,
443
+ filterable=True,
444
+ aggregatable=True,
445
+ data_type=("integer" if is_agg and agg_func == "count" else "unknown"),
446
+ )
447
+ normalized_select_cols = (
448
+ select_cols
449
+ if select_cols and isinstance(select_cols[0], SelectCol)
450
+ else ([SelectCol(expr=NormalizedExpr.from_column(c)) for c in select_cols] if select_cols else [])
451
+ )
452
+ normalized_order_by = (
453
+ order_by_cols
454
+ if order_by_cols and isinstance(order_by_cols[0], OrderByCol)
455
+ else (
456
+ [(OrderByCol(expr=NormalizedExpr.from_column(c)) if isinstance(c, str) else c) for c in order_by_cols]
457
+ if order_by_cols
458
+ else []
459
+ )
460
+ )
461
+ cte = RuntimeCteStep(
462
+ cte_name=str(cte_name),
463
+ tables=sorted(set(str(t) for t in tables)),
464
+ grain=grain,
465
+ select_cols=normalized_select_cols,
466
+ group_by_cols=group_by_cols,
467
+ output_columns=list(str(c) for c in output_columns),
468
+ filters_param=sorted(
469
+ normalized_fp,
470
+ key=lambda x: (
471
+ x.left_expr.signature_key,
472
+ x.op,
473
+ x.right_expr.signature_key if x.right_expr else "",
474
+ x.value_type,
475
+ ),
476
+ ),
477
+ having_param=sorted(
478
+ normalized_hp,
479
+ key=lambda x: (
480
+ x.left_expr.signature_key,
481
+ x.op,
482
+ x.right_expr.signature_key if x.right_expr else "",
483
+ x.value_type,
484
+ ),
485
+ ),
486
+ param_values=param_values,
487
+ order_by_cols=normalized_order_by,
488
+ limit=limit,
489
+ column_map=cte_column_map,
490
+ output_column_metadata=ocm,
491
+ chosen_join_candidate_id=chosen_join_candidate_id,
492
+ chosen_join_path_signature=chosen_join_path_signature,
493
+ )
494
+ out.append(cte)
495
+ available_ctes[cte_name] = output_columns
496
+ return out
497
+
498
+
499
+ def _normalize_cte_steps_for_key(steps: list[RuntimeCteStep]) -> list[dict[str, Any]]:
500
+ """Convert RuntimeCteStep list to stable structural dicts for intent-key hashing.
501
+
502
+ Args:
503
+
504
+ steps: List of normalised RuntimeCteStep instances.
505
+
506
+ Returns:
507
+
508
+ List of dicts with sorted keys suitable for stable JSON serialisation.
509
+ """
510
+ result = []
511
+ for cte in steps:
512
+ select_sigs = []
513
+ for sc in cte.select_cols or []:
514
+ if isinstance(sc, SelectCol):
515
+ select_sigs.append(sc.signature_key)
516
+ elif isinstance(sc, str):
517
+ select_sigs.append(sc)
518
+ order_sigs = []
519
+ for obc in cte.order_by_cols or []:
520
+ if isinstance(obc, OrderByCol):
521
+ order_sigs.append(obc.signature_key)
522
+ elif isinstance(obc, str):
523
+ order_sigs.append(obc)
524
+ cte_dict = {
525
+ "cte_name": cte.cte_name,
526
+ "tables": sorted(cte.tables or []),
527
+ "select_cols": sorted(select_sigs),
528
+ "group_by_cols": sorted([g.signature_key for g in (cte.group_by_cols or [])]),
529
+ "output_columns": sorted(cte.output_columns or []),
530
+ "filters_param": [
531
+ f"{f.signature_key}|{f.bool_op}|{f.filter_group}"
532
+ for f in sorted(
533
+ cte.filters_param or [],
534
+ key=lambda x: (
535
+ x.filter_group if x.filter_group is not None else -1,
536
+ x.left_expr.signature_key,
537
+ x.op,
538
+ x.right_expr.signature_key if x.right_expr else "",
539
+ x.value_type,
540
+ ),
541
+ )
542
+ ],
543
+ "having_param": [
544
+ f"{h.signature_key}|{h.bool_op}|{h.filter_group}"
545
+ for h in sorted(
546
+ cte.having_param or [],
547
+ key=lambda x: (
548
+ x.filter_group if x.filter_group is not None else -1,
549
+ x.left_expr.signature_key,
550
+ x.op,
551
+ x.right_expr.signature_key if x.right_expr else "",
552
+ x.value_type,
553
+ ),
554
+ )
555
+ ],
556
+ "order_by_cols": sorted(order_sigs),
557
+ }
558
+ result.append(cte_dict)
559
+ return result
560
+
561
+
562
+ def flatten_param_values(intent: RuntimeIntent) -> dict[str, Any]:
563
+ """Flatten all param_values from CTE steps and the main query into a single dict.
564
+
565
+ CTE values are merged first; main query values take precedence on key collision.
566
+
567
+ Args:
568
+
569
+ intent: The RuntimeIntent whose CTE steps and main param_values are merged.
570
+
571
+ Returns:
572
+
573
+ Merged dict of all parameter values.
574
+ """
575
+ merged = {}
576
+ for cte in intent.cte_steps or []:
577
+ merged.update(cte.param_values or {})
578
+ merged.update(intent.param_values or {})
579
+ return merged
580
+
581
+
582
+ def intent_key(intent: RuntimeIntent) -> str:
583
+ """Generate a stable SHA-256 hash key for a RuntimeIntent.
584
+
585
+ Normalises tables, select columns, filters, GROUP BY, ORDER BY, HAVING, and CTE steps before hashing so that semantically equivalent intents produce the same key regardless of input order.
586
+
587
+ Args:
588
+
589
+ intent: The RuntimeIntent to fingerprint.
590
+
591
+ Returns:
592
+
593
+ Hex-encoded SHA-256 hash string.
594
+ """
595
+ select_cols = intent.select_cols or []
596
+
597
+ grain = intent.grain or "row_level"
598
+ if grain not in VALID_GRAINS:
599
+ debug(f"[utils.intent_key] invalid_grain: '{grain}' defaulting to 'row_level'")
600
+ grain = "row_level"
601
+
602
+ expected_rows = intent.expected_rows or "many"
603
+ if expected_rows not in VALID_EXPECTED_ROWS:
604
+ if grain == "scalar":
605
+ expected_rows = "one"
606
+ elif grain == "grouped":
607
+ expected_rows = "few"
608
+ else:
609
+ expected_rows = "many"
610
+ debug(f"[utils.intent_key] inferred_expected_rows: grain={grain} -> expected_rows={expected_rows}")
611
+
612
+ filters_normalized = _normalize_filters(intent.filters_param or [])
613
+ having_conditions_normalized = _normalize_having_conditions(intent.having_param or [])
614
+ cte_steps_normalized = _normalize_cte_steps(intent.cte_steps or [])
615
+ cte_steps_for_key = _normalize_cte_steps_for_key(cte_steps_normalized)
616
+
617
+ select_cols_sorted = sorted([s.signature_key for s in select_cols])
618
+ order_by_sorted = sorted([o.signature_key for o in (intent.order_by_cols or [])])
619
+
620
+ normalized = {
621
+ "tables": sorted(intent.tables or []),
622
+ "select_cols": select_cols_sorted,
623
+ "filters": [f.to_dict() for f in filters_normalized],
624
+ "group_by_cols": sorted([g.signature_key for g in (intent.group_by_cols or [])]),
625
+ "order_by_cols": order_by_sorted,
626
+ "having_conditions": [hc.to_dict() for hc in having_conditions_normalized],
627
+ "cte_steps": cte_steps_for_key,
628
+ }
629
+ key = sha256(stable_json(normalized))
630
+ debug(f"[utils.intent_key] computed: tables={normalized['tables']} key={key[:16]}...")
631
+ return key
632
+
633
+
634
+ def _tokenize(q: str) -> list[str]:
635
+ """Extract sorted, unique tokens from a normalised question, excluding stopwords.
636
+
637
+ Args:
638
+
639
+ q: The normalised question string.
640
+
641
+ Returns:
642
+
643
+ Sorted list of alphanumeric/underscore tokens with stopwords removed.
644
+ """
645
+ return sorted(t for t in re.findall(r"[a-z0-9_]+", q) if t and t not in PolicyConfig.STOPWORDS)
646
+
647
+
648
+ def extract_tables_from_sql(sql: str, known_tables: list[str]) -> list[str]:
649
+ """Extract real table names referenced in SQL from a known list, excluding CTE names.
650
+
651
+ Args:
652
+
653
+ sql: The SQL string to scan.
654
+ known_tables: List of base table names to match against.
655
+
656
+ Returns:
657
+
658
+ Sorted list of matched table names that are not CTE aliases.
659
+ """
660
+ s = sql.lower()
661
+
662
+ cte_names = set()
663
+ if re.search(r"\bwith\b", s):
664
+ for m in re.finditer(r"\b(\w+)\s+AS\s*\(", s, re.IGNORECASE):
665
+ cte_names.add(m.group(1).lower())
666
+
667
+ hits = []
668
+ for t in known_tables:
669
+ tn = t.lower()
670
+ if tn not in cte_names and re.search(rf"\b{re.escape(tn)}\b", s):
671
+ hits.append(t)
672
+ return sorted(set(hits))
673
+
674
+
675
+ def _describe_operation(select_terms: list[str]) -> str:
676
+ """Derive the primary operation type from a list of SELECT expression strings.
677
+
678
+ Args:
679
+
680
+ select_terms: List of SELECT column/expression strings (e.g. ``["COUNT(t.id)"]``).
681
+
682
+ Returns:
683
+
684
+ Lowercase aggregation function name (e.g. ``'count'``, ``'sum'``) or ``'list'`` if no aggregation is detected.
685
+ """
686
+ for sc in select_terms:
687
+ m = AGG_PATTERN.match(sc)
688
+ if m:
689
+ return m.group(1).lower()
690
+ return "list"
691
+
692
+
693
+ def _pick_question_style(select_terms: list[str], has_grouping: bool) -> str:
694
+ """Pick a random question-opening phrase appropriate for the query structure.
695
+
696
+ Args:
697
+
698
+ select_terms: List of SELECT expression strings used to detect aggregation.
699
+ has_grouping: True if the query has GROUP BY columns.
700
+
701
+ Returns:
702
+
703
+ A randomly selected opening phrase string.
704
+ """
705
+ agg_funcs = []
706
+ for sc in select_terms:
707
+ m = AGG_PATTERN.match(sc)
708
+ if m:
709
+ agg_funcs.append(m.group(1).lower())
710
+
711
+ if has_grouping:
712
+ return random.choice(QUESTION_STARTS_GROUP)
713
+ elif agg_funcs:
714
+ agg = agg_funcs[0]
715
+ if agg == "count":
716
+ return random.choice(["How many", "Count", "What is the number of"])
717
+ elif agg == "sum":
718
+ return random.choice(["What is the total", "Find the sum of", "Calculate the total"])
719
+ elif agg == "avg":
720
+ return random.choice(["What is the average", "Calculate the average", "Find the mean"])
721
+ elif agg in ("min", "max"):
722
+ return random.choice([f"What is the {agg}imum", f"Find the {agg}imum", f"Get the {agg}"])
723
+ return random.choice(QUESTION_STARTS_AGG)
724
+ else:
725
+ return random.choice(QUESTION_STARTS_LIST)
726
+
727
+
728
+ def generate_question(
729
+ tables: list[str],
730
+ select_terms: list[str],
731
+ filter_descriptions: list[dict[str, str]],
732
+ group_by_terms: list[str],
733
+ having_descriptions: list[dict[str, str]],
734
+ schema: SchemaGraph,
735
+ ) -> str | None:
736
+ """Generate a natural language question from a plain-data intent description via LLM.
737
+
738
+ Builds a structured prompt including semantic context from the schema graph and enforces a randomly selected question-opening phrase for variety. Callers are responsible for resolving all param placeholder values into the condition strings of ``filter_descriptions`` and ``having_descriptions`` before calling.
739
+
740
+ Args:
741
+
742
+ tables: List of table names involved in the query.
743
+ select_terms: List of SELECT expression strings.
744
+ filter_descriptions: List of dicts with ``'column'`` and ``'condition'`` keys, with concrete values already substituted.
745
+ group_by_terms: List of GROUP BY column expression strings.
746
+ having_descriptions: List of dicts with HAVING condition descriptions, with concrete values already substituted.
747
+ schema: The SchemaGraph used to look up table and column descriptions.
748
+
749
+ Returns:
750
+
751
+ The generated question string, or None if the LLM call fails or returns a response that violates the required opening phrase.
752
+ """
753
+ semantics = {}
754
+ for table in tables:
755
+ table_ir = schema.tables.get(table)
756
+ if table_ir:
757
+ semantics[table] = {
758
+ "description": table_ir.description or f"{table} records",
759
+ "columns": {col: (getattr(meta, "description", None) or col) for col, meta in table_ir.columns.items()},
760
+ }
761
+
762
+ roles = {}
763
+ all_columns = set()
764
+ for fd in filter_descriptions:
765
+ all_columns.add(fd.get("column", ""))
766
+ for col in group_by_terms:
767
+ all_columns.add(col)
768
+ for col in all_columns:
769
+ parts = col.split(".")
770
+ if len(parts) == 2 and parts[0] in schema.tables:
771
+ col_meta = schema.tables[parts[0]].columns.get(parts[1])
772
+ if col_meta and col_meta.role:
773
+ roles[col] = col_meta.role
774
+
775
+ operation = _describe_operation(select_terms)
776
+
777
+ intent_structure = {
778
+ "tables": tables,
779
+ "operation": operation,
780
+ "columns": select_terms,
781
+ "filters": filter_descriptions if filter_descriptions else None,
782
+ "grouping": group_by_terms if group_by_terms else None,
783
+ "having": having_descriptions if having_descriptions else None,
784
+ }
785
+
786
+ selected_style = _pick_question_style(select_terms, bool(group_by_terms))
787
+
788
+ system = (
789
+ "You are a natural language question generator for database queries. Convert structured query intent to conversational questions.\n\n"
790
+ "Output Requirements:\n"
791
+ '- Output ONLY valid JSON with single "question" field\n'
792
+ "- Do NOT include markdown, explanations, or commentary\n"
793
+ "- Identical inputs must produce identical outputs\n\n"
794
+ "Generation Rules:\n"
795
+ "- Question MUST reflect EXACTLY the intent structure (same tables, columns, filters, aggregation)\n"
796
+ "- Do NOT add columns, tables, or filters not in the intent\n"
797
+ "- Do NOT change aggregation type or omit filter values\n"
798
+ "- Use natural, conversational language - avoid SQL jargon and raw table names\n"
799
+ "- Use semantic context to refer to business concepts naturally\n"
800
+ "- For multi-table queries, describe relationships naturally\n"
801
+ "- Inject filter values naturally into the question\n"
802
+ "- ONE sentence only\n"
803
+ "- Vary phrasing naturally within the required start constraint\n"
804
+ )
805
+
806
+ user_prompt = {
807
+ "task": "Generate a natural language question for this query intent",
808
+ "intent": intent_structure,
809
+ "semantic_context": semantics,
810
+ "column_roles": roles,
811
+ "phrasing_constraint": {
812
+ "required_start": selected_style,
813
+ "description": f"Question MUST start with exactly: '{selected_style}'",
814
+ "strict_mode": not filter_descriptions and not group_by_terms,
815
+ "phrasing_flexibility": "Vary word choice and sentence structure while keeping the required start",
816
+ },
817
+ "examples": [
818
+ {
819
+ "intent": {
820
+ "tables": ["table1"],
821
+ "operation": "count",
822
+ "columns": ["COUNT(table1.column1)"],
823
+ "filters": [{"column": "table1.column2", "condition": "= 1"}],
824
+ },
825
+ "required_start": "How many",
826
+ "question": "How many active records do we have?",
827
+ },
828
+ {
829
+ "intent": {
830
+ "tables": ["table1"],
831
+ "operation": "list",
832
+ "columns": ["table1.column1", "table1.column2"],
833
+ "filters": [{"column": "table1.column2", "condition": "= 'value1'"}],
834
+ },
835
+ "required_start": "What are",
836
+ "question": "What are the records with column2 equal to value1?",
837
+ },
838
+ {
839
+ "intent": {
840
+ "tables": ["table1", "table2"],
841
+ "operation": "sum",
842
+ "columns": ["SUM(table2.column1)", "table1.column1"],
843
+ "grouping": ["table2.column2"],
844
+ },
845
+ "required_start": "Show me",
846
+ "question": "Show me the total column1 for each group.",
847
+ },
848
+ ],
849
+ "output_format": {"question": "Your natural language question here"},
850
+ }
851
+
852
+ try:
853
+ response = llm_json(system, stable_json(user_prompt), retries=1, task="default")
854
+ question = response.get("question")
855
+ if question and isinstance(question, str):
856
+ question = question.strip()
857
+ template_start = selected_style.split("{")[0].strip()
858
+ if template_start and not question.startswith(template_start):
859
+ debug(
860
+ f"[utils.generate_question] phrasing_violation: expected_start={template_start}, got={question[:30]}"
861
+ )
862
+ return None
863
+ debug(f"[utils.generate_question] generated: {question[:50]}")
864
+ return question
865
+ debug("[utils.generate_question] missing_question_field")
866
+ return None
867
+ except Exception as e:
868
+ debug(f"[utils.generate_question] failed: {e}")
869
+ return None
870
+
871
+
872
+ _QUESTION_FROM_SQL_SYSTEM = (
873
+ "You are given a SQL query and a schema description. "
874
+ "Your job is to decide whether the query represents a realistic, "
875
+ "meaningful business question and, if so, generate a natural language "
876
+ "question that a human analyst would ask to get this query's result.\n\n"
877
+ "Rules:\n"
878
+ "- If the query is unrealistic, nonsensical, or produces meaningless "
879
+ "results, set is_realistic to false and explain why in drop_reason.\n"
880
+ "- If realistic, write ONE clear, conversational question that a "
881
+ "non-technical user would ask. Do NOT use SQL jargon or raw column "
882
+ "names — use natural business language.\n"
883
+ "- Output ONLY valid JSON with exactly three fields: "
884
+ '"question", "is_realistic" (boolean), "drop_reason" (string or null).\n'
885
+ )
886
+
887
+
888
+ def generate_question_from_sql(
889
+ sql: str,
890
+ schema: SchemaGraph,
891
+ tables: list[str],
892
+ ) -> dict[str, Any] | None:
893
+ """Generate an NL question from a substituted SQL query with a realism gate.
894
+
895
+ Sends the SQL and table/column descriptions to the LLM (default
896
+ task profile = gpt-4o-mini). Returns a dict with ``question``,
897
+ ``is_realistic``, and ``drop_reason`` keys, or ``None`` on failure.
898
+
899
+ Args:
900
+
901
+ sql: Fully substituted SQL string with literal values.
902
+ schema: Schema graph for table/column descriptions.
903
+ tables: Tables referenced by the query.
904
+
905
+ Returns:
906
+
907
+ Dict with question and realism judgment, or None on LLM failure.
908
+ """
909
+ col_descriptions: list[str] = []
910
+ for table in tables:
911
+ table_meta = schema.tables.get(table)
912
+ if not table_meta:
913
+ continue
914
+ cols = []
915
+ for col_name, col_meta in table_meta.columns.items():
916
+ desc = f"{col_name} ({col_meta.data_type or 'unknown'})"
917
+ if col_meta.role:
918
+ desc += f" [{col_meta.role}]"
919
+ cols.append(desc)
920
+ col_descriptions.append(
921
+ f"TABLE {table}: {', '.join(cols)}"
922
+ )
923
+ schema_context = "\n".join(col_descriptions)
924
+
925
+ user_prompt = stable_json({
926
+ "sql": sql,
927
+ "schema": schema_context,
928
+ "output_format": {
929
+ "question": "natural language question or empty string",
930
+ "is_realistic": True,
931
+ "drop_reason": None,
932
+ },
933
+ })
934
+
935
+ try:
936
+ response = llm_json(
937
+ _QUESTION_FROM_SQL_SYSTEM, user_prompt,
938
+ retries=1, task="default",
939
+ )
940
+ is_realistic = response.get("is_realistic", False)
941
+ question = response.get("question", "")
942
+ drop_reason = response.get("drop_reason")
943
+
944
+ if not isinstance(is_realistic, bool):
945
+ is_realistic = str(is_realistic).lower() in ("true", "1", "yes")
946
+
947
+ if not is_realistic:
948
+ debug(
949
+ f"[utils.generate_question_from_sql] dropped: "
950
+ f"reason={drop_reason}"
951
+ )
952
+ return {
953
+ "question": "",
954
+ "is_realistic": False,
955
+ "drop_reason": drop_reason or "unrealistic",
956
+ }
957
+
958
+ if not question or not isinstance(question, str):
959
+ debug("[utils.generate_question_from_sql] empty question")
960
+ return None
961
+
962
+ debug(
963
+ f"[utils.generate_question_from_sql] generated: "
964
+ f"{question[:60]}"
965
+ )
966
+ return {
967
+ "question": question.strip(),
968
+ "is_realistic": True,
969
+ "drop_reason": None,
970
+ }
971
+ except Exception as e:
972
+ debug(f"[utils.generate_question_from_sql] failed: {e}")
973
+ return None