aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2133 @@
1
+ """Intent parsing pipeline orchestration and template matching."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import replace
6
+
7
+ from .config import VALID_ARITH_OPS, PolicyConfig
8
+ from .contracts_base import IntentIssue, SchemaGraph
9
+ from .contracts_core import (
10
+ ConcreteIntent,
11
+ FilterParam,
12
+ HavingParam,
13
+ NormalizedExpr,
14
+ OrderByCol,
15
+ RuntimeCteStep,
16
+ RuntimeIntent,
17
+ SelectCol,
18
+ SimulatorIntent,
19
+ Template,
20
+ )
21
+ from .core_utils import debug, llm_chat, safe_json_loads, stable_json
22
+ from .intent_expr import (
23
+ assign_param_keys,
24
+ build_cte_output_metadata,
25
+ collect_raw_param_values,
26
+ decompose_between_params,
27
+ derive_cte_output_columns,
28
+ normalize_date_diff_raw_values,
29
+ ensure_scalar_func_defaults,
30
+ extract_structural_params,
31
+ normalize_in_raw_values,
32
+ parse_intent_response,
33
+ repair_misclassified_date_diff,
34
+ tag_expr_numeric,
35
+ )
36
+ from .intent_repair import (
37
+ dedup_contradictory_filters,
38
+ expand_fk_select_to_descriptive,
39
+ normalize_boolean_filter_values,
40
+ normalize_in_filter_types,
41
+ normalize_null_filter_values,
42
+ normalize_pk_distinct,
43
+ prune_unreferenced_tables,
44
+ qualify_cte_output_columns,
45
+ repair_fk_filter_type_mismatch,
46
+ repair_null_equality_filters,
47
+ requalify_redundant_pk_references,
48
+ resolve_filter_value_case,
49
+ sanitize_table_names,
50
+ strip_impossible_having,
51
+ strip_join_conditions,
52
+ strip_spurious_group_by,
53
+ )
54
+ from .intent_expr import extract_columns_from_expr, replace_refs_in_expr
55
+ from .intent_resolve import (
56
+ enforce_cte_grain_consistency,
57
+ enforce_grain_consistency,
58
+ enforce_schema,
59
+ force_main_grain_when_using_grouped_cte,
60
+ normalize_count_star,
61
+ normalize_cte_names,
62
+ normalize_filters_havings,
63
+ rewrite_cte_output_refs_to_aliases,
64
+ resolve_column_map,
65
+ resolve_cte_column_maps,
66
+ simplify_exprs,
67
+ sort_order_by_cols,
68
+ sort_select_cols,
69
+ )
70
+ from .utils import exact_question_match
71
+ from .validation_execute import validate_semantics
72
+ from .validation_semantic import auto_repair_filter_having
73
+
74
+ REPAIR_INSTRUCTIONS: dict[str, str] = {
75
+ "extract_epoch": (
76
+ "Do not use EXTRACT(EPOCH FROM ...). Use date column subtraction "
77
+ "(e.g. <table>.<date_col> - <table>.<date_col>) or other supported "
78
+ "date functions for time differences."
79
+ ),
80
+ "grain_validity": "Use one of the allowed grain values: 'scalar', 'grouped', or 'row_level'.",
81
+ "grain_consistency": "Ensure grain matches the query structure. 'grouped' requires group_by_cols and aggregation in select_cols. 'row_level' means no GROUP BY and no aggregation. 'scalar' means a single aggregated value with no GROUP BY.",
82
+ "schema_validation": "One or more tables or columns do not exist in the schema. Check allowed_tables and use only exact column names from each table.",
83
+ "unknown_table": "The table does not exist in the schema. Remove it from tables and rewrite any references to use only tables that appear in allowed_tables.",
84
+ "unknown_column": "The column does not exist in its table. Check the schema for available columns with a similar name or meaning and replace the reference.",
85
+ "semantic_contradiction": "The intent contains contradictory operations. Keep only the aggregation or pattern that matches the question.",
86
+ "expression_type": "Arithmetic expressions require numeric columns. Ensure all operands in arithmetic and all comparison sides have compatible types.",
87
+ "filter_aggregation": "Conditions with aggregation functions (COUNT, SUM, AVG, MIN, MAX) belong in having_param, not filters_param. Move the condition.",
88
+ "having_aggregation": "HAVING conditions must have an aggregation function in left_expr. Conditions without aggregation belong in filters_param.",
89
+ "filter_semantic": "Fix the filter comparison: remove self-comparisons and ensure type compatibility between left and right expressions.",
90
+ "having_semantic": "Fix the HAVING comparison: remove self-comparisons and ensure type compatibility between aggregated expressions.",
91
+ "nested_aggregation": "Nested aggregation (e.g. SUM(COUNT(...))) is not allowed. Use a CTE: compute the inner aggregation in a CTE step, then aggregate the CTE output in the main query.",
92
+ "mixed_aggregation": "An expression cannot mix aggregated and bare column terms. Either wrap all terms in an aggregation function or add bare columns to group_by_cols.",
93
+ "group_by_membership": "Every non-aggregated column in select_cols must appear in group_by_cols when grain is 'grouped'. Add the missing column to group_by_cols or wrap it in an aggregation.",
94
+ "order_by_aggregation": "ORDER BY cannot contain aggregation when grain is 'row_level'. Change grain to 'grouped' or remove the aggregation from order_by_cols.",
95
+ "aggregation_hint": "The question contains a quantity-comparison phrase that typically requires COUNT or SUM aggregation with GROUP BY and HAVING. Add aggregation in select_cols, group_by_cols on the entity, having_param with the threshold, and set grain to 'grouped'.",
96
+ "column_schema": "A referenced column does not exist in its table. Check the schema for the correct column name.",
97
+ "table_schema": "A referenced table does not exist in the schema. Use only tables from allowed_tables.",
98
+ "filter_type_ops": "The filter operator is not compatible with the column data type. Use an appropriate operator for the column type.",
99
+ "null_filter": "NULL checks must use 'is null' or 'is not null' operator with no value or value_type field.",
100
+ "between_filter": "BETWEEN is not supported. Decompose into two separate filters: one with op '>=' and one with op '<='.",
101
+ "date_window": (
102
+ "For relative time filters (e.g. 'last 90 days', 'past 6 months'), "
103
+ "use value_type 'date_window' with the date column in left_expr, "
104
+ "op '>=' and value as {\"unit\": \"day\"|\"week\"|\"month\"|\"year\", \"offset\": N}. "
105
+ "Use singular unit names."
106
+ ),
107
+ "date_diff": (
108
+ "For date-difference filters comparing two date columns "
109
+ "(e.g. 'return_date - rental_date > 7 days'), use value_type 'date_diff' with "
110
+ "left_expr as the date subtraction (<table>.<end_date> - <table>.<start_date>), "
111
+ "op as the comparison operator, and value as "
112
+ "{\"unit\": \"day\"|\"week\"|\"month\"|\"year\", \"amount\": N}. "
113
+ "Use singular unit names. Do NOT use date_diff for relative-to-now filters; "
114
+ "use date_window instead."
115
+ ),
116
+ "agg_role": "SUM and AVG should only be applied to numeric measure columns. Use COUNT for non-measure columns, or select a numeric column.",
117
+ "agg_type": "SUM and AVG require a numeric column. The referenced column is not numeric. Use COUNT instead or choose a numeric column.",
118
+ "for_each_grouping": "The question contains a 'for each', 'per', or 'by' phrase implying a GROUP BY on the referenced entity. Add the entity's identifying column to group_by_cols, include it as a non-aggregated entry in select_cols, and set grain to 'grouped'.",
119
+ "scalar_func_type": "The scalar function is applied to an incompatible column type. Ensure the column type matches what the function expects (e.g. YEAR needs a date column, UPPER needs a text column).",
120
+ "threshold_missing_having": "The question contains a threshold phrase (e.g. 'more than', 'at least') and the intent already has aggregation, but no HAVING condition is defined. Add a HAVING clause that compares the aggregated column to the numeric threshold in the question.",
121
+ "cte_structure": "CTE steps require a cte_name string, an output_columns list of alias strings, and valid tables.",
122
+ "cte_grain_consistency": "CTE grain must match its structure: same rules as the main query regarding grain, group_by_cols, and aggregation.",
123
+ "cte_table_reference": "A CTE references an unknown table. A CTE can only reference schema tables or previously defined CTEs.",
124
+ "cte_grain_compatibility": "A row_level query or CTE depends on an aggregated CTE. Ensure the grain is compatible with upstream CTE grains.",
125
+ "cte_aggregation": "A CTE has HAVING conditions but no aggregation in its select_cols. Add aggregation or remove the HAVING.",
126
+ "missing_scoping_table": "The question explicitly mentions a schema table that is missing from the intent. Add the table to 'tables', join it via its foreign-key relationship, and include relevant columns in select_cols or filters as appropriate.",
127
+ "pk_fk_filter": "Do not filter on a primary-key or foreign-key integer column. Replace the identifier filter with a filter on the descriptive column of the referenced table, join that table if it is not already present, and use the human-readable value from the question.",
128
+ "agg_keyword_missing": "The question asks for an aggregation (total, count, average, sum, etc.) but the intent has no aggregated column and no HAVING condition. Add the appropriate aggregation function to select_cols, include all tables needed to compute the aggregated value, and set grain to 'grouped' with the correct group_by_cols.",
129
+ }
130
+ """Map from IntentIssue.category to a targeted fix instruction for the LLM semantic-repair prompt."""
131
+
132
+
133
+ def _resolve_repair_instruction(issue: IntentIssue) -> str:
134
+ """Return a targeted fix instruction for a semantic issue."""
135
+ return REPAIR_INSTRUCTIONS.get(issue.category, issue.message)
136
+
137
+
138
+ def _classify_schema_error(error_message: str) -> str:
139
+ """Classify a schema enforcement error into a specific category.
140
+
141
+ Args:
142
+
143
+
144
+ error_message: Error string produced by enforce_schema.
145
+
146
+ Returns:
147
+
148
+
149
+ One of 'unknown_table', 'unknown_column', or 'schema_validation' as a fallback.
150
+ """
151
+ lower = error_message.lower()
152
+ if "unknown table" in lower:
153
+ return "unknown_table"
154
+ if "unknown" in lower and "column" in lower:
155
+ return "unknown_column"
156
+ return "schema_validation"
157
+
158
+
159
+ def _build_intent_semantic_repair_prompt(
160
+ question: str,
161
+ current_intent_json: str,
162
+ errors: list[IntentIssue],
163
+ warnings: list[IntentIssue],
164
+ schema_literal_text: str,
165
+ ) -> str:
166
+ """Build a user-prompt for the LLM to repair semantic issues in a parsed intent.
167
+
168
+ Errors are presented as errors_to_fix with targeted fix instructions sourced from REPAIR_INSTRUCTIONS and warnings are presented as non-binding suggestions.
169
+
170
+ Args:
171
+
172
+
173
+ question: Original natural language question.
174
+ current_intent_json: JSON string of the current flawed parsed intent.
175
+ errors: IntentIssue objects with severity equal to "error".
176
+ warnings: IntentIssue objects with severity equal to "warning".
177
+ schema_literal_text: Human-readable schema text for the LLM context.
178
+
179
+ Returns:
180
+
181
+
182
+ JSON-formatted prompt string ready to send as the user message.
183
+ """
184
+ errors_to_fix: list[dict[str, str]] = []
185
+ for err in errors:
186
+ errors_to_fix.append(
187
+ {
188
+ "category": err.category,
189
+ "issue": err.message,
190
+ "fix": _resolve_repair_instruction(err),
191
+ }
192
+ )
193
+
194
+ suggestions: list[str] = [w.message for w in warnings]
195
+
196
+ prompt = stable_json(
197
+ {
198
+ "task": (
199
+ "Fix every error listed in errors_to_fix. Follow each "
200
+ "fix instruction exactly. Suggestions are optional "
201
+ "improvements. Output corrected intent JSON only."
202
+ ),
203
+ "critical_rules": [
204
+ "Do not use EXTRACT(EPOCH FROM ...). Use date column "
205
+ "subtraction or other supported date functions for time "
206
+ "differences.",
207
+ "Use only tables and columns that exist in schema_info.",
208
+ "Every non-aggregated column in select_cols must appear "
209
+ "in group_by_cols when grain is 'grouped'.",
210
+ "HAVING conditions require an aggregation function in "
211
+ "left_expr.",
212
+ "Filters without aggregation go in filters_param; filters "
213
+ "with aggregation go in having_param.",
214
+ "SUM and AVG apply only to numeric measure columns; use "
215
+ "COUNT for non-measure columns.",
216
+ "Nested aggregation like SUM(COUNT(...)) is forbidden; "
217
+ "use a CTE instead.",
218
+ "BETWEEN is not supported; decompose into >= and <= "
219
+ "filters.",
220
+ "NULL checks must use 'is null' or 'is not null' with "
221
+ "no value field.",
222
+ "For relative time windows ('last N days') use value_type "
223
+ "'date_window' with value {\"unit\": \"day\", \"offset\": N}. "
224
+ "For date-difference between two columns use value_type "
225
+ "'date_diff' with value {\"unit\": \"day\", \"amount\": N}.",
226
+ ],
227
+ "errors_to_fix": errors_to_fix,
228
+ "suggestions": suggestions,
229
+ "field_specifications": {
230
+ "tables": (
231
+ "Tables from schema needed for query. Sorted "
232
+ "alphabetically."
233
+ ),
234
+ "select_cols": (
235
+ "Array of {expr}. expr is a SQL expression string "
236
+ "built from fully-qualified columns such as "
237
+ "<table_1>.<column_1>. Use DISTINCT inside expr when "
238
+ "needed, for example "
239
+ "'COUNT(DISTINCT <table_1>.<column_1>)'."
240
+ ),
241
+ "group_by_cols": (
242
+ "GROUP BY columns in <table_1>.<column_1> format. "
243
+ "Required when select_cols has both aggregated and "
244
+ "non-aggregated columns."
245
+ ),
246
+ "order_by_cols": (
247
+ "Array of {expr, direction}. expr is a SQL "
248
+ "expression; direction is 'asc' or 'desc'."
249
+ ),
250
+ "filters_param": (
251
+ "Array of {left_expr, op, value_type, value} for "
252
+ "expr-vs-value or {left_expr, op, right_expr} for "
253
+ "expr-vs-expr. Optional: bool_op ('AND'|'OR', "
254
+ "default 'AND'), filter_group (integer)."
255
+ ),
256
+ "having_param": (
257
+ "Array of {left_expr, op, value_type, value} for "
258
+ "agg-vs-value or {left_expr, op, right_expr} for "
259
+ "agg-vs-agg. Optional: bool_op ('AND'|'OR', default "
260
+ "'AND'), filter_group (integer)."
261
+ ),
262
+ "limit": "Integer limit or null.",
263
+ },
264
+ "current_intent": current_intent_json,
265
+ "question": question,
266
+ "schema_info": schema_literal_text,
267
+ "output_format": {
268
+ "tables": ["<table_1>"],
269
+ "select_cols": [
270
+ {"expr": "COUNT(<table_1>.<column_1>)"},
271
+ ],
272
+ "group_by_cols": [],
273
+ "order_by_cols": [],
274
+ "filters_param": [],
275
+ "having_param": [],
276
+ "limit": None,
277
+ "cte_steps": [],
278
+ "natural_language": (
279
+ "<short direct-answer description of what this query "
280
+ "returns>"
281
+ ),
282
+ },
283
+ }
284
+ )
285
+ return prompt
286
+
287
+
288
+ def _build_intent_format_repair_prompt(question: str, raw_response: str, parse_error: str) -> str:
289
+ """Build a user-prompt for the LLM to fix JSON format errors in a prior response.
290
+
291
+ Args:
292
+
293
+
294
+ question: Original natural language question.
295
+ raw_response: The malformed LLM response that failed to parse.
296
+ parse_error: Short description of why the response failed to parse.
297
+
298
+ Returns:
299
+
300
+
301
+ JSON-formatted prompt string ready to send as the user message.
302
+ """
303
+ prompt = stable_json(
304
+ {
305
+ "task": "The previous response was not valid JSON. Fix the formatting errors and return ONLY valid JSON.",
306
+ "question": question,
307
+ "invalid_response": raw_response,
308
+ "parse_error": parse_error,
309
+ "field_specifications": {
310
+ "tables": "Array of string table names. REQUIRED.",
311
+ "select_cols": (
312
+ "Array of objects with {expr}. expr is a SQL "
313
+ "expression string that uses fully-qualified "
314
+ "columns such as <table_1>.<column_1>. Use DISTINCT "
315
+ "inside expr when needed. REQUIRED, non-empty."
316
+ ),
317
+ "group_by_cols": (
318
+ "Array of string column names in "
319
+ "<table_1>.<column_1> format. REQUIRED "
320
+ "(can be empty [])."
321
+ ),
322
+ "order_by_cols": (
323
+ "Array of objects with {expr, direction}. expr is a "
324
+ "SQL expression string; direction is 'asc' or "
325
+ "'desc'. REQUIRED (can be empty [])."
326
+ ),
327
+ "filters_param": (
328
+ "Array of objects with {left_expr, op, value_type, "
329
+ "value} for expr-vs-value or {left_expr, op, "
330
+ "right_expr} for expr-vs-expr. Optional: bool_op "
331
+ "(string, default 'AND'), filter_group "
332
+ "(integer|null). REQUIRED (can be empty [])."
333
+ ),
334
+ "having_param": (
335
+ "Array of objects with {left_expr, op, value_type, "
336
+ "value} for agg-vs-value or {left_expr, op, "
337
+ "right_expr} for agg-vs-agg. Optional: bool_op "
338
+ "(string, default 'AND'), filter_group "
339
+ "(integer|null). REQUIRED (can be empty [])."
340
+ ),
341
+ "limit": "Integer or null. REQUIRED.",
342
+ "cte_steps": (
343
+ "Array of CTE step objects. REQUIRED (can be "
344
+ "empty [])."
345
+ ),
346
+ },
347
+ "issue_action_mapping": {
348
+ "missing_field": "Add the missing field with appropriate default: [] for arrays, null for scalars, false for boolean.",
349
+ "invalid_json": "Fix syntax: ensure matching braces/brackets, proper commas, quoted strings.",
350
+ "null_array": "Replace null with empty array [].",
351
+ "invalid_type": "Convert to correct type: strings in quotes, numbers without quotes, boolean as true/false.",
352
+ "extra_text": "Remove any text before/after the JSON object.",
353
+ "truncated": "Complete the JSON structure with appropriate closing braces/brackets.",
354
+ },
355
+ "type_conversions": {
356
+ "tables": "string[]",
357
+ "select_cols": "object[] with {expr: string}",
358
+ "group_by_cols": "string[]",
359
+ "order_by_cols": "object[] with {expr: string, direction: string}",
360
+ "filters_param": "object[] with {left_expr: string, op: string, value_type: string, bool_op?: string, filter_group?: integer} or {left_expr: string, op: string, right_expr: string, bool_op?: string, filter_group?: integer}",
361
+ "having_param": "object[] with {left_expr: string, op: string, value_type: string, bool_op?: string, filter_group?: integer} or {left_expr: string, op: string, right_expr: string, bool_op?: string, filter_group?: integer}",
362
+ "limit": "integer|null",
363
+ "cte_steps": "object[]",
364
+ },
365
+ "instructions": [
366
+ "Return ONLY valid JSON - no explanation, no markdown.",
367
+ "Preserve the intent and all meaningful content.",
368
+ "Fix any syntax errors (missing commas, brackets, quotes).",
369
+ "Use empty arrays [] for list fields, not null.",
370
+ "Use null for optional scalar fields.",
371
+ "Ensure all required fields are present.",
372
+ "Remove any trailing commas before closing braces/brackets.",
373
+ ],
374
+ "output_format": {
375
+ "tables": ["<table_name>"],
376
+ "select_cols": [{"expr": "<table_name>.<column_name>"}],
377
+ "group_by_cols": [],
378
+ "order_by_cols": [],
379
+ "filters_param": [],
380
+ "having_param": [],
381
+ "limit": None,
382
+ "cte_steps": [],
383
+ },
384
+ }
385
+ )
386
+ return prompt
387
+
388
+
389
+ def _build_intent_parse_prompt(
390
+ question: str,
391
+ schema_literal_text: str,
392
+ table_list: list[str],
393
+ ) -> tuple[str, str]:
394
+ """Build system and user messages for the intent parsing call."""
395
+ system = (
396
+ "You are a deterministic intent parser for text-to-SQL. "
397
+ "Output ONLY valid JSON that matches the required format. "
398
+ "Identical inputs must produce identical outputs. The "
399
+ "natural_language field is a short direct description of what "
400
+ "the query returns, without boilerplate phrases."
401
+ )
402
+
403
+ user_payload: dict[str, object] = {
404
+ "task": (
405
+ "Parse the question into a schema-aware intent JSON. "
406
+ "Do not write SQL. Extract all literals needed for "
407
+ "filters, HAVING, limits, and parameters."
408
+ ),
409
+ "question": question,
410
+ "schema_summary": schema_literal_text,
411
+ "allowed_tables": table_list,
412
+ "naming_conventions": {
413
+ "table_placeholder": "<table_1>",
414
+ "column_placeholder": "<column_1>",
415
+ "date_column_placeholder": "<date_column_1>",
416
+ "value_placeholder": "<value_from_question>",
417
+ },
418
+ "field_specifications": {
419
+ "tables": (
420
+ "All base tables and CTE names whose columns appear in "
421
+ "select_cols, group_by_cols, order_by_cols, "
422
+ "filters_param, or having_param. Do not reason about "
423
+ "joins here; include every table that owns a referenced "
424
+ "column. Sort alphabetically."
425
+ ),
426
+ "select_cols": (
427
+ "Array of objects {expr}. expr is a SQL expression "
428
+ "string that uses fully-qualified columns "
429
+ "<table_n>.<column_m>. Use DISTINCT inside expr when "
430
+ "the question asks for distinct values. The array "
431
+ "must be non-empty."
432
+ ),
433
+ "group_by_cols": (
434
+ "Array of SQL expressions in <table_n>.<column_m> "
435
+ "format. Required when select_cols mixes aggregated "
436
+ "and non-aggregated expressions."
437
+ ),
438
+ "order_by_cols": (
439
+ "Array of objects {expr, direction}. expr is the SQL "
440
+ "expression only (no ASC/DESC in expr). direction is "
441
+ "the only place for sort direction: 'asc' or 'desc'. "
442
+ "Include only when the question implies ordering."
443
+ ),
444
+ "filters_param": (
445
+ "Array of all row-level filter conditions implied by "
446
+ "the question. Use either "
447
+ "{left_expr, op, value_type, value} for expr-vs-value "
448
+ "or {left_expr, op, right_expr} for expr-vs-expr. "
449
+ "For relative time windows (e.g. 'last 90 days', 'past 6 months'), "
450
+ "use value_type 'date_window' with the date column in left_expr, "
451
+ "op '>=' and value as {\"unit\": \"day\", \"offset\": 90}. "
452
+ "For date-difference comparisons between two columns "
453
+ "(e.g. 'return_date - rental_date > 7 days'), "
454
+ "use value_type 'date_diff' with left_expr as the subtraction "
455
+ "and value as {\"unit\": \"day\", \"amount\": 7}. "
456
+ "Use singular unit names (day, week, month, year). "
457
+ "Use bool_op ('AND'|'OR', default 'AND') to connect "
458
+ "to the next filter. Use filter_group (integer) to "
459
+ "group filters that share parentheses."
460
+ ),
461
+ "having_param": (
462
+ "Array of all aggregate-level filter conditions. Use "
463
+ "{left_expr, op, value_type, value} for agg-vs-value "
464
+ "or {left_expr, op, right_expr} for agg-vs-agg. "
465
+ "left_expr and right_expr must contain aggregation."
466
+ ),
467
+ "limit": (
468
+ "Integer limit or null. Use a limit for 'top N', "
469
+ "'first N', or similar phrases. Do not reuse the "
470
+ "same number as both limit and filter value."
471
+ ),
472
+ "natural_language": (
473
+ "Short direct description of what the query returns. "
474
+ "For example: 'total <column_1> per <column_2>' or "
475
+ "'rows from <table_1> where <column_1> equals "
476
+ "<value_from_question>'. No filler words."
477
+ ),
478
+ "cte_steps": (
479
+ "Use CTEs only when the question explicitly describes "
480
+ "multiple steps or when an intermediate aggregation is "
481
+ "reused. Each CTE has: cte_name, description, tables, "
482
+ "select_cols, group_by_cols, order_by_cols, "
483
+ "filters_param, having_param, output_columns, limit. "
484
+ "For every CTE select_cols[i], you must provide "
485
+ "output_columns[i] as a simple alias (e.g. <measure_1>, "
486
+ "<count_rows>) and use only these aliases in any main "
487
+ "query expressions that reference CTE outputs."
488
+ ),
489
+ },
490
+ "expression_format": {
491
+ "description": (
492
+ "expr is a SQL expression built from fully-qualified "
493
+ "columns, arithmetic, aggregation, and scalar "
494
+ "functions. Use the same structure that the SQL "
495
+ "generator should ultimately implement."
496
+ ),
497
+ "rules": [
498
+ "Always qualify columns as <table_n>.<column_m>.",
499
+ "Bare column: <table_1>.<column_1>.",
500
+ "Aggregation: AGG(<table_1>.<column_1>) where "
501
+ "AGG is COUNT, SUM, AVG, MIN, or MAX.",
502
+ "Arithmetic: <table_1>.<column_1> * "
503
+ "<table_1>.<column_2>, "
504
+ "<table_1>.<column_1> + <table_1>.<column_2>, "
505
+ "<table_1>.<column_1> / <table_1>.<column_2>. "
506
+ "Date difference: <table_1>.<date_col> - <table_1>.<date_col>.",
507
+ "Do not use EXTRACT(EPOCH FROM ...) for time differences; "
508
+ "use date column subtraction or other supported date functions.",
509
+ "Aggregation over arithmetic: "
510
+ "SUM(<table_1>.<column_1> * <table_1>.<column_2>).",
511
+ "Scalar over aggregation: "
512
+ "ROUND(SUM(<table_1>.<column_1>), 2).",
513
+ "COUNT all rows: COUNT(*).",
514
+ "Distinct count: COUNT(DISTINCT <table_1>.<column_1>).",
515
+ ],
516
+ },
517
+ "rules": [
518
+ "All fields in output_format are required. Use [] for "
519
+ "arrays and null for optional scalars when empty.",
520
+ "Always qualify columns with their table. Never output "
521
+ "bare column names.",
522
+ "Do not invent tables, columns, filters, or HAVING "
523
+ "conditions that are not supported by schema_summary "
524
+ "or implied by the question.",
525
+ "Never include join conditions in filters_param or "
526
+ "having_param. Joins are handled separately.",
527
+ "Grain: 'row_level' when there is no aggregation, "
528
+ "'grouped' when there is aggregation and grouping, "
529
+ "'scalar' for a single aggregated value with no GROUP BY.",
530
+ "If the question cannot be answered using only the given "
531
+ "tables and columns, set intent_status to schema_invalid "
532
+ "and leave all other fields empty or null. Do not invent "
533
+ "tables or columns.",
534
+ ],
535
+ "intent_status": (
536
+ "Set to 'ok' (default) when the question is answerable "
537
+ "with the schema. Set to 'schema_invalid' when the "
538
+ "question cannot be answered using only the given tables "
539
+ "and columns."
540
+ ),
541
+ "answer_style_guidance": [
542
+ "First decide whether the question expects a single scalar "
543
+ "answer, a grouped summary, or a list of rows.",
544
+ "Scalar answers: when the question asks for a single total, "
545
+ "count, average, minimum, or maximum without grouping.",
546
+ "Grouped summaries: when the question uses phrases like "
547
+ "'per <entity>', 'by <entity>', or 'for each <entity>'.",
548
+ "Row-level lists: when the question asks to list or show "
549
+ "individual records satisfying conditions.",
550
+ "Avoid tables which are not semantically related to the "
551
+ "question and would provide bad or irrelevant context.",
552
+ "For grouped or row-level results, include at least one "
553
+ "identifying or descriptive column that lets the user "
554
+ "distinguish rows (for example, a primary key or name "
555
+ "column) unless the question clearly asks for only a "
556
+ "scalar.",
557
+ "Start from the minimal set of columns needed to answer the "
558
+ "question and uniquely identify each row. Add extra "
559
+ "descriptive columns only when the question explicitly "
560
+ "refers to them (e.g. names, titles) or when schema "
561
+ "descriptions indicate they are essential for context.",
562
+ "Questions that compare to an average (e.g. more X than "
563
+ "average, above average) often need a multi-step structure: "
564
+ "use cte_steps to compute the average or aggregate first, "
565
+ "then compare in the main query.",
566
+ ],
567
+ "output_format": {
568
+ "intent_status": "ok",
569
+ "tables": ["<table_1>", "<table_2>"],
570
+ "select_cols": [
571
+ {"expr": "<table_1>.<column_1>"},
572
+ {"expr": "COUNT(<table_1>.<column_2>)"},
573
+ ],
574
+ "group_by_cols": ["<table_1>.<column_1>"],
575
+ "order_by_cols": [
576
+ {"expr": "<table_1>.<column_1>", "direction": "asc"},
577
+ ],
578
+ "filters_param": [
579
+ {
580
+ "left_expr": "<table_1>.<column_3>",
581
+ "op": "=",
582
+ "value_type": "string",
583
+ "value": "<value_from_question>",
584
+ }
585
+ ],
586
+ "having_param": [],
587
+ "limit": None,
588
+ "natural_language": (
589
+ "<short direct description of the query>"
590
+ ),
591
+ "cte_steps": [],
592
+ },
593
+ "operator_reference": {
594
+ "filter_ops": [
595
+ "=",
596
+ "!=",
597
+ "<",
598
+ "<=",
599
+ ">",
600
+ ">=",
601
+ "like",
602
+ "not like",
603
+ "in",
604
+ "not in",
605
+ "is null",
606
+ "is not null",
607
+ ],
608
+ "having_ops": ["=", "!=", "<", "<=", ">", ">="],
609
+ },
610
+ "value_type_reference": {
611
+ "filter": [
612
+ "string",
613
+ "integer",
614
+ "number",
615
+ "date",
616
+ "boolean",
617
+ "null",
618
+ "date_window",
619
+ "date_diff",
620
+ ],
621
+ "having": ["integer", "number"],
622
+ },
623
+ }
624
+
625
+ user = stable_json(user_payload)
626
+ return system, user
627
+
628
+
629
+ def _format_repair_loop(system: str, raw: str, question: str, max_retries: int = 3) -> tuple[RuntimeIntent | None, int]:
630
+ """Attempt to parse a raw LLM response into a RuntimeIntent, retrying with format repair.
631
+
632
+ Args:
633
+
634
+
635
+ system: System prompt for the LLM.
636
+ raw: Raw LLM response string to parse.
637
+ question: Original natural language question.
638
+ max_retries: Maximum number of format-repair LLM attempts on JSON parse failure.
639
+
640
+ Returns:
641
+
642
+
643
+ Tuple of (parsed RuntimeIntent or None, number of LLM calls made during format repair).
644
+ """
645
+ intent = parse_intent_response(raw, question)
646
+ if intent:
647
+ return intent, 0
648
+ llm_calls = 0
649
+ for _ in range(max_retries):
650
+ repair_prompt = _build_intent_format_repair_prompt(question, raw, "JSON parse failed")
651
+ raw = llm_chat(system, repair_prompt, task="intent")
652
+ llm_calls += 1
653
+ intent = parse_intent_response(raw, question)
654
+ if intent:
655
+ return intent, llm_calls
656
+ return None, llm_calls
657
+
658
+
659
+ def _compute_error_signature_issues(issues: list[IntentIssue]) -> frozenset[str]:
660
+ """Compute a hashable signature from a list of IntentIssue objects.
661
+
662
+ Args:
663
+
664
+
665
+ issues: List of IntentIssue instances, typically errors.
666
+
667
+ Returns:
668
+
669
+
670
+ Frozenset of (category, message) pairs for set comparison.
671
+ """
672
+ return frozenset((iss.category, iss.message) for iss in issues)
673
+
674
+
675
+ def _compute_error_signature_strings(errors: list[str]) -> frozenset[str]:
676
+ """Compute a hashable signature from a list of error strings.
677
+
678
+ Args:
679
+
680
+
681
+ errors: List of raw error message strings.
682
+
683
+ Returns:
684
+
685
+
686
+ Frozenset of the error strings for set comparison.
687
+ """
688
+ return frozenset(errors)
689
+
690
+
691
+ def _detect_oscillation(history: list[frozenset[str]]) -> bool:
692
+ """Detect AA or ABAB oscillation patterns in error signature history.
693
+
694
+ AA means the same error set appeared consecutively and ABAB means the error set alternates between two states.
695
+
696
+ Args:
697
+
698
+
699
+ history: Ordered list of error signatures from successive rounds.
700
+
701
+ Returns:
702
+
703
+
704
+ True if oscillation is detected and repair should terminate.
705
+ """
706
+ if len(history) >= 2 and history[-1] == history[-2]:
707
+ return True
708
+ if len(history) >= 4 and history[-1] == history[-3] and history[-2] == history[-4]:
709
+ return True
710
+ return False
711
+
712
+
713
+ def _normalize_cte_output_aliases(
714
+ intent: RuntimeIntent,
715
+ schema_graph: SchemaGraph,
716
+ ) -> RuntimeIntent:
717
+ """Replace LLM-provided CTE output column aliases with deterministic aliases derived from ``select_cols``.
718
+
719
+ Builds a mapping from LLM alias to deterministic alias for each CTE and rewrites all downstream references in the main query and later CTE steps so the entire intent is consistent with the canonical names.
720
+ This must run before qualify_cte_output_columns to avoid stale alias lookups.
721
+
722
+ Args:
723
+
724
+ intent: RuntimeIntent whose CTE aliases should be normalized.
725
+ schema_graph: SchemaGraph providing table and column metadata.
726
+
727
+ Returns:
728
+
729
+ RuntimeIntent with updated CTE output aliases and remapped references.
730
+ """
731
+ cte_steps = intent.cte_steps or []
732
+ if not cte_steps:
733
+ return intent
734
+
735
+ alias_map: dict[str, str] = {}
736
+ refreshed_cte_steps = []
737
+ for cte in cte_steps:
738
+ old_oc = list(cte.output_columns or [])
739
+ new_oc = derive_cte_output_columns(cte.select_cols or [])
740
+ ocm = build_cte_output_metadata(
741
+ cte.select_cols or [], new_oc, schema_graph,
742
+ )
743
+ for old_col, new_col in zip(old_oc, new_oc, strict=False):
744
+ if old_col != new_col:
745
+ alias_map[f"{cte.cte_name}.{old_col}"] = (
746
+ f"{cte.cte_name}.{new_col}"
747
+ )
748
+ refreshed_cte_steps.append(
749
+ replace(
750
+ cte,
751
+ output_columns=new_oc,
752
+ output_column_metadata=ocm,
753
+ )
754
+ )
755
+ intent = replace(intent, cte_steps=refreshed_cte_steps)
756
+
757
+ if not alias_map:
758
+ return intent
759
+
760
+ debug(f"[_normalize_cte_output_aliases] remap: {alias_map}")
761
+
762
+ def _remap_alias(s: str) -> str:
763
+ return alias_map.get(s, s)
764
+
765
+ def _remap_expr(expr: NormalizedExpr) -> NormalizedExpr:
766
+ return replace_refs_in_expr(expr, _remap_alias)
767
+
768
+ def _remap_cte_step(cte: RuntimeCteStep) -> RuntimeCteStep:
769
+ return replace(
770
+ cte,
771
+ select_cols=[
772
+ replace(sc, expr=_remap_expr(sc.expr))
773
+ for sc in (cte.select_cols or [])
774
+ ],
775
+ order_by_cols=[
776
+ replace(obc, expr=_remap_expr(obc.expr))
777
+ for obc in (cte.order_by_cols or [])
778
+ ],
779
+ group_by_cols=[_remap_expr(g) for g in (cte.group_by_cols or [])],
780
+ filters_param=[
781
+ replace(
782
+ fp,
783
+ left_expr=_remap_expr(fp.left_expr),
784
+ right_expr=(
785
+ _remap_expr(fp.right_expr) if fp.right_expr else None
786
+ ),
787
+ )
788
+ for fp in (cte.filters_param or [])
789
+ ],
790
+ having_param=[
791
+ replace(
792
+ hp,
793
+ left_expr=_remap_expr(hp.left_expr),
794
+ right_expr=(
795
+ _remap_expr(hp.right_expr) if hp.right_expr else None
796
+ ),
797
+ )
798
+ for hp in (cte.having_param or [])
799
+ ],
800
+ )
801
+
802
+ refreshed_cte_steps = [_remap_cte_step(cte) for cte in intent.cte_steps]
803
+
804
+ intent = replace(
805
+ intent,
806
+ cte_steps=refreshed_cte_steps,
807
+ select_cols=[
808
+ replace(sc, expr=_remap_expr(sc.expr))
809
+ for sc in (intent.select_cols or [])
810
+ ],
811
+ order_by_cols=[
812
+ replace(obc, expr=_remap_expr(obc.expr))
813
+ for obc in (intent.order_by_cols or [])
814
+ ],
815
+ group_by_cols=[
816
+ _remap_expr(g) for g in (intent.group_by_cols or [])
817
+ ],
818
+ filters_param=[
819
+ replace(
820
+ fp,
821
+ left_expr=_remap_expr(fp.left_expr),
822
+ right_expr=(
823
+ _remap_expr(fp.right_expr) if fp.right_expr else None
824
+ ),
825
+ )
826
+ for fp in (intent.filters_param or [])
827
+ ],
828
+ having_param=[
829
+ replace(
830
+ hp,
831
+ left_expr=_remap_expr(hp.left_expr),
832
+ right_expr=(
833
+ _remap_expr(hp.right_expr) if hp.right_expr else None
834
+ ),
835
+ )
836
+ for hp in (intent.having_param or [])
837
+ ],
838
+ )
839
+
840
+ return intent
841
+
842
+
843
+ def _apply_deterministic_repairs(
844
+ intent: RuntimeIntent,
845
+ schema_graph: SchemaGraph,
846
+ natural_language: str = "",
847
+ ) -> RuntimeIntent:
848
+ """Apply the full deterministic normalization and repair chain to a RuntimeIntent.
849
+
850
+ Runs normalizations such as count-star and CTE names, grain enforcement, filter and having normalization, join condition stripping, CTE processing, sorting, simplification, BETWEEN decomposition, and auto filter and having repair.
851
+
852
+ Args:
853
+
854
+ intent: RuntimeIntent to repair.
855
+ schema_graph: SchemaGraph providing table and column metadata.
856
+ natural_language: Original user question, reserved for future use.
857
+
858
+ Returns:
859
+
860
+ Repaired RuntimeIntent.
861
+ """
862
+ intent = normalize_count_star(intent)
863
+ intent = normalize_cte_names(intent)
864
+ intent = _normalize_cte_output_aliases(intent, schema_graph)
865
+ intent = qualify_cte_output_columns(intent)
866
+ intent = sanitize_table_names(intent, schema_graph)
867
+ intent = force_main_grain_when_using_grouped_cte(intent)
868
+ intent = enforce_grain_consistency(intent, schema_graph)
869
+ intent = strip_spurious_group_by(intent)
870
+ intent = normalize_filters_havings(intent)
871
+ intent = repair_null_equality_filters(intent)
872
+ intent = strip_join_conditions(intent, schema_graph)
873
+
874
+ new_cte_steps = []
875
+ for cte in intent.cte_steps or []:
876
+ cte = enforce_cte_grain_consistency(cte)
877
+ cte = replace(cte, select_cols=sort_select_cols(cte.select_cols or []))
878
+ cte = replace(cte, order_by_cols=sort_order_by_cols(cte.order_by_cols or []))
879
+ new_cte_steps.append(cte)
880
+ intent = replace(intent, cte_steps=new_cte_steps)
881
+
882
+ intent = replace(intent, select_cols=sort_select_cols(intent.select_cols or []))
883
+ intent = replace(intent, order_by_cols=sort_order_by_cols(intent.order_by_cols or []))
884
+
885
+ intent = simplify_exprs(intent)
886
+
887
+ intent = normalize_in_raw_values(intent)
888
+
889
+ intent = repair_misclassified_date_diff(intent)
890
+
891
+ intent = normalize_date_diff_raw_values(intent)
892
+
893
+ intent = decompose_between_params(intent)
894
+
895
+ main_cte_names = {c.cte_name for c in intent.cte_steps or []}
896
+
897
+ def process(
898
+ fp: list[FilterParam],
899
+ hp: list[HavingParam],
900
+ cte_names: set[str] | None = None,
901
+ ) -> tuple[list[FilterParam], list[HavingParam]]:
902
+ return auto_repair_filter_having(fp, hp, cte_names=cte_names)
903
+
904
+ repaired_fp, repaired_hp = process(
905
+ intent.filters_param or [], intent.having_param or [], cte_names=main_cte_names
906
+ )
907
+ intent = replace(intent, filters_param=repaired_fp, having_param=repaired_hp)
908
+ if intent.cte_steps:
909
+ new_cte_steps = []
910
+ preceding: set[str] = set()
911
+ for cte in intent.cte_steps:
912
+ fp, hp = process(
913
+ cte.filters_param or [],
914
+ cte.having_param or [],
915
+ cte_names=preceding,
916
+ )
917
+ new_cte_steps.append(replace(cte, filters_param=fp, having_param=hp))
918
+ preceding.add(cte.cte_name)
919
+ intent = replace(intent, cte_steps=new_cte_steps)
920
+
921
+ intent = strip_impossible_having(intent)
922
+ intent = repair_fk_filter_type_mismatch(intent, schema_graph)
923
+ intent = resolve_filter_value_case(intent, schema_graph, natural_language)
924
+ intent = normalize_in_filter_types(intent, schema_graph)
925
+ intent = normalize_boolean_filter_values(intent, schema_graph)
926
+ intent = normalize_null_filter_values(intent)
927
+ intent = expand_fk_select_to_descriptive(intent, schema_graph)
928
+ intent = dedup_contradictory_filters(intent)
929
+ intent = requalify_redundant_pk_references(intent, schema_graph)
930
+
931
+ return intent
932
+
933
+
934
+ def _validate_cte_fk_connectivity(
935
+ intent: RuntimeIntent, schema: SchemaGraph,
936
+ ) -> RuntimeIntent:
937
+ """Ensure CTE steps with multiple tables form connected FK subgraphs.
938
+
939
+ For each CTE step referencing two or more tables, verifies that
940
+ every pair of tables has at least one FK join path in the schema.
941
+ Disconnected tables are removed from the CTE step and logged.
942
+
943
+ Args:
944
+ intent: The runtime intent with CTE steps to validate.
945
+ schema: The schema graph with ``join_paths_multi``.
946
+
947
+ Returns:
948
+ Updated intent with CTE table lists pruned of unreachable
949
+ tables.
950
+ """
951
+ cte_steps = intent.cte_steps or []
952
+ if not cte_steps:
953
+ return intent
954
+
955
+ updated_steps = []
956
+ changed = False
957
+ for cte in cte_steps:
958
+ tables = cte.tables or []
959
+ if len(tables) < 2:
960
+ updated_steps.append(cte)
961
+ continue
962
+ connected = _fk_connected_tables(tables, schema)
963
+ if connected == set(tables):
964
+ updated_steps.append(cte)
965
+ continue
966
+ pruned = sorted(connected)
967
+ debug(
968
+ f"[intent_process._validate_cte_fk_connectivity] CTE '{cte.cte_name}': "
969
+ f"pruned {set(tables) - connected} (no FK path)"
970
+ )
971
+ updated_steps.append(replace(cte, tables=pruned))
972
+ changed = True
973
+
974
+ if changed:
975
+ intent = replace(intent, cte_steps=updated_steps)
976
+ return intent
977
+
978
+
979
+ def _fk_connected_tables(
980
+ tables: list[str], schema: SchemaGraph,
981
+ ) -> set[str]:
982
+ """Find the largest FK-connected subset starting from the first table.
983
+
984
+ Uses the schema's ``join_paths_multi`` to determine which tables
985
+ are reachable from the first table through FK edges.
986
+
987
+ Args:
988
+ tables: Table names to check connectivity for.
989
+ schema: The schema graph.
990
+
991
+ Returns:
992
+ Set of table names reachable from ``tables[0]`` via FK paths.
993
+ """
994
+ if not tables:
995
+ return set()
996
+ visited: set[str] = {tables[0]}
997
+ queue = [tables[0]]
998
+ target_set = set(tables)
999
+ while queue:
1000
+ current = queue.pop(0)
1001
+ for other in target_set - visited:
1002
+ paths = schema.join_paths_multi.get(current, {}).get(other, [])
1003
+ reverse_paths = schema.join_paths_multi.get(other, {}).get(current, [])
1004
+ if paths or reverse_paths:
1005
+ visited.add(other)
1006
+ queue.append(other)
1007
+ return visited
1008
+
1009
+
1010
+ def _apply_post_processing(intent: RuntimeIntent, schema_graph: SchemaGraph, question: str) -> RuntimeIntent | None:
1011
+ """Apply post-semantic-loop processing including column resolution, CTE wiring, foreign key repairs, and parameter extraction.
1012
+
1013
+ Runs column map resolution, CTE output column derivation, alias remapping, foreign key filter and select repair, parameter key assignment, table pruning, primary key distinct normalization, numeric tagging, parameter value collection, scalar function defaults, and structural parameter extraction and returns None on structural failures such as missing parameter values.
1014
+
1015
+ Args:
1016
+
1017
+ intent: RuntimeIntent after the semantic loop.
1018
+ schema_graph: SchemaGraph providing table and column metadata.
1019
+ question: Original natural language question.
1020
+
1021
+ Returns:
1022
+
1023
+ Processed RuntimeIntent, or None on structural failure.
1024
+ """
1025
+ all_cols = []
1026
+ for sc in intent.select_cols or []:
1027
+ all_cols.extend(extract_columns_from_expr(sc.expr))
1028
+ for obc in intent.order_by_cols or []:
1029
+ all_cols.extend(extract_columns_from_expr(obc.expr))
1030
+ for g in intent.group_by_cols or []:
1031
+ all_cols.extend(extract_columns_from_expr(g))
1032
+ for fp in intent.filters_param or []:
1033
+ all_cols.extend(extract_columns_from_expr(fp.left_expr))
1034
+ if fp.right_expr:
1035
+ all_cols.extend(extract_columns_from_expr(fp.right_expr))
1036
+
1037
+ column_map = resolve_column_map(all_cols, schema_graph, intent.tables or [])
1038
+ intent = replace(intent, column_map=column_map)
1039
+
1040
+ if intent.cte_steps:
1041
+ intent = replace(intent, cte_steps=resolve_cte_column_maps(intent.cte_steps))
1042
+
1043
+ if intent.cte_steps:
1044
+ debug(f"[_apply_post_processing] CTE chain: {[c.cte_name for c in intent.cte_steps]}")
1045
+ alias_map: dict[str, str] = {}
1046
+ refreshed_cte_steps = []
1047
+ for cte in intent.cte_steps:
1048
+ old_oc = list(cte.output_columns or [])
1049
+ new_oc = derive_cte_output_columns(cte.select_cols or [])
1050
+ ocm = build_cte_output_metadata(cte.select_cols or [], new_oc, schema_graph)
1051
+ debug(
1052
+ f"[_apply_post_processing] CTE '{cte.cte_name}' "
1053
+ f"tables={cte.tables} grain={cte.grain} "
1054
+ f"old_oc={old_oc} new_oc={new_oc}"
1055
+ )
1056
+ for old_col, new_col in zip(old_oc, new_oc, strict=False):
1057
+ if old_col != new_col:
1058
+ alias_map[f"{cte.cte_name}.{old_col}"] = f"{cte.cte_name}.{new_col}"
1059
+ refreshed_cte_steps.append(replace(cte, output_columns=new_oc, output_column_metadata=ocm))
1060
+ intent = replace(intent, cte_steps=refreshed_cte_steps)
1061
+ if alias_map:
1062
+ debug(f"[_apply_post_processing] CTE alias remap: {alias_map}")
1063
+
1064
+ def _remap_alias(s: str) -> str:
1065
+ return alias_map.get(s, s)
1066
+
1067
+ def _remap_expr(expr: NormalizedExpr) -> NormalizedExpr:
1068
+ return replace_refs_in_expr(expr, _remap_alias)
1069
+
1070
+ intent = replace(
1071
+ intent,
1072
+ select_cols=[replace(sc, expr=_remap_expr(sc.expr)) for sc in (intent.select_cols or [])],
1073
+ order_by_cols=[replace(obc, expr=_remap_expr(obc.expr)) for obc in (intent.order_by_cols or [])],
1074
+ group_by_cols=[_remap_expr(g) for g in (intent.group_by_cols or [])],
1075
+ filters_param=[
1076
+ replace(
1077
+ fp,
1078
+ left_expr=_remap_expr(fp.left_expr),
1079
+ right_expr=(_remap_expr(fp.right_expr) if fp.right_expr else None),
1080
+ )
1081
+ for fp in (intent.filters_param or [])
1082
+ ],
1083
+ having_param=[
1084
+ replace(
1085
+ hp,
1086
+ left_expr=_remap_expr(hp.left_expr),
1087
+ right_expr=(_remap_expr(hp.right_expr) if hp.right_expr else None),
1088
+ )
1089
+ for hp in (intent.having_param or [])
1090
+ ],
1091
+ )
1092
+
1093
+ if intent.cte_steps:
1094
+ intent = qualify_cte_output_columns(intent)
1095
+
1096
+ intent = rewrite_cte_output_refs_to_aliases(intent)
1097
+
1098
+ filters_param, having_param, cte_steps, _ = assign_param_keys(
1099
+ intent.filters_param or [],
1100
+ intent.having_param or [],
1101
+ intent.cte_steps,
1102
+ )
1103
+ intent = replace(
1104
+ intent,
1105
+ filters_param=filters_param,
1106
+ having_param=having_param,
1107
+ cte_steps=cte_steps,
1108
+ )
1109
+
1110
+ intent = prune_unreferenced_tables(intent, schema_graph=schema_graph)
1111
+ intent = sanitize_table_names(intent, schema_graph)
1112
+ intent = _validate_cte_fk_connectivity(intent, schema_graph)
1113
+ intent = normalize_pk_distinct(intent, schema_graph)
1114
+
1115
+ intent = tag_expr_numeric(intent, schema_graph)
1116
+
1117
+ all_pv = collect_raw_param_values(intent)
1118
+
1119
+ expected_keys: list[str] = []
1120
+ for cte in intent.cte_steps or []:
1121
+ for fp in cte.filters_param or []:
1122
+ if fp.param_key and fp.op not in ("is null", "is not null") and not fp.right_expr:
1123
+ expected_keys.append(fp.param_key)
1124
+ for hp in cte.having_param or []:
1125
+ if hp.param_key and not hp.right_expr:
1126
+ expected_keys.append(hp.param_key)
1127
+ for fp in intent.filters_param or []:
1128
+ if fp.param_key and fp.op not in ("is null", "is not null") and not fp.right_expr:
1129
+ expected_keys.append(fp.param_key)
1130
+ for hp in intent.having_param or []:
1131
+ if hp.param_key and not hp.right_expr:
1132
+ expected_keys.append(hp.param_key)
1133
+ missing_keys = [k for k in expected_keys if k not in all_pv]
1134
+ if missing_keys:
1135
+ debug(f"[intent_process.apply_post_processing] missing_param_values — auto-terminating: {missing_keys}")
1136
+ return None
1137
+
1138
+ intent = replace(intent, param_values=all_pv)
1139
+
1140
+ if intent.cte_steps:
1141
+ new_cte_steps = []
1142
+ for cte in intent.cte_steps:
1143
+ cte_pks: set[str] = set()
1144
+ for fp in cte.filters_param or []:
1145
+ if fp.param_key:
1146
+ cte_pks.add(fp.param_key)
1147
+ for hp in cte.having_param or []:
1148
+ if hp.param_key:
1149
+ cte_pks.add(hp.param_key)
1150
+ cte_pv = {k: v for k, v in all_pv.items() if k in cte_pks}
1151
+ new_cte_steps.append(replace(cte, param_values=cte_pv))
1152
+ intent = replace(intent, cte_steps=new_cte_steps)
1153
+
1154
+ intent = ensure_scalar_func_defaults(intent)
1155
+ intent = extract_structural_params(intent)
1156
+
1157
+ return intent
1158
+
1159
+
1160
+ def _attempt_fresh_restart(
1161
+ question: str,
1162
+ system: str,
1163
+ schema_literal_text: str,
1164
+ table_list: list[str],
1165
+ failure_hints: list[str],
1166
+ max_retries: int,
1167
+ schema_graph: SchemaGraph,
1168
+ llm_calls: int,
1169
+ ) -> tuple[RuntimeIntent | None, list[str], int]:
1170
+ """Make one fresh LLM call with accumulated failure hints.
1171
+
1172
+ Called when the normal repair loops have been exhausted or oscillation was detected and builds an enriched prompt that includes the previous failure messages so the LLM can avoid repeating the same mistakes.
1173
+
1174
+ Args:
1175
+
1176
+ question: Original natural language question.
1177
+ system: System prompt for the LLM.
1178
+ schema_literal_text: Human-readable schema text.
1179
+ table_list: List of table names from the schema.
1180
+ failure_hints: Accumulated error messages from previous rounds.
1181
+ max_retries: Maximum number of format-repair retries.
1182
+ schema_graph: SchemaGraph for deterministic repairs and post-processing.
1183
+ llm_calls: Running count of LLM calls made so far.
1184
+
1185
+ Returns:
1186
+
1187
+ Tuple of (RuntimeIntent or None, semantic_warnings, total_llm_calls).
1188
+ """
1189
+ if not failure_hints:
1190
+ return None, [], llm_calls
1191
+
1192
+ debug(f"[intent_process._attempt_fresh_restart] attempting fresh restart with {len(failure_hints)} failure hints")
1193
+
1194
+ deduped_hints = list(dict.fromkeys(failure_hints))[:10]
1195
+
1196
+ _, fresh_user = _build_intent_parse_prompt(question, schema_literal_text, table_list)
1197
+ hint_block = "\n".join(f"- {h}" for h in deduped_hints)
1198
+ augmented_user = fresh_user + "\n\nPrevious attempts failed with these issues — avoid them:\n" + hint_block
1199
+
1200
+ raw = llm_chat(system, augmented_user, task="intent")
1201
+ llm_calls += 1
1202
+ intent, fmt_calls = _format_repair_loop(system, raw, question, max_retries)
1203
+ llm_calls += fmt_calls
1204
+
1205
+ if not intent:
1206
+ debug("[intent_process._attempt_fresh_restart] fresh restart format repair failed")
1207
+ return None, [], llm_calls
1208
+
1209
+ intent = _apply_deterministic_repairs(intent, schema_graph, question)
1210
+ intent, schema_errors = enforce_schema(intent, schema_graph)
1211
+ if schema_errors:
1212
+ debug(
1213
+ f"[intent_process._attempt_fresh_restart] fresh restart has {len(schema_errors)} schema errors — giving up"
1214
+ )
1215
+ return None, [], llm_calls
1216
+
1217
+ validation_result = validate_semantics(intent, schema_graph)
1218
+ errors = [iss for iss in validation_result.issues if iss.severity == "error"]
1219
+ warnings = [iss for iss in validation_result.issues if iss.severity == "warning"]
1220
+
1221
+ if errors:
1222
+ debug(f"[intent_process._attempt_fresh_restart] fresh restart has {len(errors)} semantic errors — giving up")
1223
+ return None, [], llm_calls
1224
+
1225
+ result = _apply_post_processing(intent, schema_graph, question)
1226
+ if result is None:
1227
+ return None, [], llm_calls
1228
+
1229
+ semantic_warnings = [w.message for w in warnings]
1230
+ debug(
1231
+ f"[intent_process._attempt_fresh_restart] fresh restart succeeded "
1232
+ f"with {len(result.tables or [])} tables, {llm_calls} total LLM calls"
1233
+ )
1234
+ return result, semantic_warnings, llm_calls
1235
+
1236
+
1237
+ def full_intent_parse(
1238
+ question: str, schema_graph: SchemaGraph, max_retries: int = 3
1239
+ ) -> tuple[RuntimeIntent | None, list[str], int]:
1240
+ """Parse a natural language question into a RuntimeIntent.
1241
+
1242
+ Runs the LLM intent-parse prompt and then applies interleaved format and semantic repair loops.
1243
+ Schema errors are handled in a nested sub-loop that does not consume semantic repair rounds and oscillation detection terminates repair early when the same error set repeats or alternates.
1244
+ When all repair rounds are exhausted or oscillation is detected, a single fresh-restart LLM call is attempted with accumulated failure hints before giving up.
1245
+
1246
+ Args:
1247
+
1248
+ question: Natural language question to parse.
1249
+ schema_graph: SchemaGraph providing table and column metadata and schema text.
1250
+ max_retries: Maximum number of format-repair LLM attempts on JSON parse failure.
1251
+
1252
+ Returns:
1253
+
1254
+ Tuple of (RuntimeIntent, semantic_warnings, llm_call_count) where semantic_warnings is a list of non-fatal warning strings and llm_call_count is the total number of LLM API calls made; RuntimeIntent is None when parsing fails.
1255
+ """
1256
+ max_schema_sub_rounds = 3
1257
+ table_list = list(schema_graph.tables.keys())
1258
+ schema_literal_text = schema_graph.schema_literal_text
1259
+ llm_calls = 0
1260
+
1261
+ system, user = _build_intent_parse_prompt(question, schema_literal_text, table_list)
1262
+
1263
+ raw = llm_chat(system, user, task="intent")
1264
+ llm_calls += 1
1265
+ debug(f"[intent_process.full_intent_parse] raw_llm_response: {raw}")
1266
+ intent, fmt_calls = _format_repair_loop(system, raw, question, max_retries)
1267
+ llm_calls += fmt_calls
1268
+
1269
+ if not intent:
1270
+ debug("[intent_process.full_intent_parse] format repair exhausted on initial parse — terminating")
1271
+ return None, [], llm_calls
1272
+
1273
+ debug(
1274
+ f"[intent_process.full_intent_parse] normalized intent after initial parse:\n"
1275
+ f"{stable_json(intent.to_dict())}"
1276
+ )
1277
+
1278
+ semantic_warnings: list[str] = []
1279
+ seen_warning_ids: set[str] = set()
1280
+ semantic_error_history: list[frozenset[str]] = []
1281
+ accumulated_failure_hints: list[str] = []
1282
+
1283
+ for sem_round in range(PolicyConfig.MAX_REPAIR_LOOPS):
1284
+ debug(f"[intent_process.full_intent_parse] semantic round {sem_round + 1}/{PolicyConfig.MAX_REPAIR_LOOPS}")
1285
+
1286
+ intent = _apply_deterministic_repairs(intent, schema_graph, question)
1287
+ debug(
1288
+ f"[intent_process.full_intent_parse] full intent after deterministic repairs:\n"
1289
+ f"{stable_json(intent.to_dict())}"
1290
+ )
1291
+
1292
+ schema_error_history: list[frozenset[str]] = []
1293
+ schema_resolved = False
1294
+ for schema_sub in range(max_schema_sub_rounds):
1295
+ intent, schema_errors = enforce_schema(intent, schema_graph)
1296
+ if not schema_errors:
1297
+ schema_resolved = True
1298
+ break
1299
+ debug(
1300
+ f"[intent_process.full_intent_parse] schema sub-round {schema_sub + 1}/{max_schema_sub_rounds}: "
1301
+ f"{len(schema_errors)} errors"
1302
+ )
1303
+ sig = _compute_error_signature_strings(schema_errors)
1304
+ schema_error_history.append(sig)
1305
+ if _detect_oscillation(schema_error_history):
1306
+ debug("[intent_process.full_intent_parse] schema oscillation detected — breaking sub-loop")
1307
+ accumulated_failure_hints.extend(schema_errors)
1308
+ break
1309
+ if schema_sub >= max_schema_sub_rounds - 1:
1310
+ accumulated_failure_hints.extend(schema_errors)
1311
+ break
1312
+ schema_issues = [
1313
+ IntentIssue(
1314
+ issue_id=f"schema_error_{idx}",
1315
+ category=_classify_schema_error(err),
1316
+ severity="error",
1317
+ message=err,
1318
+ )
1319
+ for idx, err in enumerate(schema_errors)
1320
+ ]
1321
+ for iss in schema_issues:
1322
+ debug(f"[intent_process.full_intent_parse] issue_id={iss.issue_id} message={iss.message}")
1323
+ intent_json = stable_json(intent.to_dict())
1324
+ debug(
1325
+ f"[intent_process.full_intent_parse] intent being sent to schema repair LLM:\n{intent_json}"
1326
+ )
1327
+ debug(
1328
+ f"[intent_process.full_intent_parse] schema errors_to_fix: "
1329
+ f"{[(iss.category, iss.message) for iss in schema_issues]}"
1330
+ )
1331
+ repair_prompt = _build_intent_semantic_repair_prompt(
1332
+ question, intent_json, schema_issues, [], schema_literal_text
1333
+ )
1334
+ repaired_raw = llm_chat(system, repair_prompt, task="intent")
1335
+ llm_calls += 1
1336
+ repaired, fmt_calls = _format_repair_loop(system, repaired_raw, question, max_retries)
1337
+ llm_calls += fmt_calls
1338
+ if not repaired:
1339
+ debug("[intent_process.full_intent_parse] format repair exhausted after schema repair — terminating")
1340
+ return _attempt_fresh_restart(
1341
+ question,
1342
+ system,
1343
+ schema_literal_text,
1344
+ table_list,
1345
+ accumulated_failure_hints,
1346
+ max_retries,
1347
+ schema_graph,
1348
+ llm_calls,
1349
+ )
1350
+ intent = repaired
1351
+ debug(
1352
+ f"[intent_process.full_intent_parse] normalized intent after schema repair:\n"
1353
+ f"{stable_json(intent.to_dict())}"
1354
+ )
1355
+ intent = _apply_deterministic_repairs(intent, schema_graph, question)
1356
+
1357
+ if not schema_resolved:
1358
+ debug("[intent_process.full_intent_parse] schema errors persist after sub-loop — terminating")
1359
+ return _attempt_fresh_restart(
1360
+ question,
1361
+ system,
1362
+ schema_literal_text,
1363
+ table_list,
1364
+ accumulated_failure_hints,
1365
+ max_retries,
1366
+ schema_graph,
1367
+ llm_calls,
1368
+ )
1369
+
1370
+ validation_result = validate_semantics(intent, schema_graph)
1371
+
1372
+ errors = [iss for iss in validation_result.issues if iss.severity == "error"]
1373
+ warnings = [iss for iss in validation_result.issues if iss.severity == "warning"]
1374
+
1375
+ for iss in validation_result.issues:
1376
+ debug(f"[intent_process.full_intent_parse] issue_id={iss.issue_id} message={iss.message}")
1377
+ for w in warnings:
1378
+ if w.issue_id not in seen_warning_ids:
1379
+ seen_warning_ids.add(w.issue_id)
1380
+ semantic_warnings.append(w.message)
1381
+
1382
+ if not errors:
1383
+ debug(f"[intent_process.full_intent_parse] no semantic errors in round {sem_round + 1}")
1384
+ break
1385
+
1386
+ debug(
1387
+ f"[intent_process.full_intent_parse] {len(errors)} errors, {len(warnings)} warnings in round {sem_round + 1}"
1388
+ )
1389
+ accumulated_failure_hints.extend(iss.message for iss in errors)
1390
+
1391
+ sig = _compute_error_signature_issues(errors)
1392
+ semantic_error_history.append(sig)
1393
+ if _detect_oscillation(semantic_error_history):
1394
+ debug("[intent_process.full_intent_parse] semantic oscillation detected — trying fresh restart")
1395
+ return _attempt_fresh_restart(
1396
+ question,
1397
+ system,
1398
+ schema_literal_text,
1399
+ table_list,
1400
+ accumulated_failure_hints,
1401
+ max_retries,
1402
+ schema_graph,
1403
+ llm_calls,
1404
+ )
1405
+
1406
+ if sem_round >= PolicyConfig.MAX_REPAIR_LOOPS - 1:
1407
+ debug("[intent_process.full_intent_parse] semantic errors persist after max rounds — trying fresh restart")
1408
+ return _attempt_fresh_restart(
1409
+ question,
1410
+ system,
1411
+ schema_literal_text,
1412
+ table_list,
1413
+ accumulated_failure_hints,
1414
+ max_retries,
1415
+ schema_graph,
1416
+ llm_calls,
1417
+ )
1418
+
1419
+ intent_json = stable_json(intent.to_dict())
1420
+ debug(
1421
+ f"[intent_process.full_intent_parse] intent being sent to semantic repair LLM:\n{intent_json}"
1422
+ )
1423
+ debug(
1424
+ f"[intent_process.full_intent_parse] errors_to_fix: "
1425
+ f"{[(e.category, e.message) for e in errors]}"
1426
+ )
1427
+ repair_prompt = _build_intent_semantic_repair_prompt(
1428
+ question, intent_json, errors, warnings, schema_literal_text
1429
+ )
1430
+ repaired_raw = llm_chat(system, repair_prompt, task="intent")
1431
+ llm_calls += 1
1432
+ repaired, fmt_calls = _format_repair_loop(system, repaired_raw, question, max_retries)
1433
+ llm_calls += fmt_calls
1434
+
1435
+ if not repaired:
1436
+ debug(
1437
+ "[intent_process.full_intent_parse] format repair exhausted after semantic repair — trying fresh restart"
1438
+ )
1439
+ return _attempt_fresh_restart(
1440
+ question,
1441
+ system,
1442
+ schema_literal_text,
1443
+ table_list,
1444
+ accumulated_failure_hints,
1445
+ max_retries,
1446
+ schema_graph,
1447
+ llm_calls,
1448
+ )
1449
+
1450
+ intent = repaired
1451
+ debug(
1452
+ f"[intent_process.full_intent_parse] normalized intent after semantic repair:\n"
1453
+ f"{stable_json(intent.to_dict())}"
1454
+ )
1455
+
1456
+ result = _apply_post_processing(intent, schema_graph, question)
1457
+ if result is None:
1458
+ return None, [], llm_calls
1459
+
1460
+ debug(
1461
+ f"[intent_process.full_intent_parse] parsed intent with {len(result.tables or [])} tables, "
1462
+ f"{len(result.filters_param or [])} filters, {llm_calls} LLM calls"
1463
+ )
1464
+
1465
+ return result, semantic_warnings, llm_calls
1466
+
1467
+
1468
+ def _compute_filters_similarity(filters1: list[FilterParam], filters2: list[FilterParam]) -> float:
1469
+ """Compute Jaccard similarity between two filter parameter lists.
1470
+
1471
+ Args:
1472
+
1473
+
1474
+ filters1: First list of FilterParam objects.
1475
+ filters2: Second list of FilterParam objects.
1476
+
1477
+ Returns:
1478
+
1479
+
1480
+ Float in the range [0.0, 1.0] with 1.0 when both lists are empty or identical.
1481
+ """
1482
+ if not filters1 and not filters2:
1483
+ return 1.0
1484
+ if not filters1 or not filters2:
1485
+ return 0.0
1486
+ keys1 = {fp.signature_key for fp in filters1}
1487
+ keys2 = {fp.signature_key for fp in filters2}
1488
+ score = _jaccard(keys1, keys2)
1489
+ if score > 0:
1490
+ bops1 = sorted(fp.bool_op for fp in filters1)
1491
+ bops2 = sorted(fp.bool_op for fp in filters2)
1492
+ if bops1 != bops2:
1493
+ score *= 0.9
1494
+ return score
1495
+
1496
+
1497
+ def _compute_having_similarity(having1: list[HavingParam], having2: list[HavingParam]) -> float:
1498
+ """Compute Jaccard similarity between two having parameter lists.
1499
+
1500
+ Args:
1501
+
1502
+
1503
+ having1: First list of HavingParam objects.
1504
+ having2: Second list of HavingParam objects.
1505
+
1506
+ Returns:
1507
+
1508
+
1509
+ Float in the range [0.0, 1.0] with 1.0 when both lists are empty or identical.
1510
+ """
1511
+ if not having1 and not having2:
1512
+ return 1.0
1513
+ if not having1 or not having2:
1514
+ return 0.0
1515
+ keys1 = {hp.signature_key for hp in having1}
1516
+ keys2 = {hp.signature_key for hp in having2}
1517
+ score = _jaccard(keys1, keys2)
1518
+ if score > 0:
1519
+ bops1 = sorted(hp.bool_op for hp in having1)
1520
+ bops2 = sorted(hp.bool_op for hp in having2)
1521
+ if bops1 != bops2:
1522
+ score *= 0.9
1523
+ return score
1524
+
1525
+
1526
+ def _compute_select_cols_similarity(cols1: list[SelectCol], cols2: list[SelectCol]) -> float:
1527
+ """Compute Jaccard similarity between two select column lists.
1528
+
1529
+ Args:
1530
+
1531
+
1532
+ cols1: First list of SelectCol objects.
1533
+ cols2: Second list of SelectCol objects.
1534
+
1535
+ Returns:
1536
+
1537
+
1538
+ Float in the range [0.0, 1.0] with 1.0 when both lists are empty or identical.
1539
+ """
1540
+ if not cols1 and not cols2:
1541
+ return 1.0
1542
+ if not cols1 or not cols2:
1543
+ return 0.0
1544
+ keys1 = {sc.signature_key for sc in cols1}
1545
+ keys2 = {sc.signature_key for sc in cols2}
1546
+ return _jaccard(keys1, keys2)
1547
+
1548
+
1549
+ def _compute_order_by_cols_similarity(cols1: list[OrderByCol], cols2: list[OrderByCol]) -> float:
1550
+ """Compute Jaccard similarity between two order-by column lists.
1551
+
1552
+ Args:
1553
+
1554
+
1555
+ cols1: First list of OrderByCol objects.
1556
+ cols2: Second list of OrderByCol objects.
1557
+
1558
+ Returns:
1559
+
1560
+
1561
+ Float in the range [0.0, 1.0] with 1.0 when both lists are empty or identical.
1562
+ """
1563
+ if not cols1 and not cols2:
1564
+ return 1.0
1565
+ if not cols1 or not cols2:
1566
+ return 0.0
1567
+ keys1 = {obc.signature_key for obc in cols1}
1568
+ keys2 = {obc.signature_key for obc in cols2}
1569
+ return _jaccard(keys1, keys2)
1570
+
1571
+
1572
+ def _base_similarity(
1573
+ tables1: list[str],
1574
+ tables2: list[str],
1575
+ select1: list[SelectCol],
1576
+ select2: list[SelectCol],
1577
+ group1: list[NormalizedExpr],
1578
+ group2: list[NormalizedExpr],
1579
+ order1: list[OrderByCol],
1580
+ order2: list[OrderByCol],
1581
+ filters1: list[FilterParam],
1582
+ filters2: list[FilterParam],
1583
+ having1: list[HavingParam],
1584
+ having2: list[HavingParam],
1585
+ ) -> float:
1586
+ """Compute a uniformly weighted similarity score across all structural intent fields.
1587
+
1588
+ Each of the six dimensions (tables, select, group_by, order_by, filters, and having) contributes an equal one sixth weight.
1589
+
1590
+ Returns:
1591
+
1592
+
1593
+ Float in the range [0.0, 1.0].
1594
+ """
1595
+ tables_sim = _jaccard(set(tables1), set(tables2))
1596
+ select_sim = _compute_select_cols_similarity(select1, select2)
1597
+ group_sim = _jaccard({g.signature_key for g in group1}, {g.signature_key for g in group2})
1598
+ order_sim = _compute_order_by_cols_similarity(order1, order2)
1599
+ filter_sim = _compute_filters_similarity(filters1, filters2)
1600
+ having_sim = _compute_having_similarity(having1, having2)
1601
+ return (
1602
+ 1 / 6 * tables_sim
1603
+ + 1 / 6 * select_sim
1604
+ + 1 / 6 * group_sim
1605
+ + 1 / 6 * order_sim
1606
+ + 1 / 6 * filter_sim
1607
+ + 1 / 6 * having_sim
1608
+ )
1609
+
1610
+
1611
+ def _cte_step_similarity(cte1: RuntimeCteStep, cte2: RuntimeCteStep) -> float:
1612
+ """Compute similarity score between two CTE steps.
1613
+
1614
+ Args:
1615
+
1616
+
1617
+ cte1: First CTE step.
1618
+ cte2: Second CTE step.
1619
+
1620
+ Returns:
1621
+
1622
+
1623
+ Float in the range [0.0, 1.0].
1624
+ """
1625
+ return _base_similarity(
1626
+ cte1.tables or [],
1627
+ cte2.tables or [],
1628
+ cte1.select_cols or [],
1629
+ cte2.select_cols or [],
1630
+ cte1.group_by_cols or [],
1631
+ cte2.group_by_cols or [],
1632
+ cte1.order_by_cols or [],
1633
+ cte2.order_by_cols or [],
1634
+ cte1.filters_param or [],
1635
+ cte2.filters_param or [],
1636
+ cte1.having_param or [],
1637
+ cte2.having_param or [],
1638
+ )
1639
+
1640
+
1641
+ def intent_similarity(intent1: RuntimeIntent | ConcreteIntent, intent2: RuntimeIntent | ConcreteIntent) -> float:
1642
+ """Compute overall similarity score between two intents with dynamic CTE weighting.
1643
+
1644
+ The main-query similarity is down-weighted when CTEs are present and the remaining weight is split equally across matched CTE step pairs.
1645
+
1646
+ Args:
1647
+
1648
+
1649
+ intent1: First intent, either a RuntimeIntent or a ConcreteIntent.
1650
+ intent2: Second intent.
1651
+
1652
+ Returns:
1653
+
1654
+
1655
+ Float in the range [0.0, 1.0].
1656
+ """
1657
+ base_sim = _base_similarity(
1658
+ intent1.tables or [],
1659
+ intent2.tables or [],
1660
+ intent1.select_cols or [],
1661
+ intent2.select_cols or [],
1662
+ intent1.group_by_cols or [],
1663
+ intent2.group_by_cols or [],
1664
+ intent1.order_by_cols or [],
1665
+ intent2.order_by_cols or [],
1666
+ intent1.filters_param or [],
1667
+ intent2.filters_param or [],
1668
+ intent1.having_param or [],
1669
+ intent2.having_param or [],
1670
+ )
1671
+ ctes1 = intent1.cte_steps or []
1672
+ ctes2 = intent2.cte_steps or []
1673
+ n_cte = max(len(ctes1), len(ctes2))
1674
+ if n_cte == 0:
1675
+ return base_sim
1676
+ intent_weight = {1: 0.7, 2: 0.6}.get(n_cte, 0.4)
1677
+ cte_total_weight = 1.0 - intent_weight
1678
+ cte_per_weight = cte_total_weight / n_cte
1679
+ cte_score = 0.0
1680
+ for i in range(n_cte):
1681
+ if i < len(ctes1) and i < len(ctes2):
1682
+ cte_score += cte_per_weight * _cte_step_similarity(ctes1[i], ctes2[i])
1683
+ return intent_weight * base_sim + cte_score
1684
+
1685
+
1686
+ def _jaccard(set1: set[str], set2: set[str]) -> float:
1687
+ """Compute Jaccard similarity between two string sets.
1688
+
1689
+ Args:
1690
+
1691
+
1692
+ set1: First set of strings.
1693
+ set2: Second set of strings.
1694
+
1695
+ Returns:
1696
+
1697
+
1698
+ Float in the range [0.0, 1.0] with 1.0 when both sets are empty.
1699
+ """
1700
+ if not set1 and not set2:
1701
+ return 1.0
1702
+ if not set1 or not set2:
1703
+ return 0.0
1704
+ intersection = set1 & set2
1705
+ union = set1 | set2
1706
+ return len(intersection) / len(union) if union else 1.0
1707
+
1708
+
1709
+ def intent_approval(intent: SimulatorIntent, runtime_intent: RuntimeIntent) -> tuple[float, list[str]]:
1710
+ """Check how closely a SimulatorIntent matches its RuntimeIntent conversion.
1711
+
1712
+ Compares tables, select columns, group_by, order_by, filters, having, and CTE step counts and then computes an overall similarity score.
1713
+
1714
+ Args:
1715
+
1716
+
1717
+ intent: SimulatorIntent to evaluate.
1718
+ runtime_intent: RuntimeIntent to compare against.
1719
+
1720
+ Returns:
1721
+
1722
+
1723
+ Tuple of (score, diffs) where score is a float in [0.0, 1.0] and diffs is a list of human-readable difference descriptions.
1724
+ """
1725
+ diffs: list[str] = []
1726
+ if sorted(intent.tables or []) != sorted(runtime_intent.tables or []):
1727
+ diffs.append(f"tables: {sorted(intent.tables or [])} vs {sorted(runtime_intent.tables or [])}")
1728
+ if len(intent.select_cols or []) != len(runtime_intent.select_cols or []):
1729
+ diffs.append(f"select_cols_count: {len(intent.select_cols or [])} vs {len(runtime_intent.select_cols or [])}")
1730
+ else:
1731
+ src_sigs = sorted(s.signature_key for s in (intent.select_cols or []))
1732
+ tgt_sigs = sorted(s.signature_key for s in (runtime_intent.select_cols or []))
1733
+ if src_sigs != tgt_sigs:
1734
+ diffs.append(f"select_cols_sigs: {src_sigs} vs {tgt_sigs}")
1735
+ src_gb = sorted(g.signature_key for g in (intent.group_by_cols or []))
1736
+ tgt_gb = sorted(g.signature_key for g in (runtime_intent.group_by_cols or []))
1737
+ if src_gb != tgt_gb:
1738
+ diffs.append(f"group_by_cols: {src_gb} vs {tgt_gb}")
1739
+ if len(intent.order_by_cols or []) != len(runtime_intent.order_by_cols or []):
1740
+ diffs.append(
1741
+ f"order_by_cols_count: {len(intent.order_by_cols or [])} vs {len(runtime_intent.order_by_cols or [])}"
1742
+ )
1743
+ else:
1744
+ src_obc = sorted(o.signature_key for o in (intent.order_by_cols or []))
1745
+ tgt_obc = sorted(o.signature_key for o in (runtime_intent.order_by_cols or []))
1746
+ if src_obc != tgt_obc:
1747
+ diffs.append(f"order_by_cols_sigs: {src_obc} vs {tgt_obc}")
1748
+ if len(intent.filters_param or []) != len(runtime_intent.filters_param or []):
1749
+ diffs.append(f"filters_count: {len(intent.filters_param or [])} vs {len(runtime_intent.filters_param or [])}")
1750
+ else:
1751
+ src_fp = sorted(fp.signature_key for fp in (intent.filters_param or []))
1752
+ tgt_fp = sorted(fp.signature_key for fp in (runtime_intent.filters_param or []))
1753
+ if src_fp != tgt_fp:
1754
+ diffs.append(f"filters_sigs: {src_fp} vs {tgt_fp}")
1755
+ else:
1756
+ src_fp_bo = sorted((fp.signature_key, fp.bool_op, fp.filter_group) for fp in (intent.filters_param or []))
1757
+ tgt_fp_bo = sorted(
1758
+ (fp.signature_key, fp.bool_op, fp.filter_group) for fp in (runtime_intent.filters_param or [])
1759
+ )
1760
+ if src_fp_bo != tgt_fp_bo:
1761
+ diffs.append(f"filters_bool_op: {src_fp_bo} vs {tgt_fp_bo}")
1762
+ if len(intent.having_param or []) != len(runtime_intent.having_param or []):
1763
+ diffs.append(f"having_count: {len(intent.having_param or [])} vs {len(runtime_intent.having_param or [])}")
1764
+ else:
1765
+ src_hp = sorted(hp.signature_key for hp in (intent.having_param or []))
1766
+ tgt_hp = sorted(hp.signature_key for hp in (runtime_intent.having_param or []))
1767
+ if src_hp != tgt_hp:
1768
+ diffs.append(f"having_sigs: {src_hp} vs {tgt_hp}")
1769
+ else:
1770
+ src_hp_bo = sorted((hp.signature_key, hp.bool_op, hp.filter_group) for hp in (intent.having_param or []))
1771
+ tgt_hp_bo = sorted(
1772
+ (hp.signature_key, hp.bool_op, hp.filter_group) for hp in (runtime_intent.having_param or [])
1773
+ )
1774
+ if src_hp_bo != tgt_hp_bo:
1775
+ diffs.append(f"having_bool_op: {src_hp_bo} vs {tgt_hp_bo}")
1776
+ src_cte_count = len(intent.cte_steps or [])
1777
+ tgt_cte_count = len(runtime_intent.cte_steps or [])
1778
+ if src_cte_count != tgt_cte_count:
1779
+ diffs.append(f"cte_steps_count: {src_cte_count} vs {tgt_cte_count}")
1780
+ converted = intent.to_runtime_intent()
1781
+ score = intent_similarity(converted, runtime_intent)
1782
+ return score, diffs
1783
+
1784
+
1785
+ def find_trusted_template_match(question: str, templates: list[Template]) -> tuple[RuntimeIntent, Template] | None:
1786
+ """Check whether the LLM intent parse can be skipped via a trusted template match.
1787
+
1788
+ Only templates with trust_level equal to 2 are considered and an exact fuzzy match against any historical question in the template triggers the skip path.
1789
+
1790
+ Args:
1791
+
1792
+
1793
+ question: Normalized question to look up.
1794
+ templates: List of Template objects to search.
1795
+
1796
+ Returns:
1797
+
1798
+
1799
+ Tuple of (None, matching_template) when a match is found, or None when no match is found and the LLM must be called.
1800
+ """
1801
+ if not templates:
1802
+ return None
1803
+
1804
+ for tpl in templates:
1805
+ if tpl.trust_level != 2:
1806
+ debug(f"[intent_process.find_trusted_template_match] {tpl.id} skipped (trust={tpl.trust_level})")
1807
+ continue
1808
+
1809
+ for hist_q in tpl.value_history.questions:
1810
+ if exact_question_match(question, hist_q or "", label=tpl.id):
1811
+ debug(f"[intent_process.find_trusted_template_match] fuzzy match with template {tpl.id}")
1812
+ return None, tpl
1813
+
1814
+ return None
1815
+
1816
+
1817
+ def llm_ux_explain(intent: RuntimeIntent, question: str) -> str:
1818
+ """Generate a short user-friendly explanation of an intent for disambiguation.
1819
+
1820
+ Args:
1821
+
1822
+
1823
+ intent: Parsed RuntimeIntent to summarize.
1824
+ question: Original question, unused in the LLM prompt but available for context.
1825
+
1826
+ Returns:
1827
+
1828
+
1829
+ Plain English phrase under 10 words describing what the query does.
1830
+ """
1831
+ system = "You are a helpful assistant that explains SQL queries in plain English. Output JSON with a single 'summary' key."
1832
+
1833
+ tables = ", ".join(intent.tables or [])
1834
+
1835
+ cols_desc = []
1836
+ for sc in intent.select_cols or []:
1837
+ term = sc.expr.primary_term
1838
+ cols_desc.append(term if term else "*")
1839
+ cols_str = ", ".join(cols_desc) if cols_desc else "*"
1840
+
1841
+ filters_desc = []
1842
+ all_filters = intent.filters_param or []
1843
+ for i, fp in enumerate(all_filters):
1844
+ left = fp.left_expr.primary_column or fp.left_expr.primary_term
1845
+ if fp.right_expr:
1846
+ right = fp.right_expr.primary_column or fp.right_expr.primary_term
1847
+ filters_desc.append(f"{left} {fp.op} {right}")
1848
+ else:
1849
+ filters_desc.append(f"{left} {fp.op} ?")
1850
+ if i < len(all_filters) - 1:
1851
+ filters_desc.append(fp.bool_op)
1852
+ filters_str = " ".join(filters_desc) if filters_desc else "none"
1853
+
1854
+ user = stable_json(
1855
+ {
1856
+ "task": "Summarize what this query does in one short phrase (under 10 words). Do not mention specific filter values or literals. Return JSON: {\"summary\": \"<phrase>\"}.",
1857
+ "tables": tables,
1858
+ "columns": cols_str,
1859
+ "filters": filters_str,
1860
+ "group_by": ", ".join(g.primary_column for g in (intent.group_by_cols or [])) or "none",
1861
+ }
1862
+ )
1863
+
1864
+ raw = llm_chat(system, user, task="intent")
1865
+ parsed = safe_json_loads(raw)
1866
+ if isinstance(parsed, dict) and isinstance(parsed.get("summary"), str):
1867
+ return parsed["summary"].strip()
1868
+ return raw.strip()
1869
+
1870
+
1871
+ MAX_NON_AGG_COL_DIFF = 2
1872
+
1873
+
1874
+ def _cte_structural_signature(steps: list) -> list[tuple[str, str]]:
1875
+ """Build sorted structural fingerprints for CTE steps.
1876
+
1877
+ Each CTE step produces a ``(cte_name, body_signature)`` tuple where the
1878
+ body signature covers filters, group-by, order-by, having, and grain but
1879
+ deliberately excludes select columns so that column-only deltas are
1880
+ handled by the union matching logic.
1881
+
1882
+ Args:
1883
+ steps: CTE step objects (``RuntimeCteStep`` or ``ConcreteCteStep``).
1884
+
1885
+ Returns:
1886
+ Sorted list of ``(cte_name, body_signature)`` tuples.
1887
+ """
1888
+ sigs: list[tuple[str, str]] = []
1889
+ for cte in steps:
1890
+ parts: list[str] = [
1891
+ cte.grain or "row_level",
1892
+ ",".join(sorted(f.signature_key for f in (cte.filters_param or []))),
1893
+ ",".join(sorted(g.signature_key for g in (cte.group_by_cols or []))),
1894
+ ",".join(sorted(o.signature_key for o in (cte.order_by_cols or []))),
1895
+ ",".join(sorted(h.signature_key for h in (cte.having_param or []))),
1896
+ ]
1897
+ sigs.append((cte.cte_name, "|".join(parts)))
1898
+ return sorted(sigs, key=lambda t: t[0])
1899
+
1900
+
1901
+ def _structural_body_matches(
1902
+ intent: RuntimeIntent,
1903
+ concrete: ConcreteIntent,
1904
+ ) -> bool:
1905
+ """Check whether every non-select structural element matches exactly.
1906
+
1907
+ Compares grain, limit, filters, group-by, order-by, having, and CTE body
1908
+ signatures between the incoming runtime intent and a stored concrete
1909
+ intent. Select columns and tables are deliberately excluded because they
1910
+ are handled by the union diff logic.
1911
+
1912
+ Args:
1913
+ intent: The newly parsed runtime intent.
1914
+ concrete: The stored concrete intent from an existing template.
1915
+
1916
+ Returns:
1917
+ ``True`` when all structural clauses match.
1918
+ """
1919
+ if (intent.grain or "row_level") != (concrete.grain or "row_level"):
1920
+ return False
1921
+ if intent.limit != concrete.limit:
1922
+ return False
1923
+
1924
+ i_filters = sorted(f.signature_key for f in (intent.filters_param or []))
1925
+ c_filters = sorted(f.signature_key for f in (concrete.filters_param or []))
1926
+ if i_filters != c_filters:
1927
+ return False
1928
+
1929
+ i_gb = sorted(g.signature_key for g in (intent.group_by_cols or []))
1930
+ c_gb = sorted(g.signature_key for g in (concrete.group_by_cols or []))
1931
+ if i_gb != c_gb:
1932
+ return False
1933
+
1934
+ i_ob = sorted(o.signature_key for o in (intent.order_by_cols or []))
1935
+ c_ob = sorted(o.signature_key for o in (concrete.order_by_cols or []))
1936
+ if i_ob != c_ob:
1937
+ return False
1938
+
1939
+ i_hav = sorted(h.signature_key for h in (intent.having_param or []))
1940
+ c_hav = sorted(h.signature_key for h in (concrete.having_param or []))
1941
+ if i_hav != c_hav:
1942
+ return False
1943
+
1944
+ if _cte_structural_signature(intent.cte_steps or []) != _cte_structural_signature(
1945
+ concrete.cte_steps or []
1946
+ ):
1947
+ return False
1948
+
1949
+ return True
1950
+
1951
+
1952
+ def _select_col_diff(
1953
+ intent_cols: list[SelectCol],
1954
+ concrete_cols: list[SelectCol],
1955
+ ) -> tuple[bool, int]:
1956
+ """Compute the select-column delta between two intents.
1957
+
1958
+ Aggregated columns must match exactly; non-aggregated columns may differ
1959
+ by up to ``MAX_NON_AGG_COL_DIFF``.
1960
+
1961
+ Args:
1962
+ intent_cols: Select columns from the runtime intent.
1963
+ concrete_cols: Select columns from the stored concrete intent.
1964
+
1965
+ Returns:
1966
+ ``(agg_match, non_agg_diff)`` — whether aggregate columns are
1967
+ identical and the count of non-aggregate column key differences.
1968
+ """
1969
+ i_agg = sorted(s.signature_key for s in intent_cols if s.is_aggregated)
1970
+ c_agg = sorted(s.signature_key for s in concrete_cols if s.is_aggregated)
1971
+ agg_match = i_agg == c_agg
1972
+
1973
+ i_non = set(s.signature_key for s in intent_cols if not s.is_aggregated)
1974
+ c_non = set(s.signature_key for s in concrete_cols if not s.is_aggregated)
1975
+ non_agg_diff = len(i_non.symmetric_difference(c_non))
1976
+
1977
+ return agg_match, non_agg_diff
1978
+
1979
+
1980
+ def _diff_cols_span_disjoint_tables(
1981
+ intent_cols: list[SelectCol],
1982
+ concrete_cols: list[SelectCol],
1983
+ intent_tables: list[str],
1984
+ concrete_tables: list[str],
1985
+ ) -> bool:
1986
+ """Return True when any differing non-agg column references a
1987
+ table outside the intersection of both table lists.
1988
+
1989
+ Prevents a union match from absorbing columns that would
1990
+ introduce joins the template cannot satisfy, or from inheriting
1991
+ stale template columns that reference tables the intent does not
1992
+ need.
1993
+
1994
+ Args:
1995
+ intent_cols: Select columns from the runtime intent.
1996
+ concrete_cols: Select columns from the stored concrete intent.
1997
+ intent_tables: Table list from the runtime intent.
1998
+ concrete_tables: Table list from the stored concrete intent.
1999
+
2000
+ Returns:
2001
+ ``True`` if any diff column belongs to a table not shared by
2002
+ both intents.
2003
+ """
2004
+ shared_tables = set(intent_tables or []) & set(concrete_tables or [])
2005
+ i_non = {s.signature_key for s in intent_cols if not s.is_aggregated}
2006
+ c_non = {s.signature_key for s in concrete_cols if not s.is_aggregated}
2007
+ diff_keys = i_non.symmetric_difference(c_non)
2008
+ if not diff_keys:
2009
+ return False
2010
+ all_cols = list(intent_cols) + list(concrete_cols)
2011
+ for sc in all_cols:
2012
+ if sc.is_aggregated or sc.signature_key not in diff_keys:
2013
+ continue
2014
+ term = sc.expr.primary_term
2015
+ tbl = term.split(".")[0] if "." in term else ""
2016
+ if tbl and tbl not in shared_tables:
2017
+ return True
2018
+ return False
2019
+
2020
+
2021
+ def match_template_for_union(
2022
+ intent: RuntimeIntent,
2023
+ templates: dict[str, Template],
2024
+ ) -> tuple[Template, list[SelectCol], bool] | None:
2025
+ """Find the best template whose structure can absorb the new intent
2026
+ via column-only union.
2027
+
2028
+ Iterates all templates with ``trust_level >= 1`` and checks for an
2029
+ exact structural body match (filters, group-by, order-by, having,
2030
+ CTE skeleton, grain, limit). Select columns may differ by up to
2031
+ ``MAX_NON_AGG_COL_DIFF`` non-aggregate columns, and aggregate
2032
+ columns must match exactly.
2033
+
2034
+ A match is rejected when any differing non-aggregate column
2035
+ references a table outside the intersection of the intent and
2036
+ template table lists. This prevents union merges from inheriting
2037
+ stale context columns that would introduce foreign joins.
2038
+
2039
+ When a match is found the function delegates to
2040
+ ``compute_intent_union`` to produce merged select columns. Among
2041
+ all valid candidates the template with the smallest column diff
2042
+ is preferred.
2043
+
2044
+ Args:
2045
+ intent: The newly parsed and validated runtime intent.
2046
+ templates: All stored templates keyed by template id.
2047
+
2048
+ Returns:
2049
+ ``(matched_template, union_select_cols, cols_changed)`` for the
2050
+ best match, or ``None`` when no template qualifies.
2051
+ """
2052
+ best: tuple[Template, list[SelectCol], bool, int] | None = None
2053
+
2054
+ for tmpl in templates.values():
2055
+ if tmpl.trust_level < 1:
2056
+ continue
2057
+
2058
+ concrete = tmpl.intent_signature
2059
+ if not _structural_body_matches(intent, concrete):
2060
+ continue
2061
+
2062
+ agg_match, non_agg_diff = _select_col_diff(
2063
+ intent.select_cols or [], concrete.select_cols or [],
2064
+ )
2065
+ if not agg_match:
2066
+ continue
2067
+ if non_agg_diff > MAX_NON_AGG_COL_DIFF:
2068
+ continue
2069
+
2070
+ if _diff_cols_span_disjoint_tables(
2071
+ intent.select_cols or [],
2072
+ concrete.select_cols or [],
2073
+ intent.tables or [],
2074
+ concrete.tables or [],
2075
+ ):
2076
+ continue
2077
+
2078
+ union_cols, cols_changed = compute_intent_union(intent, concrete)
2079
+
2080
+ if best is None or non_agg_diff < best[3]:
2081
+ best = (tmpl, union_cols, cols_changed, non_agg_diff)
2082
+
2083
+ if best is None:
2084
+ return None
2085
+
2086
+ debug(
2087
+ f"[intent_process.match_template_for_union] matched template={best[0].id} "
2088
+ f"non_agg_diff={best[3]} union_cols={len(best[1])}"
2089
+ )
2090
+ return best[0], best[1], best[2]
2091
+
2092
+
2093
+ def compute_intent_union(
2094
+ intent: RuntimeIntent,
2095
+ concrete: ConcreteIntent,
2096
+ ) -> tuple[list[SelectCol], bool]:
2097
+ """Compute the column-only union between two intents.
2098
+
2099
+ Deduplicates select columns by ``signature_key`` preserving the
2100
+ order from the concrete intent first, then appending any new
2101
+ columns from the runtime intent.
2102
+
2103
+ Args:
2104
+ intent: The newly parsed runtime intent.
2105
+ concrete: The stored concrete intent from the matched template.
2106
+
2107
+ Returns:
2108
+ ``(union_select_cols, cols_changed)`` where ``cols_changed``
2109
+ indicates whether the union differs from the concrete intent's
2110
+ original column set.
2111
+ """
2112
+ seen_keys: set[str] = set()
2113
+ union_cols: list[SelectCol] = []
2114
+ for sc in concrete.select_cols or []:
2115
+ key = sc.signature_key
2116
+ if key not in seen_keys:
2117
+ seen_keys.add(key)
2118
+ union_cols.append(sc)
2119
+ for sc in intent.select_cols or []:
2120
+ key = sc.signature_key
2121
+ if key not in seen_keys:
2122
+ seen_keys.add(key)
2123
+ union_cols.append(sc)
2124
+
2125
+ cols_changed = sorted(seen_keys) != sorted(
2126
+ s.signature_key for s in (concrete.select_cols or [])
2127
+ )
2128
+
2129
+ debug(
2130
+ f"[intent_process.compute_intent_union] "
2131
+ f"union_cols={len(union_cols)} cols_changed={cols_changed}"
2132
+ )
2133
+ return union_cols, cols_changed