aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
|
@@ -0,0 +1,1733 @@
|
|
|
1
|
+
"""Structural intent repairs and value normalization.
|
|
2
|
+
|
|
3
|
+
Repairs foreign key filter type mismatches where an integer column is compared to a string value by rewriting filters to use descriptive columns and expands foreign key selects to descriptive columns.
|
|
4
|
+
|
|
5
|
+
Strips spurious GROUP BY clauses, impossible HAVING conditions such as COUNT < 0, and hallucinated SQL keywords in table names.
|
|
6
|
+
|
|
7
|
+
Strips foreign key equi-join conditions from filters, prunes unreferenced tables, and normalizes boolean filter values, IN-list types, and filter value casing against schema statistics and question text.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import re
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from dataclasses import replace
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from .config import (
|
|
18
|
+
BOOLEAN_FALSY_VALUES,
|
|
19
|
+
BOOLEAN_TRUTHY_VALUES,
|
|
20
|
+
DISTINCT_RE,
|
|
21
|
+
IMPOSSIBLE_HAVING_RE,
|
|
22
|
+
NUMERIC_DATA_TYPES,
|
|
23
|
+
NUMERIC_LITERAL_RE,
|
|
24
|
+
RANGE_OPS,
|
|
25
|
+
SQL_KEYWORDS,
|
|
26
|
+
TOP_N_RE,
|
|
27
|
+
PolicyConfig,
|
|
28
|
+
)
|
|
29
|
+
from .contracts_base import ColumnMetadata, SchemaGraph, TableMetadata
|
|
30
|
+
from .contracts_core import (
|
|
31
|
+
FilterParam,
|
|
32
|
+
HavingParam,
|
|
33
|
+
MulGroup,
|
|
34
|
+
NormalizedExpr,
|
|
35
|
+
RuntimeCteStep,
|
|
36
|
+
RuntimeIntent,
|
|
37
|
+
SelectCol,
|
|
38
|
+
)
|
|
39
|
+
from .core_utils import debug
|
|
40
|
+
from .intent_expr import extract_columns_from_expr, replace_refs_in_expr
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _english_plurals(word: str) -> list[str]:
|
|
44
|
+
"""Return *word* together with its common English plural forms.
|
|
45
|
+
|
|
46
|
+
Covers consonant-y to -ies, sibilant endings to -es, and the default -s suffix and always returns the original word as the first element so callers can iterate a single list.
|
|
47
|
+
"""
|
|
48
|
+
forms = [word]
|
|
49
|
+
w = word.lower()
|
|
50
|
+
if w.endswith("y") and len(w) > 2 and w[-2] not in "aeiou":
|
|
51
|
+
forms.append(w[:-1] + "ies")
|
|
52
|
+
elif w.endswith(("s", "sh", "ch", "x", "z")):
|
|
53
|
+
forms.append(w + "es")
|
|
54
|
+
else:
|
|
55
|
+
forms.append(w + "s")
|
|
56
|
+
return forms
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _apply_filters_to_main_and_ctes(
|
|
60
|
+
intent: RuntimeIntent,
|
|
61
|
+
process_fn: Callable[[list[FilterParam]], tuple[list[FilterParam], bool]],
|
|
62
|
+
) -> RuntimeIntent:
|
|
63
|
+
"""Apply a filter processor to the main intent and each CTE, merging results."""
|
|
64
|
+
new_fp, main_changed = process_fn(intent.filters_param or [])
|
|
65
|
+
if not intent.cte_steps:
|
|
66
|
+
return replace(intent, filters_param=new_fp) if main_changed else intent
|
|
67
|
+
new_cte_steps = []
|
|
68
|
+
cte_changed = False
|
|
69
|
+
for cte in intent.cte_steps:
|
|
70
|
+
cte_fp, c = process_fn(cte.filters_param or [])
|
|
71
|
+
if c:
|
|
72
|
+
cte_changed = True
|
|
73
|
+
new_cte_steps.append(replace(cte, filters_param=cte_fp))
|
|
74
|
+
if not main_changed and not cte_changed:
|
|
75
|
+
return intent
|
|
76
|
+
result = replace(intent, filters_param=new_fp)
|
|
77
|
+
if cte_changed:
|
|
78
|
+
result = replace(result, cte_steps=new_cte_steps)
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _dedup_contradictory_filters_list(
|
|
83
|
+
filters: list[FilterParam],
|
|
84
|
+
) -> tuple[list[FilterParam], bool]:
|
|
85
|
+
"""Remove range operators when equality exists on the same column.
|
|
86
|
+
|
|
87
|
+
When a column has both an '=' filter and a range operator such as '>', '<', '>=', or '<=', the range filter contradicts or is redundant with the equality and the range filter is dropped.
|
|
88
|
+
"""
|
|
89
|
+
eq_columns: set[str] = set()
|
|
90
|
+
for fp in filters:
|
|
91
|
+
col = fp.left_expr.primary_column or ""
|
|
92
|
+
if fp.op == "=" and col:
|
|
93
|
+
eq_columns.add(col)
|
|
94
|
+
|
|
95
|
+
if not eq_columns:
|
|
96
|
+
return filters, False
|
|
97
|
+
|
|
98
|
+
kept: list[FilterParam] = []
|
|
99
|
+
changed = False
|
|
100
|
+
for fp in filters:
|
|
101
|
+
col = fp.left_expr.primary_column or ""
|
|
102
|
+
if col in eq_columns and fp.op in RANGE_OPS:
|
|
103
|
+
debug(f"[intent_repair.dedup_contradictory_filters] dropping {fp.op} on '{col}' that contradicts =")
|
|
104
|
+
changed = True
|
|
105
|
+
continue
|
|
106
|
+
kept.append(fp)
|
|
107
|
+
return kept, changed
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def dedup_contradictory_filters(intent: RuntimeIntent) -> RuntimeIntent:
|
|
111
|
+
"""Remove contradictory range filters from main query and CTEs."""
|
|
112
|
+
return _apply_filters_to_main_and_ctes(intent, _dedup_contradictory_filters_list)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _is_null_value(raw_value: Any) -> bool:
|
|
116
|
+
"""Return True if the raw filter value represents NULL."""
|
|
117
|
+
if raw_value is None:
|
|
118
|
+
return True
|
|
119
|
+
if isinstance(raw_value, str) and raw_value.strip().lower() == "null":
|
|
120
|
+
return True
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def repair_null_equality_filters(intent: RuntimeIntent) -> RuntimeIntent:
|
|
125
|
+
"""Rewrite equality filters against null values into proper IS NULL or IS NOT NULL conditions.
|
|
126
|
+
|
|
127
|
+
When the LLM produces a filter with op '=' and a null value, the SQL column = NULL is rewritten to column IS NULL and similarly '!=' or '<>' against null becomes IS NOT NULL and the change applies to both the main query and all CTE steps.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
|
|
131
|
+
intent: RuntimeIntent to inspect.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
|
|
135
|
+
Updated RuntimeIntent with corrected null operators.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
return _apply_filters_to_main_and_ctes(intent, _repair_null_equality_list)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _repair_null_equality_list(
|
|
142
|
+
filters: list[FilterParam],
|
|
143
|
+
) -> tuple[list[FilterParam], bool]:
|
|
144
|
+
repaired: list[FilterParam] = []
|
|
145
|
+
changed = False
|
|
146
|
+
for fp in filters:
|
|
147
|
+
if fp.op == "=" and _is_null_value(fp.raw_value):
|
|
148
|
+
repaired.append(
|
|
149
|
+
replace(
|
|
150
|
+
fp,
|
|
151
|
+
op="is null",
|
|
152
|
+
raw_value=None,
|
|
153
|
+
value_type="null",
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
changed = True
|
|
157
|
+
elif fp.op in ("!=", "<>") and _is_null_value(fp.raw_value):
|
|
158
|
+
repaired.append(
|
|
159
|
+
replace(
|
|
160
|
+
fp,
|
|
161
|
+
op="is not null",
|
|
162
|
+
raw_value=None,
|
|
163
|
+
value_type="null",
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
changed = True
|
|
167
|
+
else:
|
|
168
|
+
repaired.append(fp)
|
|
169
|
+
return repaired, changed
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _infer_cte_output_columns(cte: Any) -> list[str]:
|
|
173
|
+
"""Derive output column names from a CTE's select_cols.
|
|
174
|
+
|
|
175
|
+
When the LLM omits output_columns this falls back to extracting the trailing column identifier from each select expression and prepends the aggregation function name for aggregated columns to avoid ambiguity.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
|
|
179
|
+
cte: A RuntimeCteStep with populated select_cols.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
|
|
183
|
+
List of bare column-name strings suitable for use as CTE output aliases.
|
|
184
|
+
"""
|
|
185
|
+
names: list[str] = []
|
|
186
|
+
for sc in cte.select_cols or []:
|
|
187
|
+
col = sc.expr.primary_column if sc.expr else ""
|
|
188
|
+
if not col:
|
|
189
|
+
continue
|
|
190
|
+
bare = col.split(".")[-1].strip().lower()
|
|
191
|
+
if sc.is_aggregated and sc.expr.agg_func:
|
|
192
|
+
bare = f"{sc.expr.agg_func.lower()}_{bare}"
|
|
193
|
+
if bare and bare not in names:
|
|
194
|
+
names.append(bare)
|
|
195
|
+
return names
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _qualify_term(term: str, output_to_cte: dict[str, str]) -> str:
|
|
199
|
+
"""Prefix an unqualified column reference with its CTE source name.
|
|
200
|
+
|
|
201
|
+
If term, which is a single MulGroup.multiply or MulGroup.divide entry, contains an unqualified column name that matches a CTE output column, the column portion is rewritten to cte_name.column while already qualified terms containing a dot are returned unchanged, and function-wrapped columns such as SUM(total_amount) are handled by replacing the innermost identifier.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
|
|
205
|
+
term: Raw expression term.
|
|
206
|
+
|
|
207
|
+
output_to_cte: Mapping of lowered bare output column name to the owning CTE name.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
|
|
211
|
+
The possibly rewritten term string.
|
|
212
|
+
"""
|
|
213
|
+
for col_lower, cte_name in output_to_cte.items():
|
|
214
|
+
pat = re.compile(
|
|
215
|
+
r"(?<!\.)(?<![A-Za-z0-9_])" + re.escape(col_lower) + r"(?![A-Za-z0-9_])",
|
|
216
|
+
re.IGNORECASE,
|
|
217
|
+
)
|
|
218
|
+
if pat.search(term):
|
|
219
|
+
term = pat.sub(f"{cte_name}.{col_lower}", term)
|
|
220
|
+
return term
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _qualify_expr(expr: NormalizedExpr, output_to_cte: dict[str, str]) -> NormalizedExpr:
|
|
224
|
+
"""Apply CTE qualification to every term in a NormalizedExpr by rebuilding each MulGroup with qualified multiply and divide terms."""
|
|
225
|
+
|
|
226
|
+
def _fix_group(g: MulGroup) -> MulGroup:
|
|
227
|
+
return replace(
|
|
228
|
+
g,
|
|
229
|
+
multiply=[_qualify_term(m, output_to_cte) for m in g.multiply],
|
|
230
|
+
divide=[_qualify_term(d, output_to_cte) for d in g.divide],
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return replace(
|
|
234
|
+
expr,
|
|
235
|
+
add_groups=[_fix_group(g) for g in expr.add_groups],
|
|
236
|
+
sub_groups=[_fix_group(g) for g in expr.sub_groups],
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def qualify_cte_output_columns(intent: RuntimeIntent) -> RuntimeIntent:
|
|
241
|
+
"""Qualify unqualified column references in the main query that match CTE output columns.
|
|
242
|
+
|
|
243
|
+
When the LLM produces a main-query expression referencing a CTE output column without the CTE-name prefix this repair detects the match and prepends the correct CTE name and only the main query's select_cols, group_by_cols, and order_by_cols are touched because CTE steps reference their own tables rather than other CTE outputs.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
|
|
247
|
+
intent: RuntimeIntent with CTE steps whose output_columns may be referenced without qualification in the main query.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
|
|
251
|
+
Updated RuntimeIntent with qualified main-query expressions or the original intent if nothing changed.
|
|
252
|
+
"""
|
|
253
|
+
cte_steps = intent.cte_steps or []
|
|
254
|
+
if not cte_steps:
|
|
255
|
+
return intent
|
|
256
|
+
|
|
257
|
+
output_to_cte: dict[str, str] = {}
|
|
258
|
+
for cte in cte_steps:
|
|
259
|
+
explicit_outputs = cte.output_columns or []
|
|
260
|
+
if not explicit_outputs:
|
|
261
|
+
explicit_outputs = _infer_cte_output_columns(cte)
|
|
262
|
+
for oc in explicit_outputs:
|
|
263
|
+
bare = oc.split(".")[-1].strip().lower()
|
|
264
|
+
if bare:
|
|
265
|
+
output_to_cte[bare] = cte.cte_name
|
|
266
|
+
if not output_to_cte:
|
|
267
|
+
return intent
|
|
268
|
+
|
|
269
|
+
main_tables = set(intent.tables or [])
|
|
270
|
+
|
|
271
|
+
def _should_skip(term: str) -> bool:
|
|
272
|
+
"""Return True if the term is already qualified with a real
|
|
273
|
+
table."""
|
|
274
|
+
if "." in term:
|
|
275
|
+
prefix = term.split(".")[0].lower()
|
|
276
|
+
return prefix in {t.lower() for t in main_tables}
|
|
277
|
+
return False
|
|
278
|
+
|
|
279
|
+
def _safe_qualify(term: str) -> str:
|
|
280
|
+
if _should_skip(term):
|
|
281
|
+
return term
|
|
282
|
+
return _qualify_term(term, output_to_cte)
|
|
283
|
+
|
|
284
|
+
new_select_cols = [
|
|
285
|
+
(replace(sc, expr=_qualify_expr(sc.expr, output_to_cte)) if not _should_skip(sc.expr.primary_column) else sc)
|
|
286
|
+
for sc in (intent.select_cols or [])
|
|
287
|
+
]
|
|
288
|
+
new_group_by = [
|
|
289
|
+
_qualify_expr(g, output_to_cte) if not _should_skip(g.primary_column) else g
|
|
290
|
+
for g in (intent.group_by_cols or [])
|
|
291
|
+
]
|
|
292
|
+
new_order_by = [
|
|
293
|
+
(
|
|
294
|
+
replace(obc, expr=_qualify_expr(obc.expr, output_to_cte))
|
|
295
|
+
if not _should_skip(obc.expr.primary_column)
|
|
296
|
+
else obc
|
|
297
|
+
)
|
|
298
|
+
for obc in (intent.order_by_cols or [])
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
if (
|
|
302
|
+
new_select_cols == intent.select_cols
|
|
303
|
+
and new_group_by == intent.group_by_cols
|
|
304
|
+
and new_order_by == intent.order_by_cols
|
|
305
|
+
):
|
|
306
|
+
return intent
|
|
307
|
+
|
|
308
|
+
debug("[qualify_cte_output_columns] qualified unqualified CTE output references in main query")
|
|
309
|
+
return replace(
|
|
310
|
+
intent,
|
|
311
|
+
select_cols=new_select_cols,
|
|
312
|
+
group_by_cols=new_group_by,
|
|
313
|
+
order_by_cols=new_order_by,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
DESCRIPTIVE_ALLOWED_VALUE_TYPES = frozenset({"string", "integer"})
|
|
318
|
+
DESCRIPTIVE_EXCLUDED_VALUE_TYPES = frozenset({"date", "boolean", "number"})
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _descriptive_column_score(col_name: str, col_meta: ColumnMetadata) -> tuple[int, int]:
|
|
322
|
+
"""Score a column for use as a descriptive column; higher is better.
|
|
323
|
+
|
|
324
|
+
Prefers name-like columns (name, title, first_name, last_name) and
|
|
325
|
+
higher distinct_count. No maximum cardinality cap.
|
|
326
|
+
"""
|
|
327
|
+
name_lower = col_name.lower()
|
|
328
|
+
name_score = 0
|
|
329
|
+
if "name" in name_lower or "title" in name_lower:
|
|
330
|
+
name_score = 2
|
|
331
|
+
elif "first_name" in name_lower or "last_name" in name_lower:
|
|
332
|
+
name_score = 3
|
|
333
|
+
dc = col_meta.distinct_count or 0
|
|
334
|
+
return (name_score, dc)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def best_descriptive_columns(
|
|
338
|
+
table: str,
|
|
339
|
+
schema_graph: SchemaGraph,
|
|
340
|
+
exclude: set[str],
|
|
341
|
+
max_count: int = 2,
|
|
342
|
+
) -> list[str]:
|
|
343
|
+
"""Return up to *max_count* best descriptive columns for the table.
|
|
344
|
+
|
|
345
|
+
Excludes PK/FK columns and non-string/integer types. Requires
|
|
346
|
+
high individual uniqueness (``distinct_ratio >= 0.95``). When
|
|
347
|
+
*max_count* >= 2 and two name-like candidates exist whose
|
|
348
|
+
composite distinct ratio (profiled during schema loading) exceeds
|
|
349
|
+
the best single-column ratio, both columns are returned.
|
|
350
|
+
"""
|
|
351
|
+
tbl_meta = schema_graph.tables.get(table)
|
|
352
|
+
if not tbl_meta:
|
|
353
|
+
return []
|
|
354
|
+
candidates: list[tuple[str, ColumnMetadata]] = []
|
|
355
|
+
for col_name, col_meta in tbl_meta.columns.items():
|
|
356
|
+
if col_meta.is_primary_key or col_meta.is_foreign_key:
|
|
357
|
+
continue
|
|
358
|
+
if f"{table}.{col_name}" in exclude:
|
|
359
|
+
continue
|
|
360
|
+
vt = (col_meta.value_type or "").lower()
|
|
361
|
+
if vt in DESCRIPTIVE_EXCLUDED_VALUE_TYPES:
|
|
362
|
+
continue
|
|
363
|
+
if vt not in DESCRIPTIVE_ALLOWED_VALUE_TYPES:
|
|
364
|
+
continue
|
|
365
|
+
ratio = col_meta.distinct_ratio
|
|
366
|
+
if ratio is not None and ratio < 0.95:
|
|
367
|
+
continue
|
|
368
|
+
candidates.append((col_name, col_meta))
|
|
369
|
+
if not candidates:
|
|
370
|
+
return []
|
|
371
|
+
candidates.sort(
|
|
372
|
+
key=lambda p: _descriptive_column_score(p[0], p[1]),
|
|
373
|
+
reverse=True,
|
|
374
|
+
)
|
|
375
|
+
if max_count >= 2 and len(candidates) >= 2:
|
|
376
|
+
pair = _best_composite_name_pair(tbl_meta, candidates)
|
|
377
|
+
if pair is not None:
|
|
378
|
+
return list(pair)
|
|
379
|
+
return [col_name for col_name, _ in candidates[:max_count]]
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _best_composite_name_pair(
|
|
383
|
+
tbl_meta: TableMetadata,
|
|
384
|
+
candidates: list[tuple[str, ColumnMetadata]],
|
|
385
|
+
) -> tuple[str, str] | None:
|
|
386
|
+
"""Return a name-like column pair if its composite ratio beats singles.
|
|
387
|
+
|
|
388
|
+
Checks whether any two name-scored candidates have a profiled
|
|
389
|
+
composite distinct ratio that exceeds the best individual
|
|
390
|
+
distinct_ratio among the candidates. Returns ``None`` when no
|
|
391
|
+
such pair is found.
|
|
392
|
+
"""
|
|
393
|
+
name_candidates = [
|
|
394
|
+
(name, meta)
|
|
395
|
+
for name, meta in candidates
|
|
396
|
+
if _descriptive_column_score(name, meta)[0] >= 2
|
|
397
|
+
]
|
|
398
|
+
if len(name_candidates) < 2:
|
|
399
|
+
return None
|
|
400
|
+
best_single_ratio = max(
|
|
401
|
+
(m.distinct_ratio or 0.0) for _, m in candidates
|
|
402
|
+
)
|
|
403
|
+
ratios = tbl_meta.composite_descriptive_ratios
|
|
404
|
+
for i in range(len(name_candidates)):
|
|
405
|
+
for j in range(i + 1, len(name_candidates)):
|
|
406
|
+
c1 = name_candidates[i][0]
|
|
407
|
+
c2 = name_candidates[j][0]
|
|
408
|
+
composite = ratios.get((c1, c2)) or ratios.get((c2, c1))
|
|
409
|
+
if composite is not None and composite > best_single_ratio:
|
|
410
|
+
return (c1, c2)
|
|
411
|
+
return None
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def best_descriptive_column(table: str, schema_graph: SchemaGraph, exclude: set[str]) -> str | None:
|
|
415
|
+
"""Return the best non-PK non-FK descriptive column for the table.
|
|
416
|
+
|
|
417
|
+
Uses best_descriptive_columns with max_count=1. Allows string and integer
|
|
418
|
+
types; excludes PK/FK and decimals, dates, booleans.
|
|
419
|
+
"""
|
|
420
|
+
cols = best_descriptive_columns(table, schema_graph, exclude, max_count=1)
|
|
421
|
+
return cols[0] if cols else None
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _repair_fk_filters(
|
|
425
|
+
filters: list[FilterParam],
|
|
426
|
+
select_cols: list,
|
|
427
|
+
tables: list[str],
|
|
428
|
+
schema_graph: SchemaGraph,
|
|
429
|
+
label: str = "",
|
|
430
|
+
) -> tuple[list[FilterParam], list[str], bool]:
|
|
431
|
+
"""Detect foreign-key filters that should use descriptive columns.
|
|
432
|
+
|
|
433
|
+
Leaves filters unchanged but reports whether any filter targets a
|
|
434
|
+
foreign-key integer column with a string-like value, which should be
|
|
435
|
+
surfaced as a semantic issue for repair rather than rewritten
|
|
436
|
+
deterministically.
|
|
437
|
+
"""
|
|
438
|
+
new_filters: list[FilterParam] = []
|
|
439
|
+
tables = list(tables)
|
|
440
|
+
changed = False
|
|
441
|
+
existing_terms = {sc.expr.primary_term for sc in select_cols or []}
|
|
442
|
+
for fp in filters:
|
|
443
|
+
if fp.value_type not in {"string", "enum"} or fp.raw_value is None:
|
|
444
|
+
new_filters.append(fp)
|
|
445
|
+
continue
|
|
446
|
+
col = fp.left_expr.primary_column
|
|
447
|
+
parts = col.split(".", 1) if "." in col else None
|
|
448
|
+
if not parts:
|
|
449
|
+
new_filters.append(fp)
|
|
450
|
+
continue
|
|
451
|
+
col_meta = schema_graph.get_column(parts[0], parts[1])
|
|
452
|
+
if not col_meta or not col_meta.is_foreign_key or col_meta.value_type not in {"integer", "number"}:
|
|
453
|
+
new_filters.append(fp)
|
|
454
|
+
continue
|
|
455
|
+
fk_target = col_meta.fk_target
|
|
456
|
+
if not fk_target:
|
|
457
|
+
new_filters.append(fp)
|
|
458
|
+
continue
|
|
459
|
+
target_table, _ = fk_target
|
|
460
|
+
desc = best_descriptive_column(target_table, schema_graph, existing_terms)
|
|
461
|
+
new_filters.append(fp)
|
|
462
|
+
if desc:
|
|
463
|
+
changed = True
|
|
464
|
+
debug(f"[intent_resolve.repair_fk_filter_type_mismatch{label}] detected fk filter {col} needing descriptive column")
|
|
465
|
+
return new_filters, tables, changed
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def repair_fk_filter_type_mismatch(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
469
|
+
main_filters, _, main_changed = _repair_fk_filters(
|
|
470
|
+
intent.filters_param or [],
|
|
471
|
+
intent.select_cols or [],
|
|
472
|
+
list(intent.tables or []),
|
|
473
|
+
schema_graph,
|
|
474
|
+
)
|
|
475
|
+
cte_changed = False
|
|
476
|
+
new_cte_steps = []
|
|
477
|
+
for cte in intent.cte_steps or []:
|
|
478
|
+
cte_filters, _, c = _repair_fk_filters(
|
|
479
|
+
cte.filters_param or [],
|
|
480
|
+
cte.select_cols or [],
|
|
481
|
+
list(cte.tables or []),
|
|
482
|
+
schema_graph,
|
|
483
|
+
label=f" CTE '{cte.cte_name}'",
|
|
484
|
+
)
|
|
485
|
+
if c:
|
|
486
|
+
new_cte_steps.append(replace(cte, filters_param=cte_filters))
|
|
487
|
+
cte_changed = True
|
|
488
|
+
else:
|
|
489
|
+
new_cte_steps.append(cte)
|
|
490
|
+
if not main_changed and not cte_changed:
|
|
491
|
+
return intent
|
|
492
|
+
result = intent
|
|
493
|
+
if main_changed:
|
|
494
|
+
result = replace(result, filters_param=main_filters)
|
|
495
|
+
if cte_changed:
|
|
496
|
+
result = replace(result, cte_steps=new_cte_steps)
|
|
497
|
+
return result
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def expand_fk_select_to_descriptive(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
501
|
+
"""Replace foreign-key integer columns in select_cols with the target table's descriptive column.
|
|
502
|
+
|
|
503
|
+
When a SelectCol references a foreign-key column this rewrites it to the foreign-key target table's best descriptive column and adds the target table to intent.tables so join enumeration discovers the path, mirroring repair_fk_filter_type_mismatch but operating on SELECT columns instead of filter conditions and only rewriting bare non-aggregated foreign-key columns.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
|
|
507
|
+
intent: RuntimeIntent whose select_cols may reference foreign-key integer columns.
|
|
508
|
+
schema_graph: SchemaGraph for foreign-key relationship and descriptive column lookups.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
|
|
512
|
+
Updated RuntimeIntent with foreign-key select_cols expanded to descriptive columns.
|
|
513
|
+
"""
|
|
514
|
+
tables = list(intent.tables or [])
|
|
515
|
+
new_select: list[SelectCol] = []
|
|
516
|
+
changed = False
|
|
517
|
+
existing_terms = {sc.expr.primary_term for sc in intent.select_cols or []}
|
|
518
|
+
for sc in intent.select_cols or []:
|
|
519
|
+
if sc.is_aggregated:
|
|
520
|
+
new_select.append(sc)
|
|
521
|
+
continue
|
|
522
|
+
col = sc.expr.primary_column
|
|
523
|
+
parts = col.split(".", 1) if "." in col else None
|
|
524
|
+
if not parts:
|
|
525
|
+
new_select.append(sc)
|
|
526
|
+
continue
|
|
527
|
+
col_meta = schema_graph.get_column(parts[0], parts[1])
|
|
528
|
+
if not col_meta or not col_meta.is_foreign_key or col_meta.value_type not in {"integer", "number"}:
|
|
529
|
+
new_select.append(sc)
|
|
530
|
+
continue
|
|
531
|
+
fk_target = col_meta.fk_target
|
|
532
|
+
if not fk_target:
|
|
533
|
+
new_select.append(sc)
|
|
534
|
+
continue
|
|
535
|
+
target_table, _ = fk_target
|
|
536
|
+
descs = best_descriptive_columns(
|
|
537
|
+
target_table, schema_graph, existing_terms, max_count=2,
|
|
538
|
+
)
|
|
539
|
+
if not descs:
|
|
540
|
+
new_select.append(sc)
|
|
541
|
+
continue
|
|
542
|
+
for desc in descs:
|
|
543
|
+
fq = f"{target_table}.{desc}"
|
|
544
|
+
new_expr = NormalizedExpr.from_column(fq)
|
|
545
|
+
new_select.append(SelectCol(expr=new_expr))
|
|
546
|
+
existing_terms.add(fq)
|
|
547
|
+
if target_table not in tables:
|
|
548
|
+
tables.append(target_table)
|
|
549
|
+
changed = True
|
|
550
|
+
debug(
|
|
551
|
+
f"[intent_resolve.expand_fk_select_to_descriptive] "
|
|
552
|
+
f"rewired select {col} -> {[f'{target_table}.{d}' for d in descs]}"
|
|
553
|
+
)
|
|
554
|
+
if not changed:
|
|
555
|
+
return intent
|
|
556
|
+
return replace(intent, select_cols=new_select, tables=sorted(tables))
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def strip_spurious_group_by(intent: RuntimeIntent) -> RuntimeIntent:
|
|
560
|
+
"""Remove group_by_cols when no aggregation exists in select or having.
|
|
561
|
+
|
|
562
|
+
Guards against LLM hallucinations that produce GROUP BY without any aggregate function in select_cols or having_param, and when group_by_cols are present but neither select nor having contains an aggregate the GROUP BY is stripped and the grain is downgraded to 'row_level' if it was 'grouped' with the same logic applied to each CTE step independently.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
|
|
566
|
+
intent: RuntimeIntent to inspect.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
|
|
570
|
+
Updated RuntimeIntent with group_by_cols cleared when spurious, or the original intent unchanged.
|
|
571
|
+
"""
|
|
572
|
+
main_changed = False
|
|
573
|
+
new_grain = intent.grain
|
|
574
|
+
new_gb = intent.group_by_cols or []
|
|
575
|
+
if intent.group_by_cols:
|
|
576
|
+
has_agg = any(sc.is_aggregated for sc in (intent.select_cols or []))
|
|
577
|
+
has_agg = has_agg or any(hp.left_expr.has_aggregation for hp in (intent.having_param or []))
|
|
578
|
+
if not has_agg:
|
|
579
|
+
debug(
|
|
580
|
+
f"[intent_resolve.strip_spurious_group_by] group_by_cols present without aggregation — stripping {[g.primary_term for g in intent.group_by_cols]}"
|
|
581
|
+
)
|
|
582
|
+
new_grain = "row_level" if intent.grain == "grouped" else intent.grain
|
|
583
|
+
new_gb = []
|
|
584
|
+
main_changed = True
|
|
585
|
+
|
|
586
|
+
new_cte_steps = []
|
|
587
|
+
cte_changed = False
|
|
588
|
+
for cte in intent.cte_steps or []:
|
|
589
|
+
if not (cte.group_by_cols or []):
|
|
590
|
+
new_cte_steps.append(cte)
|
|
591
|
+
continue
|
|
592
|
+
cte_has_agg = any(sc.is_aggregated for sc in (cte.select_cols or []))
|
|
593
|
+
cte_has_agg = cte_has_agg or any(hp.left_expr.has_aggregation for hp in (cte.having_param or []))
|
|
594
|
+
if cte_has_agg:
|
|
595
|
+
new_cte_steps.append(cte)
|
|
596
|
+
continue
|
|
597
|
+
debug(
|
|
598
|
+
f"[intent_resolve.strip_spurious_group_by] CTE '{cte.cte_name}' group_by_cols present without aggregation — stripping {[g.primary_term for g in cte.group_by_cols]}"
|
|
599
|
+
)
|
|
600
|
+
cte_grain = "row_level" if cte.grain == "grouped" else cte.grain
|
|
601
|
+
new_cte_steps.append(replace(cte, group_by_cols=[], grain=cte_grain))
|
|
602
|
+
cte_changed = True
|
|
603
|
+
|
|
604
|
+
if not main_changed and not cte_changed:
|
|
605
|
+
return intent
|
|
606
|
+
return replace(
|
|
607
|
+
intent,
|
|
608
|
+
group_by_cols=new_gb,
|
|
609
|
+
grain=new_grain,
|
|
610
|
+
cte_steps=new_cte_steps if cte_changed else (intent.cte_steps or []),
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def _is_impossible_having(hp: HavingParam) -> bool:
|
|
615
|
+
"""Return True when a HAVING condition is logically impossible.
|
|
616
|
+
|
|
617
|
+
Detects patterns such as ``COUNT(...) < 0`` or ``COUNT(...) <= -1``
|
|
618
|
+
which can never be satisfied. Only applies to COUNT since SUM can
|
|
619
|
+
legitimately produce negative values. Handles both raw-string forms
|
|
620
|
+
(``primary_term`` starting with ``COUNT``) and structured forms
|
|
621
|
+
where ``agg_func`` is stored on the ``MulGroup``.
|
|
622
|
+
|
|
623
|
+
Args: hp: A single HavingParam to inspect.
|
|
624
|
+
|
|
625
|
+
Returns: True if the condition can never be true.
|
|
626
|
+
"""
|
|
627
|
+
left_expr = hp.left_expr
|
|
628
|
+
if not left_expr:
|
|
629
|
+
return False
|
|
630
|
+
primary = left_expr.primary_term
|
|
631
|
+
agg_func = ""
|
|
632
|
+
if left_expr.add_groups and left_expr.add_groups[0].agg_func:
|
|
633
|
+
agg_func = left_expr.add_groups[0].agg_func.upper()
|
|
634
|
+
is_count = bool(IMPOSSIBLE_HAVING_RE.match(primary)) or agg_func == "COUNT"
|
|
635
|
+
if not is_count:
|
|
636
|
+
return False
|
|
637
|
+
op = (hp.op or "").strip().lower()
|
|
638
|
+
val = hp.raw_value
|
|
639
|
+
if val is None:
|
|
640
|
+
return False
|
|
641
|
+
try:
|
|
642
|
+
numeric_val = float(val) if not isinstance(val, (int, float)) else val
|
|
643
|
+
except (ValueError, TypeError):
|
|
644
|
+
return False
|
|
645
|
+
if op in ("<", "<=") and numeric_val <= 0:
|
|
646
|
+
return True
|
|
647
|
+
if op == "=" and numeric_val < 0:
|
|
648
|
+
return True
|
|
649
|
+
return False
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def strip_impossible_having(intent: RuntimeIntent) -> RuntimeIntent:
|
|
653
|
+
"""Remove HAVING conditions that are logically impossible.
|
|
654
|
+
|
|
655
|
+
Filters out conditions like ``COUNT(...) < 0`` which can never be
|
|
656
|
+
satisfied. When all HAVING params are removed and the intent was
|
|
657
|
+
``"grouped"`` with no remaining aggregation need, the grain is
|
|
658
|
+
downgraded.
|
|
659
|
+
|
|
660
|
+
Applies the same logic to each CTE step independently.
|
|
661
|
+
|
|
662
|
+
Args: intent: RuntimeIntent to inspect.
|
|
663
|
+
|
|
664
|
+
Returns: Updated RuntimeIntent with impossible HAVING params
|
|
665
|
+
removed, or the original intent unchanged.
|
|
666
|
+
"""
|
|
667
|
+
main_having = intent.having_param or []
|
|
668
|
+
kept_main = [hp for hp in main_having if not _is_impossible_having(hp)]
|
|
669
|
+
main_changed = len(kept_main) != len(main_having)
|
|
670
|
+
if main_changed:
|
|
671
|
+
removed = len(main_having) - len(kept_main)
|
|
672
|
+
debug(f"[strip_impossible_having] removed {removed} impossible HAVING condition(s)")
|
|
673
|
+
|
|
674
|
+
new_cte_steps = []
|
|
675
|
+
cte_changed = False
|
|
676
|
+
for cte in intent.cte_steps or []:
|
|
677
|
+
cte_having = cte.having_param or []
|
|
678
|
+
kept_cte = [hp for hp in cte_having if not _is_impossible_having(hp)]
|
|
679
|
+
if len(kept_cte) != len(cte_having):
|
|
680
|
+
cte_changed = True
|
|
681
|
+
new_cte_steps.append(replace(cte, having_param=kept_cte))
|
|
682
|
+
else:
|
|
683
|
+
new_cte_steps.append(cte)
|
|
684
|
+
|
|
685
|
+
if not main_changed and not cte_changed:
|
|
686
|
+
return intent
|
|
687
|
+
return replace(
|
|
688
|
+
intent,
|
|
689
|
+
having_param=kept_main,
|
|
690
|
+
cte_steps=new_cte_steps if cte_changed else (intent.cte_steps or []),
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def sanitize_table_names(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
695
|
+
"""Strip SQL keywords accidentally prepended to table names.
|
|
696
|
+
|
|
697
|
+
LLMs sometimes hallucinate table references like ``"FROM orders"``
|
|
698
|
+
or ``"JOIN products"`` instead of bare ``"orders"`` /
|
|
699
|
+
``"products"``. This function strips leading tokens that match SQL
|
|
700
|
+
keywords, keeping only the trailing word if it matches a known
|
|
701
|
+
schema table.
|
|
702
|
+
|
|
703
|
+
Args: intent: RuntimeIntent whose ``tables`` list may contain
|
|
704
|
+
polluted names. schema_graph: SchemaGraph providing the set of
|
|
705
|
+
valid table names.
|
|
706
|
+
|
|
707
|
+
Returns: Updated RuntimeIntent with sanitized table names, or
|
|
708
|
+
the original intent when no changes are needed.
|
|
709
|
+
"""
|
|
710
|
+
valid_tables = {t.lower(): t for t in schema_graph.tables}
|
|
711
|
+
new_tables: list[str] = []
|
|
712
|
+
changed = False
|
|
713
|
+
for tbl in intent.tables or []:
|
|
714
|
+
if tbl.lower() in valid_tables:
|
|
715
|
+
new_tables.append(tbl)
|
|
716
|
+
continue
|
|
717
|
+
parts = tbl.split()
|
|
718
|
+
candidate = parts[-1].lower() if parts else ""
|
|
719
|
+
if candidate in valid_tables and any(p.lower() in SQL_KEYWORDS for p in parts[:-1]):
|
|
720
|
+
debug(f"[sanitize_table_names] corrected '{tbl}' → '{valid_tables[candidate]}'")
|
|
721
|
+
new_tables.append(valid_tables[candidate])
|
|
722
|
+
changed = True
|
|
723
|
+
else:
|
|
724
|
+
new_tables.append(tbl)
|
|
725
|
+
|
|
726
|
+
if not changed:
|
|
727
|
+
return intent
|
|
728
|
+
return replace(intent, tables=new_tables)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def _strip_join_condition_filters(filters: list[FilterParam], schema_graph: SchemaGraph) -> list[FilterParam]:
|
|
732
|
+
"""Remove FilterParam entries that are FK equi-join conditions.
|
|
733
|
+
|
|
734
|
+
An equi-join condition is an equality filter (op '=') between two
|
|
735
|
+
fully-qualified columns that match a known FK edge in the schema (in
|
|
736
|
+
either direction).
|
|
737
|
+
|
|
738
|
+
Args: filters: List of FilterParam objects to process.
|
|
739
|
+
schema_graph: SchemaGraph providing FK edge definitions.
|
|
740
|
+
|
|
741
|
+
Returns: Filtered list with FK join conditions removed.
|
|
742
|
+
"""
|
|
743
|
+
fk_pairs: set[tuple[str, str]] = set()
|
|
744
|
+
for tbl in schema_graph.tables.values():
|
|
745
|
+
for fk in tbl.foreign_keys:
|
|
746
|
+
if len(fk.src_cols) == 1 and len(fk.dst_cols) == 1:
|
|
747
|
+
left = f"{fk.src_table}.{fk.src_cols[0]}"
|
|
748
|
+
right = f"{fk.dst_table}.{fk.dst_cols[0]}"
|
|
749
|
+
fk_pairs.add((left, right))
|
|
750
|
+
fk_pairs.add((right, left))
|
|
751
|
+
result: list[FilterParam] = []
|
|
752
|
+
for fp in filters:
|
|
753
|
+
if fp.right_expr is None or fp.op != "=":
|
|
754
|
+
result.append(fp)
|
|
755
|
+
continue
|
|
756
|
+
left_term = fp.left_expr.primary_term
|
|
757
|
+
right_term = fp.right_expr.primary_term
|
|
758
|
+
if (left_term, right_term) in fk_pairs:
|
|
759
|
+
debug(f"[intent_resolve.strip_join_condition_filters] dropping FK join filter: {left_term} = {right_term}")
|
|
760
|
+
continue
|
|
761
|
+
result.append(fp)
|
|
762
|
+
return result
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def strip_join_conditions(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
766
|
+
"""Remove FK equi-join filters from the main query and all CTE
|
|
767
|
+
steps.
|
|
768
|
+
|
|
769
|
+
Args: intent: RuntimeIntent whose filter lists should be
|
|
770
|
+
stripped. schema_graph: SchemaGraph providing FK edge
|
|
771
|
+
definitions.
|
|
772
|
+
|
|
773
|
+
Returns: New RuntimeIntent with FK join conditions removed.
|
|
774
|
+
"""
|
|
775
|
+
new_filters = _strip_join_condition_filters(intent.filters_param or [], schema_graph)
|
|
776
|
+
new_cte_steps = [
|
|
777
|
+
replace(
|
|
778
|
+
cte,
|
|
779
|
+
filters_param=_strip_join_condition_filters(cte.filters_param or [], schema_graph),
|
|
780
|
+
)
|
|
781
|
+
for cte in (intent.cte_steps or [])
|
|
782
|
+
]
|
|
783
|
+
return replace(intent, filters_param=new_filters, cte_steps=new_cte_steps)
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def _is_pk_column(col_ref: str, schema_graph: SchemaGraph) -> bool:
|
|
787
|
+
"""Check whether a fully-qualified column reference points to a
|
|
788
|
+
primary key.
|
|
789
|
+
|
|
790
|
+
Args: col_ref: Column reference string in 'table.column' format.
|
|
791
|
+
schema_graph: SchemaGraph for metadata lookups.
|
|
792
|
+
|
|
793
|
+
Returns: True when the referenced column is marked as a primary
|
|
794
|
+
key.
|
|
795
|
+
"""
|
|
796
|
+
if "." not in col_ref:
|
|
797
|
+
return False
|
|
798
|
+
tbl, col = col_ref.split(".", 1)
|
|
799
|
+
tbl_meta = schema_graph.tables.get(tbl)
|
|
800
|
+
if not tbl_meta:
|
|
801
|
+
return False
|
|
802
|
+
col_meta = tbl_meta.columns.get(col)
|
|
803
|
+
return col_meta.is_primary_key if col_meta else False
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def _strip_distinct_prefix(term: str) -> str:
|
|
807
|
+
"""Remove a DISTINCT prefix that leaked into an expression term.
|
|
808
|
+
|
|
809
|
+
Args: term: Expression term string that may start with 'DISTINCT
|
|
810
|
+
'.
|
|
811
|
+
|
|
812
|
+
Returns: Term string with the prefix removed, or the original
|
|
813
|
+
string if absent.
|
|
814
|
+
"""
|
|
815
|
+
if term.upper().startswith("DISTINCT "):
|
|
816
|
+
return term[9:].strip()
|
|
817
|
+
return term
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def _normalize_sc_pk_distinct(sc: SelectCol, schema_graph: SchemaGraph) -> SelectCol:
|
|
821
|
+
"""Strip redundant DISTINCT from a single SelectCol if aggregated
|
|
822
|
+
column is a PK."""
|
|
823
|
+
e = sc.expr
|
|
824
|
+
agg = e.agg_func or (e.add_groups[0].agg_func if e.add_groups and e.add_groups[0].agg_func else None)
|
|
825
|
+
if agg != "count":
|
|
826
|
+
return sc
|
|
827
|
+
term = e.primary_term
|
|
828
|
+
clean_term = _strip_distinct_prefix(term)
|
|
829
|
+
if not _is_pk_column(clean_term, schema_graph):
|
|
830
|
+
return sc
|
|
831
|
+
needs_term_fix = clean_term != term
|
|
832
|
+
if not needs_term_fix:
|
|
833
|
+
return sc
|
|
834
|
+
new_groups = list(e.add_groups)
|
|
835
|
+
if new_groups and needs_term_fix:
|
|
836
|
+
g = new_groups[0]
|
|
837
|
+
new_mul = [clean_term if _strip_distinct_prefix(m) == clean_term else m for m in g.multiply]
|
|
838
|
+
new_groups[0] = replace(g, multiply=new_mul)
|
|
839
|
+
new_expr = replace(e, add_groups=new_groups)
|
|
840
|
+
debug(f"[normalize_pk_distinct] stripped DISTINCT prefix from PK term: {term} → {clean_term}")
|
|
841
|
+
return replace(sc, expr=new_expr)
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
def normalize_pk_distinct(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
845
|
+
"""Strip redundant DISTINCT from COUNT expressions on primary key
|
|
846
|
+
columns.
|
|
847
|
+
|
|
848
|
+
COUNT(DISTINCT pk) is semantically equivalent to COUNT(pk) for
|
|
849
|
+
primary keys. Removes the DISTINCT prefix from affected select
|
|
850
|
+
columns.
|
|
851
|
+
|
|
852
|
+
Args: intent: RuntimeIntent to normalize. schema_graph:
|
|
853
|
+
SchemaGraph for PK lookups.
|
|
854
|
+
|
|
855
|
+
Returns: New RuntimeIntent with redundant DISTINCT removed.
|
|
856
|
+
"""
|
|
857
|
+
new_select = [_normalize_sc_pk_distinct(sc, schema_graph) for sc in (intent.select_cols or [])]
|
|
858
|
+
new_cte_steps = []
|
|
859
|
+
for cte in intent.cte_steps or []:
|
|
860
|
+
cte_select = [_normalize_sc_pk_distinct(sc, schema_graph) for sc in (cte.select_cols or [])]
|
|
861
|
+
new_cte_steps.append(replace(cte, select_cols=cte_select))
|
|
862
|
+
return replace(intent, select_cols=new_select, cte_steps=new_cte_steps)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def _tables_from_columns(cols: list[str]) -> set[str]:
|
|
866
|
+
"""Extract unique table names from a list of fully-qualified column
|
|
867
|
+
references.
|
|
868
|
+
|
|
869
|
+
Args: cols: List of column reference strings in 'table.column'
|
|
870
|
+
format.
|
|
871
|
+
|
|
872
|
+
Returns: Set of table name strings found before the '.'
|
|
873
|
+
separator.
|
|
874
|
+
"""
|
|
875
|
+
tables: set[str] = set()
|
|
876
|
+
for col in cols:
|
|
877
|
+
if "." in col:
|
|
878
|
+
tables.add(col.split(".")[0])
|
|
879
|
+
return tables
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def _collect_referenced_tables(
|
|
883
|
+
select_cols: list,
|
|
884
|
+
order_by_cols: list,
|
|
885
|
+
group_by_cols: list,
|
|
886
|
+
filters_param: list,
|
|
887
|
+
having_param: list,
|
|
888
|
+
) -> set[str]:
|
|
889
|
+
"""Collect all table names referenced in expressions across the
|
|
890
|
+
given clause lists.
|
|
891
|
+
|
|
892
|
+
Args: select_cols: SelectCol list. order_by_cols: OrderByCol
|
|
893
|
+
list. group_by_cols: NormalizedExpr list. filters_param:
|
|
894
|
+
FilterParam list. having_param: HavingParam list.
|
|
895
|
+
|
|
896
|
+
Returns: Set of table names extracted from fully-qualified
|
|
897
|
+
column references.
|
|
898
|
+
"""
|
|
899
|
+
all_cols: list[str] = []
|
|
900
|
+
for sc in select_cols or []:
|
|
901
|
+
all_cols.extend(extract_columns_from_expr(sc.expr))
|
|
902
|
+
for obc in order_by_cols or []:
|
|
903
|
+
all_cols.extend(extract_columns_from_expr(obc.expr))
|
|
904
|
+
for g in group_by_cols or []:
|
|
905
|
+
all_cols.extend(extract_columns_from_expr(g))
|
|
906
|
+
for fp in filters_param or []:
|
|
907
|
+
all_cols.extend(extract_columns_from_expr(fp.left_expr))
|
|
908
|
+
if fp.right_expr:
|
|
909
|
+
all_cols.extend(extract_columns_from_expr(fp.right_expr))
|
|
910
|
+
for hp in having_param or []:
|
|
911
|
+
all_cols.extend(extract_columns_from_expr(hp.left_expr))
|
|
912
|
+
if hp.right_expr:
|
|
913
|
+
all_cols.extend(extract_columns_from_expr(hp.right_expr))
|
|
914
|
+
return _tables_from_columns(all_cols)
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def _collect_essential_tables(
|
|
918
|
+
order_by_cols: list,
|
|
919
|
+
group_by_cols: list,
|
|
920
|
+
filters_param: list,
|
|
921
|
+
having_param: list,
|
|
922
|
+
) -> set[str]:
|
|
923
|
+
"""Collect tables referenced in non-select clauses.
|
|
924
|
+
|
|
925
|
+
These tables are essential for query semantics (filtering, grouping,
|
|
926
|
+
ordering, having) and must never be pruned.
|
|
927
|
+
"""
|
|
928
|
+
all_cols: list[str] = []
|
|
929
|
+
for obc in order_by_cols or []:
|
|
930
|
+
all_cols.extend(extract_columns_from_expr(obc.expr))
|
|
931
|
+
for g in group_by_cols or []:
|
|
932
|
+
all_cols.extend(extract_columns_from_expr(g))
|
|
933
|
+
for fp in filters_param or []:
|
|
934
|
+
all_cols.extend(extract_columns_from_expr(fp.left_expr))
|
|
935
|
+
if fp.right_expr:
|
|
936
|
+
all_cols.extend(extract_columns_from_expr(fp.right_expr))
|
|
937
|
+
for hp in having_param or []:
|
|
938
|
+
all_cols.extend(extract_columns_from_expr(hp.left_expr))
|
|
939
|
+
if hp.right_expr:
|
|
940
|
+
all_cols.extend(extract_columns_from_expr(hp.right_expr))
|
|
941
|
+
return _tables_from_columns(all_cols)
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
def _find_fk_column_for_pk(
|
|
945
|
+
parent_table: str,
|
|
946
|
+
pk_column: str,
|
|
947
|
+
candidate_tables: set[str],
|
|
948
|
+
schema_graph: SchemaGraph,
|
|
949
|
+
) -> str | None:
|
|
950
|
+
"""Find an FK column on a *candidate_table* that references
|
|
951
|
+
*parent_table*.*pk_column*.
|
|
952
|
+
|
|
953
|
+
Returns the fully qualified FK column or ``None`` when no candidate
|
|
954
|
+
holds a matching foreign key.
|
|
955
|
+
"""
|
|
956
|
+
target_key = (parent_table.lower(), pk_column.lower())
|
|
957
|
+
for tbl in candidate_tables:
|
|
958
|
+
tbl_meta = schema_graph.tables.get(tbl)
|
|
959
|
+
if not tbl_meta:
|
|
960
|
+
continue
|
|
961
|
+
for col_name, col_meta in tbl_meta.columns.items():
|
|
962
|
+
if not col_meta.is_foreign_key or not col_meta.fk_target:
|
|
963
|
+
continue
|
|
964
|
+
fk_tgt = (col_meta.fk_target[0].lower(), col_meta.fk_target[1].lower())
|
|
965
|
+
if fk_tgt == target_key:
|
|
966
|
+
return f"{tbl}.{col_name}"
|
|
967
|
+
return None
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
def _rewrite_redundant_pk_aggregations(
|
|
971
|
+
select_cols: list[SelectCol],
|
|
972
|
+
select_only_tables: set[str],
|
|
973
|
+
essential_tables: set[str],
|
|
974
|
+
schema_graph: SchemaGraph,
|
|
975
|
+
all_intent_tables: set[str] | None = None,
|
|
976
|
+
) -> tuple[list[SelectCol], set[str]]:
|
|
977
|
+
"""Rewrite aggregations on a parent PK to the child FK column,
|
|
978
|
+
eliminating the need for the parent table.
|
|
979
|
+
|
|
980
|
+
When a select-only table contributes only aggregated columns on its
|
|
981
|
+
primary key, and another intent table has an FK pointing to that PK,
|
|
982
|
+
the aggregation is rewritten to use the FK column. If another
|
|
983
|
+
aggregation on a different table already exists, the redundant
|
|
984
|
+
column is dropped instead of rewritten.
|
|
985
|
+
|
|
986
|
+
*all_intent_tables*, when provided, is used as the candidate set for
|
|
987
|
+
FK lookup so that bridge tables not yet referenced by any expression
|
|
988
|
+
are still discoverable.
|
|
989
|
+
|
|
990
|
+
Returns the updated select_cols and a set of tables whose references
|
|
991
|
+
were fully eliminated by rewriting.
|
|
992
|
+
"""
|
|
993
|
+
eliminated: set[str] = set()
|
|
994
|
+
new_select: list[SelectCol] = list(select_cols)
|
|
995
|
+
|
|
996
|
+
for tbl in list(select_only_tables):
|
|
997
|
+
tbl_meta = schema_graph.tables.get(tbl)
|
|
998
|
+
if not tbl_meta:
|
|
999
|
+
continue
|
|
1000
|
+
|
|
1001
|
+
pk_col: str | None = None
|
|
1002
|
+
for col_name, col_meta in tbl_meta.columns.items():
|
|
1003
|
+
if col_meta.is_primary_key:
|
|
1004
|
+
pk_col = col_name
|
|
1005
|
+
break
|
|
1006
|
+
if not pk_col:
|
|
1007
|
+
continue
|
|
1008
|
+
|
|
1009
|
+
prefix = f"{tbl}.".lower()
|
|
1010
|
+
tbl_indices: list[int] = []
|
|
1011
|
+
all_agg_pk = True
|
|
1012
|
+
for idx, sc in enumerate(new_select):
|
|
1013
|
+
cols = extract_columns_from_expr(sc.expr)
|
|
1014
|
+
refs_tbl = any(c.lower().startswith(prefix) for c in cols)
|
|
1015
|
+
if not refs_tbl:
|
|
1016
|
+
continue
|
|
1017
|
+
tbl_indices.append(idx)
|
|
1018
|
+
col_ref = sc.expr.primary_column.lower()
|
|
1019
|
+
if not sc.is_aggregated or col_ref != f"{tbl}.{pk_col}".lower():
|
|
1020
|
+
all_agg_pk = False
|
|
1021
|
+
break
|
|
1022
|
+
|
|
1023
|
+
if not tbl_indices or not all_agg_pk:
|
|
1024
|
+
continue
|
|
1025
|
+
|
|
1026
|
+
candidate_pool = (all_intent_tables or set()) | essential_tables | select_only_tables
|
|
1027
|
+
other_tables = candidate_pool - {tbl}
|
|
1028
|
+
fk_col = _find_fk_column_for_pk(tbl, pk_col, other_tables, schema_graph)
|
|
1029
|
+
if not fk_col:
|
|
1030
|
+
continue
|
|
1031
|
+
|
|
1032
|
+
other_has_agg = any(sc.is_aggregated for idx, sc in enumerate(new_select) if idx not in tbl_indices)
|
|
1033
|
+
|
|
1034
|
+
rewritten: list[SelectCol] = []
|
|
1035
|
+
for idx, sc in enumerate(new_select):
|
|
1036
|
+
if idx not in tbl_indices:
|
|
1037
|
+
rewritten.append(sc)
|
|
1038
|
+
continue
|
|
1039
|
+
if other_has_agg:
|
|
1040
|
+
debug(f"[prune_unreferenced_tables] dropping redundant agg {sc.expr.primary_term} (other agg exists)")
|
|
1041
|
+
continue
|
|
1042
|
+
agg_func = sc.expr.agg_func or (sc.expr.add_groups[0].agg_func if sc.expr.add_groups else "")
|
|
1043
|
+
rewritten.append(SelectCol(expr=NormalizedExpr.from_agg(agg_func, fk_col)))
|
|
1044
|
+
debug(f"[prune_unreferenced_tables] rewrote {sc.expr.primary_term} → {agg_func}({fk_col})")
|
|
1045
|
+
|
|
1046
|
+
new_select = rewritten
|
|
1047
|
+
eliminated.add(tbl)
|
|
1048
|
+
|
|
1049
|
+
return new_select, eliminated
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def requalify_redundant_pk_references(
|
|
1053
|
+
intent: RuntimeIntent,
|
|
1054
|
+
schema_graph: SchemaGraph,
|
|
1055
|
+
) -> RuntimeIntent:
|
|
1056
|
+
"""Rewrite target-table PK references to source-table FK
|
|
1057
|
+
equivalents when the PK table contributes no other columns.
|
|
1058
|
+
|
|
1059
|
+
When the LLM places ``target.pk_col`` in ``group_by_cols`` or
|
|
1060
|
+
non-aggregated ``select_cols`` and another intent table holds an
|
|
1061
|
+
FK pointing to that PK, the reference is rewritten to
|
|
1062
|
+
``source.fk_col``. This eliminates an unnecessary join to the
|
|
1063
|
+
target table.
|
|
1064
|
+
|
|
1065
|
+
Aggregated expressions and ``having_param`` are never touched
|
|
1066
|
+
because PK/FK usage inside aggregation functions (e.g.
|
|
1067
|
+
``COUNT(table.pk)``) is intentional.
|
|
1068
|
+
"""
|
|
1069
|
+
if not schema_graph:
|
|
1070
|
+
return intent
|
|
1071
|
+
|
|
1072
|
+
all_cols: list[str] = []
|
|
1073
|
+
for sc in intent.select_cols or []:
|
|
1074
|
+
all_cols.extend(extract_columns_from_expr(sc.expr))
|
|
1075
|
+
for obc in intent.order_by_cols or []:
|
|
1076
|
+
all_cols.extend(extract_columns_from_expr(obc.expr))
|
|
1077
|
+
for g in intent.group_by_cols or []:
|
|
1078
|
+
all_cols.extend(extract_columns_from_expr(g))
|
|
1079
|
+
for fp in intent.filters_param or []:
|
|
1080
|
+
all_cols.extend(extract_columns_from_expr(fp.left_expr))
|
|
1081
|
+
if fp.right_expr:
|
|
1082
|
+
all_cols.extend(extract_columns_from_expr(fp.right_expr))
|
|
1083
|
+
for hp in intent.having_param or []:
|
|
1084
|
+
all_cols.extend(extract_columns_from_expr(hp.left_expr))
|
|
1085
|
+
if hp.right_expr:
|
|
1086
|
+
all_cols.extend(extract_columns_from_expr(hp.right_expr))
|
|
1087
|
+
|
|
1088
|
+
col_counts: dict[str, int] = {}
|
|
1089
|
+
for col_ref in all_cols:
|
|
1090
|
+
tbl = col_ref.split(".")[0] if "." in col_ref else ""
|
|
1091
|
+
if tbl:
|
|
1092
|
+
col_counts[tbl] = col_counts.get(tbl, 0) + 1
|
|
1093
|
+
|
|
1094
|
+
intent_tables = set(intent.tables or [])
|
|
1095
|
+
fk_lookup: dict[str, tuple[str, str]] = {}
|
|
1096
|
+
for src_table_name in intent_tables:
|
|
1097
|
+
src_table = schema_graph.tables.get(src_table_name)
|
|
1098
|
+
if not src_table:
|
|
1099
|
+
continue
|
|
1100
|
+
for fk in src_table.foreign_keys:
|
|
1101
|
+
dst_table = fk.dst_table
|
|
1102
|
+
if dst_table not in intent_tables:
|
|
1103
|
+
continue
|
|
1104
|
+
for src_col, dst_col in zip(fk.src_cols, fk.dst_cols, strict=False):
|
|
1105
|
+
pk_ref = f"{dst_table}.{dst_col}"
|
|
1106
|
+
fk_ref = f"{src_table_name}.{src_col}"
|
|
1107
|
+
if pk_ref not in fk_lookup:
|
|
1108
|
+
fk_lookup[pk_ref] = (fk_ref, dst_table)
|
|
1109
|
+
|
|
1110
|
+
rewrite_map: dict[str, str] = {}
|
|
1111
|
+
for pk_ref, (fk_ref, pk_table) in fk_lookup.items():
|
|
1112
|
+
if col_counts.get(pk_table, 0) <= 1:
|
|
1113
|
+
rewrite_map[pk_ref] = fk_ref
|
|
1114
|
+
|
|
1115
|
+
if not rewrite_map:
|
|
1116
|
+
return intent
|
|
1117
|
+
|
|
1118
|
+
debug(f"[requalify_redundant_pk_references] rewrite_map: {rewrite_map}")
|
|
1119
|
+
|
|
1120
|
+
def _rewrite_col(col_ref: str) -> str:
|
|
1121
|
+
return rewrite_map.get(col_ref, col_ref)
|
|
1122
|
+
|
|
1123
|
+
def _rewrite_expr(expr: NormalizedExpr) -> NormalizedExpr:
|
|
1124
|
+
return replace_refs_in_expr(expr, _rewrite_col)
|
|
1125
|
+
|
|
1126
|
+
new_group_by = [
|
|
1127
|
+
_rewrite_expr(g) for g in (intent.group_by_cols or [])
|
|
1128
|
+
]
|
|
1129
|
+
|
|
1130
|
+
new_select_cols = []
|
|
1131
|
+
for sc in intent.select_cols or []:
|
|
1132
|
+
if sc.is_aggregated:
|
|
1133
|
+
new_select_cols.append(sc)
|
|
1134
|
+
else:
|
|
1135
|
+
new_select_cols.append(replace(sc, expr=_rewrite_expr(sc.expr)))
|
|
1136
|
+
|
|
1137
|
+
return replace(
|
|
1138
|
+
intent,
|
|
1139
|
+
select_cols=new_select_cols,
|
|
1140
|
+
group_by_cols=new_group_by,
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
|
|
1144
|
+
def prune_unreferenced_tables(
|
|
1145
|
+
intent: RuntimeIntent,
|
|
1146
|
+
schema_graph: SchemaGraph | None = None,
|
|
1147
|
+
) -> RuntimeIntent:
|
|
1148
|
+
"""Synchronize the tables list with tables actually referenced in
|
|
1149
|
+
expressions.
|
|
1150
|
+
|
|
1151
|
+
Any table referenced in any clause (select, filter, group_by,
|
|
1152
|
+
having, order_by) is kept. Tables present in the intent but not
|
|
1153
|
+
referenced anywhere are removed. Missing referenced tables are
|
|
1154
|
+
added.
|
|
1155
|
+
|
|
1156
|
+
Redundant PK aggregation columns are rewritten to their FK
|
|
1157
|
+
equivalents when possible, which may eliminate a table reference
|
|
1158
|
+
and thus the table itself.
|
|
1159
|
+
|
|
1160
|
+
The same synchronization is applied to each CTE step
|
|
1161
|
+
independently.
|
|
1162
|
+
"""
|
|
1163
|
+
cte_names = {cte.cte_name for cte in (intent.cte_steps or [])}
|
|
1164
|
+
select_cols = list(intent.select_cols or [])
|
|
1165
|
+
referenced = (
|
|
1166
|
+
_collect_referenced_tables(
|
|
1167
|
+
select_cols,
|
|
1168
|
+
intent.order_by_cols,
|
|
1169
|
+
intent.group_by_cols,
|
|
1170
|
+
intent.filters_param,
|
|
1171
|
+
intent.having_param,
|
|
1172
|
+
)
|
|
1173
|
+
| cte_names
|
|
1174
|
+
)
|
|
1175
|
+
essential = (
|
|
1176
|
+
_collect_essential_tables(
|
|
1177
|
+
intent.order_by_cols,
|
|
1178
|
+
intent.group_by_cols,
|
|
1179
|
+
intent.filters_param,
|
|
1180
|
+
intent.having_param,
|
|
1181
|
+
)
|
|
1182
|
+
| cte_names
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
select_only = referenced - essential
|
|
1186
|
+
|
|
1187
|
+
if schema_graph and select_only:
|
|
1188
|
+
all_intent = set(intent.tables or [])
|
|
1189
|
+
select_cols, eliminated = _rewrite_redundant_pk_aggregations(
|
|
1190
|
+
select_cols,
|
|
1191
|
+
select_only,
|
|
1192
|
+
essential,
|
|
1193
|
+
schema_graph,
|
|
1194
|
+
all_intent_tables=all_intent,
|
|
1195
|
+
)
|
|
1196
|
+
if eliminated:
|
|
1197
|
+
referenced = (
|
|
1198
|
+
_collect_referenced_tables(
|
|
1199
|
+
select_cols,
|
|
1200
|
+
intent.order_by_cols,
|
|
1201
|
+
intent.group_by_cols,
|
|
1202
|
+
intent.filters_param,
|
|
1203
|
+
intent.having_param,
|
|
1204
|
+
)
|
|
1205
|
+
| cte_names
|
|
1206
|
+
)
|
|
1207
|
+
select_only = referenced - essential
|
|
1208
|
+
|
|
1209
|
+
kept_tables = referenced
|
|
1210
|
+
new_select_cols = select_cols
|
|
1211
|
+
|
|
1212
|
+
original = set(intent.tables or [])
|
|
1213
|
+
added = (kept_tables - original) - cte_names
|
|
1214
|
+
removed = original - kept_tables
|
|
1215
|
+
main_tables = sorted(kept_tables)
|
|
1216
|
+
if added:
|
|
1217
|
+
debug(f"[prune_unreferenced_tables] added {sorted(added)} to tables")
|
|
1218
|
+
if removed:
|
|
1219
|
+
debug(f"[prune_unreferenced_tables] removed {sorted(removed)} from tables")
|
|
1220
|
+
if added or removed:
|
|
1221
|
+
debug(f"[prune_unreferenced_tables] final tables: {main_tables}")
|
|
1222
|
+
|
|
1223
|
+
new_cte_steps = []
|
|
1224
|
+
for cte in intent.cte_steps or []:
|
|
1225
|
+
cte_referenced = (
|
|
1226
|
+
_collect_referenced_tables(
|
|
1227
|
+
cte.select_cols,
|
|
1228
|
+
cte.order_by_cols,
|
|
1229
|
+
cte.group_by_cols,
|
|
1230
|
+
cte.filters_param,
|
|
1231
|
+
cte.having_param,
|
|
1232
|
+
)
|
|
1233
|
+
| cte_names
|
|
1234
|
+
)
|
|
1235
|
+
cte_original = set(cte.tables or [])
|
|
1236
|
+
cte_added = (cte_referenced - cte_original) - cte_names
|
|
1237
|
+
cte_removed = cte_original - cte_referenced
|
|
1238
|
+
cte_tables = sorted(cte_referenced)
|
|
1239
|
+
if cte_added:
|
|
1240
|
+
debug(f"[prune_unreferenced_tables] CTE '{cte.cte_name}' added {sorted(cte_added)} to tables")
|
|
1241
|
+
if cte_removed:
|
|
1242
|
+
debug(f"[prune_unreferenced_tables] CTE '{cte.cte_name}' removed {sorted(cte_removed)} from tables")
|
|
1243
|
+
new_cte_steps.append(replace(cte, tables=cte_tables))
|
|
1244
|
+
|
|
1245
|
+
return replace(
|
|
1246
|
+
intent,
|
|
1247
|
+
tables=main_tables,
|
|
1248
|
+
select_cols=new_select_cols,
|
|
1249
|
+
cte_steps=new_cte_steps,
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
|
|
1253
|
+
def _correct_value_case(raw_value: str, top_k: list[str]) -> str | None:
|
|
1254
|
+
"""Return the case-corrected version of a filter value using
|
|
1255
|
+
profiled sample values.
|
|
1256
|
+
|
|
1257
|
+
Performs a case-insensitive comparison of ``raw_value`` against each
|
|
1258
|
+
entry in ``top_k``. When a match is found whose casing differs from
|
|
1259
|
+
the original, the sample value is returned so the filter uses the
|
|
1260
|
+
casing that actually appears in the database.
|
|
1261
|
+
|
|
1262
|
+
Args: raw_value: The filter value string extracted from the user
|
|
1263
|
+
question. top_k: Profiled sample values
|
|
1264
|
+
(``ColumnMetadata.top_k_values``) for the column being
|
|
1265
|
+
filtered.
|
|
1266
|
+
|
|
1267
|
+
Returns: The matching sample string with correct casing, or
|
|
1268
|
+
``None`` when no case-insensitive match is found or the casing
|
|
1269
|
+
already matches.
|
|
1270
|
+
"""
|
|
1271
|
+
if not raw_value or not top_k:
|
|
1272
|
+
return None
|
|
1273
|
+
lower_val = raw_value.lower()
|
|
1274
|
+
for sample in top_k:
|
|
1275
|
+
if sample.lower() == lower_val and sample != raw_value:
|
|
1276
|
+
return sample
|
|
1277
|
+
return None
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
def _match_enum_value(raw_value: str, col_meta: ColumnMetadata, schema_graph: SchemaGraph) -> str | None:
|
|
1281
|
+
"""Case-insensitive match of *raw_value* against the column's enum
|
|
1282
|
+
type values.
|
|
1283
|
+
|
|
1284
|
+
Looks up ``col_meta.data_type`` in ``schema_graph.enum_values``.
|
|
1285
|
+
When the column belongs to a defined enum type, returns the enum
|
|
1286
|
+
member whose casing matches the database definition.
|
|
1287
|
+
|
|
1288
|
+
Args: raw_value: Filter value string extracted from the user
|
|
1289
|
+
question. col_meta: Column metadata for the filter target
|
|
1290
|
+
column. schema_graph: Schema graph holding ``enum_values``.
|
|
1291
|
+
|
|
1292
|
+
Returns: The correctly-cased enum member, or ``None`` when no
|
|
1293
|
+
match is found or the column is not an enum type.
|
|
1294
|
+
"""
|
|
1295
|
+
if not schema_graph.enum_values:
|
|
1296
|
+
return None
|
|
1297
|
+
dtype_lower = (col_meta.data_type or "").lower()
|
|
1298
|
+
enum_vals = schema_graph.enum_values.get(dtype_lower)
|
|
1299
|
+
if not enum_vals:
|
|
1300
|
+
return None
|
|
1301
|
+
raw_lower = raw_value.lower()
|
|
1302
|
+
for ev in enum_vals:
|
|
1303
|
+
if ev.lower() == raw_lower:
|
|
1304
|
+
return ev
|
|
1305
|
+
return None
|
|
1306
|
+
|
|
1307
|
+
|
|
1308
|
+
def _extract_question_casing(raw_value: str, question: str) -> str | None:
|
|
1309
|
+
"""Extract the user's original casing for a filter value from the
|
|
1310
|
+
question text.
|
|
1311
|
+
|
|
1312
|
+
Searches *question* for *raw_value* (case-insensitive). When a
|
|
1313
|
+
match is found and the matched substring contains at least one
|
|
1314
|
+
uppercase letter, returns it so the filter preserves the user's
|
|
1315
|
+
intended casing. When the matched substring is entirely lowercase
|
|
1316
|
+
the result is ``None`` so that downstream tiers (e.g. ILIKE
|
|
1317
|
+
fallback) can still apply.
|
|
1318
|
+
|
|
1319
|
+
Args: raw_value: Filter value string to locate. question:
|
|
1320
|
+
Original natural-language question.
|
|
1321
|
+
|
|
1322
|
+
Returns: The matched substring from *question* with its original
|
|
1323
|
+
casing, or ``None`` when no match is found, the match is all-
|
|
1324
|
+
lowercase, or the casing already equals *raw_value*.
|
|
1325
|
+
"""
|
|
1326
|
+
if not raw_value or not question:
|
|
1327
|
+
return None
|
|
1328
|
+
q_lower = question.lower()
|
|
1329
|
+
val_lower = raw_value.lower()
|
|
1330
|
+
idx = q_lower.find(val_lower)
|
|
1331
|
+
if idx < 0:
|
|
1332
|
+
return None
|
|
1333
|
+
matched = question[idx : idx + len(val_lower)]
|
|
1334
|
+
if matched == matched.lower():
|
|
1335
|
+
return None
|
|
1336
|
+
if matched == raw_value:
|
|
1337
|
+
return None
|
|
1338
|
+
return matched
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
def _resolve_filter_list_cascade(
|
|
1342
|
+
filters: list[FilterParam],
|
|
1343
|
+
schema_graph: SchemaGraph,
|
|
1344
|
+
question: str,
|
|
1345
|
+
) -> tuple[list[FilterParam], bool]:
|
|
1346
|
+
"""Resolve string filter values to database-safe casing.
|
|
1347
|
+
|
|
1348
|
+
For each eligible string/enum filter the function first checks for
|
|
1349
|
+
an enum-type match (tier 1) which preserves the exact database
|
|
1350
|
+
casing. When no enum match exists, the raw_value is lowercased so
|
|
1351
|
+
the SQL prompt can pair it with a ``LOWER(column)`` wrapper for
|
|
1352
|
+
case-insensitive comparison.
|
|
1353
|
+
|
|
1354
|
+
Args: filters: List of ``FilterParam`` objects to inspect and
|
|
1355
|
+
correct. schema_graph: Schema graph with enum and profiled
|
|
1356
|
+
column data. question: Original natural-language question.
|
|
1357
|
+
|
|
1358
|
+
Returns: Tuple of ``(resolved_filters, changed)`` where
|
|
1359
|
+
*changed* is ``True`` when at least one filter was modified.
|
|
1360
|
+
"""
|
|
1361
|
+
new_filters: list[FilterParam] = []
|
|
1362
|
+
changed = False
|
|
1363
|
+
for fp in filters:
|
|
1364
|
+
if fp.raw_value is None or fp.value_type not in {"string", "enum"}:
|
|
1365
|
+
new_filters.append(fp)
|
|
1366
|
+
continue
|
|
1367
|
+
col = fp.left_expr.primary_column
|
|
1368
|
+
parts = col.split(".", 1) if "." in col else None
|
|
1369
|
+
if not parts:
|
|
1370
|
+
new_filters.append(fp)
|
|
1371
|
+
continue
|
|
1372
|
+
col_meta = schema_graph.get_column(parts[0], parts[1])
|
|
1373
|
+
if not col_meta:
|
|
1374
|
+
new_filters.append(fp)
|
|
1375
|
+
continue
|
|
1376
|
+
|
|
1377
|
+
if isinstance(fp.raw_value, list):
|
|
1378
|
+
new_vals: list = []
|
|
1379
|
+
list_changed = False
|
|
1380
|
+
for v in fp.raw_value:
|
|
1381
|
+
if not isinstance(v, str):
|
|
1382
|
+
new_vals.append(v)
|
|
1383
|
+
continue
|
|
1384
|
+
enum_match = _match_enum_value(v, col_meta, schema_graph)
|
|
1385
|
+
if enum_match is not None:
|
|
1386
|
+
if enum_match != v:
|
|
1387
|
+
list_changed = True
|
|
1388
|
+
new_vals.append(enum_match)
|
|
1389
|
+
else:
|
|
1390
|
+
lowered = v.lower()
|
|
1391
|
+
if lowered != v:
|
|
1392
|
+
list_changed = True
|
|
1393
|
+
new_vals.append(lowered)
|
|
1394
|
+
if list_changed:
|
|
1395
|
+
new_filters.append(replace(fp, raw_value=new_vals))
|
|
1396
|
+
changed = True
|
|
1397
|
+
debug(f"[intent_repair.resolve_filter_list_cascade] resolved list values on {col}")
|
|
1398
|
+
else:
|
|
1399
|
+
new_filters.append(fp)
|
|
1400
|
+
continue
|
|
1401
|
+
|
|
1402
|
+
if not isinstance(fp.raw_value, str):
|
|
1403
|
+
new_filters.append(fp)
|
|
1404
|
+
continue
|
|
1405
|
+
|
|
1406
|
+
enum_match = _match_enum_value(fp.raw_value, col_meta, schema_graph)
|
|
1407
|
+
if enum_match is not None:
|
|
1408
|
+
if enum_match != fp.raw_value:
|
|
1409
|
+
new_filters.append(replace(fp, raw_value=enum_match))
|
|
1410
|
+
changed = True
|
|
1411
|
+
debug(f"[intent_repair.resolve_filter_list_cascade] enum {col}: '{fp.raw_value}' -> '{enum_match}'")
|
|
1412
|
+
else:
|
|
1413
|
+
new_filters.append(fp)
|
|
1414
|
+
continue
|
|
1415
|
+
|
|
1416
|
+
lowered = fp.raw_value.lower()
|
|
1417
|
+
if lowered != fp.raw_value:
|
|
1418
|
+
new_filters.append(replace(fp, raw_value=lowered))
|
|
1419
|
+
changed = True
|
|
1420
|
+
debug(f"[intent_repair.resolve_filter_list_cascade] lower {col}: '{fp.raw_value}' -> '{lowered}'")
|
|
1421
|
+
else:
|
|
1422
|
+
new_filters.append(fp)
|
|
1423
|
+
return new_filters, changed
|
|
1424
|
+
|
|
1425
|
+
|
|
1426
|
+
def resolve_filter_value_case(intent: RuntimeIntent, schema_graph: SchemaGraph, question: str) -> RuntimeIntent:
|
|
1427
|
+
"""Resolve string filter values across the main query and CTE steps.
|
|
1428
|
+
|
|
1429
|
+
Tier 1 — enum match via ``schema_graph.enum_values`` preserves exact
|
|
1430
|
+
database casing. All other string filters have their raw_value
|
|
1431
|
+
lowercased so the SQL generator can pair them with ``LOWER(column)``
|
|
1432
|
+
for case-insensitive comparison.
|
|
1433
|
+
|
|
1434
|
+
Args: intent: ``RuntimeIntent`` whose filter values may have
|
|
1435
|
+
incorrect casing. schema_graph: Schema graph with enum and
|
|
1436
|
+
profiled column data. question: Original natural-language
|
|
1437
|
+
question.
|
|
1438
|
+
|
|
1439
|
+
Returns: Updated ``RuntimeIntent`` with resolved filter value
|
|
1440
|
+
casing, or the original intent unchanged when no corrections are
|
|
1441
|
+
needed.
|
|
1442
|
+
"""
|
|
1443
|
+
|
|
1444
|
+
def process(filters: list[FilterParam]) -> tuple[list[FilterParam], bool]:
|
|
1445
|
+
return _resolve_filter_list_cascade(filters, schema_graph, question)
|
|
1446
|
+
|
|
1447
|
+
return _apply_filters_to_main_and_ctes(intent, process)
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
def _coerce_element(val: Any, data_type: str) -> Any:
|
|
1451
|
+
"""Coerce a single IN-list element to match the column's data type.
|
|
1452
|
+
|
|
1453
|
+
When the column is numeric, strings that look like numbers are cast
|
|
1454
|
+
to ``int`` or ``float``. Non-castable values are returned
|
|
1455
|
+
unchanged.
|
|
1456
|
+
|
|
1457
|
+
Args: val: Single element from an IN-list raw_value.
|
|
1458
|
+
data_type: Lowercased column data_type string.
|
|
1459
|
+
|
|
1460
|
+
Returns: Coerced value, or the original value if coercion is not
|
|
1461
|
+
applicable.
|
|
1462
|
+
"""
|
|
1463
|
+
if data_type not in NUMERIC_DATA_TYPES:
|
|
1464
|
+
return val
|
|
1465
|
+
if isinstance(val, (int, float)):
|
|
1466
|
+
return val
|
|
1467
|
+
if not isinstance(val, str):
|
|
1468
|
+
return val
|
|
1469
|
+
stripped = val.strip()
|
|
1470
|
+
try:
|
|
1471
|
+
if "." in stripped:
|
|
1472
|
+
return float(stripped)
|
|
1473
|
+
return int(stripped)
|
|
1474
|
+
except (ValueError, OverflowError):
|
|
1475
|
+
return val
|
|
1476
|
+
|
|
1477
|
+
|
|
1478
|
+
def _consolidate_in_list(vals: list, data_type: str) -> str:
|
|
1479
|
+
"""Convert a list of IN-values into a formatted SQL-ready string.
|
|
1480
|
+
|
|
1481
|
+
String elements are wrapped in single quotes (``'R', 'PG-13'``),
|
|
1482
|
+
while numeric elements are joined as-is (``1, 2, 3``).
|
|
1483
|
+
|
|
1484
|
+
Args: vals: List of IN-list elements (already type-coerced).
|
|
1485
|
+
data_type: Lowercased column data_type for formatting decisions.
|
|
1486
|
+
|
|
1487
|
+
Returns: Comma-separated string suitable for direct SQL
|
|
1488
|
+
substitution.
|
|
1489
|
+
"""
|
|
1490
|
+
if all(isinstance(v, (int, float)) for v in vals):
|
|
1491
|
+
return ", ".join(str(v) for v in vals)
|
|
1492
|
+
parts: list[str] = []
|
|
1493
|
+
for v in vals:
|
|
1494
|
+
if isinstance(v, str):
|
|
1495
|
+
parts.append(f"'{v}'")
|
|
1496
|
+
else:
|
|
1497
|
+
parts.append(str(v))
|
|
1498
|
+
return ", ".join(parts)
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
def _normalize_in_types_for_list(
|
|
1502
|
+
filters: list[FilterParam],
|
|
1503
|
+
schema_graph: SchemaGraph,
|
|
1504
|
+
) -> tuple[list[FilterParam], bool]:
|
|
1505
|
+
"""Coerce IN / NOT IN list elements to match their column types and
|
|
1506
|
+
consolidate to strings.
|
|
1507
|
+
|
|
1508
|
+
For each filter with ``op`` in (``in``, ``not in``) and a list
|
|
1509
|
+
``raw_value``, each element is coerced to the column's native type.
|
|
1510
|
+
The list is then consolidated into a formatted SQL string so
|
|
1511
|
+
``substitute_params`` can perform direct substitution.
|
|
1512
|
+
|
|
1513
|
+
Args: filters: Filter params to inspect and coerce.
|
|
1514
|
+
schema_graph: Schema graph for column type lookup.
|
|
1515
|
+
|
|
1516
|
+
Returns: Tuple of ``(coerced_filters, changed)``.
|
|
1517
|
+
"""
|
|
1518
|
+
new_filters: list[FilterParam] = []
|
|
1519
|
+
changed = False
|
|
1520
|
+
for fp in filters:
|
|
1521
|
+
if fp.op.lower() not in {"in", "not in"} or not isinstance(fp.raw_value, list):
|
|
1522
|
+
new_filters.append(fp)
|
|
1523
|
+
continue
|
|
1524
|
+
col = fp.left_expr.primary_column
|
|
1525
|
+
parts = col.split(".", 1) if "." in col else None
|
|
1526
|
+
if not parts:
|
|
1527
|
+
new_filters.append(fp)
|
|
1528
|
+
continue
|
|
1529
|
+
col_meta = schema_graph.get_column(parts[0], parts[1])
|
|
1530
|
+
dtype = (col_meta.data_type or "").lower() if col_meta else ""
|
|
1531
|
+
coerced = [_coerce_element(v, dtype) for v in fp.raw_value]
|
|
1532
|
+
consolidated = _consolidate_in_list(coerced, dtype)
|
|
1533
|
+
if consolidated != fp.raw_value:
|
|
1534
|
+
new_filters.append(replace(fp, raw_value=consolidated))
|
|
1535
|
+
changed = True
|
|
1536
|
+
debug(f"[intent_resolve_normalize_in_types_for_list] {col}: {fp.raw_value!r} -> {consolidated!r}")
|
|
1537
|
+
else:
|
|
1538
|
+
new_filters.append(fp)
|
|
1539
|
+
return new_filters, changed
|
|
1540
|
+
|
|
1541
|
+
|
|
1542
|
+
def normalize_in_filter_types(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
1543
|
+
"""Coerce IN / NOT IN list elements and consolidate across main
|
|
1544
|
+
query and CTE steps.
|
|
1545
|
+
|
|
1546
|
+
Args: intent: RuntimeIntent whose IN-list filter values may need
|
|
1547
|
+
type coercion. schema_graph: Schema graph for column type
|
|
1548
|
+
lookup.
|
|
1549
|
+
|
|
1550
|
+
Returns: Updated RuntimeIntent with coerced and consolidated IN-
|
|
1551
|
+
list values.
|
|
1552
|
+
"""
|
|
1553
|
+
|
|
1554
|
+
def process(filters: list[FilterParam]) -> tuple[list[FilterParam], bool]:
|
|
1555
|
+
return _normalize_in_types_for_list(filters, schema_graph)
|
|
1556
|
+
|
|
1557
|
+
intent = _apply_filters_to_main_and_ctes(intent, process)
|
|
1558
|
+
return decompose_in_not_in_filters(intent)
|
|
1559
|
+
|
|
1560
|
+
|
|
1561
|
+
def _decompose_in_list(
|
|
1562
|
+
filters: list[FilterParam],
|
|
1563
|
+
max_list_size: int = 10,
|
|
1564
|
+
) -> list[FilterParam]:
|
|
1565
|
+
"""Expand small IN / NOT IN lists into primitive comparisons.
|
|
1566
|
+
|
|
1567
|
+
For ``IN`` with a short raw_value list, creates one ``=`` filter per
|
|
1568
|
+
element combined with ``OR``. For ``NOT IN`` with a short list,
|
|
1569
|
+
creates ``!=`` filters combined with ``AND``. Large lists are left
|
|
1570
|
+
unchanged.
|
|
1571
|
+
"""
|
|
1572
|
+
new_filters: list[FilterParam] = []
|
|
1573
|
+
for fp in filters:
|
|
1574
|
+
raw = fp.raw_value
|
|
1575
|
+
op_lower = (fp.op or "").lower()
|
|
1576
|
+
if not isinstance(raw, list) or op_lower not in {"in", "not in"} or len(raw) == 0 or len(raw) > max_list_size:
|
|
1577
|
+
new_filters.append(fp)
|
|
1578
|
+
continue
|
|
1579
|
+
elems = list(raw)
|
|
1580
|
+
bool_op = "OR" if op_lower == "in" else "AND"
|
|
1581
|
+
new_group = []
|
|
1582
|
+
for idx, val in enumerate(elems):
|
|
1583
|
+
new_fp = replace(
|
|
1584
|
+
fp,
|
|
1585
|
+
op="=" if op_lower == "in" else "!=",
|
|
1586
|
+
raw_value=val,
|
|
1587
|
+
bool_op=bool_op if idx > 0 else (fp.bool_op or "AND"),
|
|
1588
|
+
)
|
|
1589
|
+
new_group.append(new_fp)
|
|
1590
|
+
new_filters.extend(new_group)
|
|
1591
|
+
return new_filters
|
|
1592
|
+
|
|
1593
|
+
|
|
1594
|
+
def decompose_in_not_in_filters(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1595
|
+
"""Decompose small IN / NOT IN lists across main query and CTE steps."""
|
|
1596
|
+
main_filters = _decompose_in_list(intent.filters_param or [])
|
|
1597
|
+
new_ctes: list[RuntimeCteStep] = []
|
|
1598
|
+
for cte in intent.cte_steps or []:
|
|
1599
|
+
decomposed = _decompose_in_list(cte.filters_param or [])
|
|
1600
|
+
new_ctes.append(replace(cte, filters_param=decomposed))
|
|
1601
|
+
return replace(intent, filters_param=main_filters, cte_steps=new_ctes or intent.cte_steps)
|
|
1602
|
+
|
|
1603
|
+
|
|
1604
|
+
def _resolve_boolean_value(raw_value: Any, col_meta: ColumnMetadata) -> tuple[Any, str] | None:
|
|
1605
|
+
"""Resolve a filter raw_value to a Python ``bool`` for a native
|
|
1606
|
+
boolean column.
|
|
1607
|
+
|
|
1608
|
+
Only applies to columns whose ``data_type`` contains ``"bool"``.
|
|
1609
|
+
Converts common truthy/falsy representations (integers, strings,
|
|
1610
|
+
Python bools) to ``True``/``False`` and sets the value_type to
|
|
1611
|
+
``"boolean"`` so that ``substitute_params`` emits the SQL literal
|
|
1612
|
+
``TRUE`` or ``FALSE``.
|
|
1613
|
+
|
|
1614
|
+
Args: raw_value: The current filter value (int, str, bool, or
|
|
1615
|
+
other). col_meta: Column metadata for the filter target column.
|
|
1616
|
+
|
|
1617
|
+
Returns: Tuple of ``(resolved_value, "boolean")`` when
|
|
1618
|
+
conversion succeeds, or ``None`` when the column is not a native
|
|
1619
|
+
boolean or the value cannot be mapped.
|
|
1620
|
+
"""
|
|
1621
|
+
dtype_lower = (col_meta.data_type or "").lower()
|
|
1622
|
+
if "bool" not in dtype_lower:
|
|
1623
|
+
return None
|
|
1624
|
+
if isinstance(raw_value, bool):
|
|
1625
|
+
return raw_value, "boolean"
|
|
1626
|
+
val_str = str(raw_value).lower().strip()
|
|
1627
|
+
if val_str in BOOLEAN_TRUTHY_VALUES:
|
|
1628
|
+
return True, "boolean"
|
|
1629
|
+
if val_str in BOOLEAN_FALSY_VALUES:
|
|
1630
|
+
return False, "boolean"
|
|
1631
|
+
return None
|
|
1632
|
+
|
|
1633
|
+
|
|
1634
|
+
def _normalize_boolean_filter_list(
|
|
1635
|
+
filters: list[FilterParam], schema_graph: SchemaGraph
|
|
1636
|
+
) -> tuple[list[FilterParam], bool]:
|
|
1637
|
+
"""Normalise boolean filter values in a list of ``FilterParam``
|
|
1638
|
+
objects.
|
|
1639
|
+
|
|
1640
|
+
For each filter targeting a native boolean column whose
|
|
1641
|
+
``raw_value`` is not already a Python ``bool``, converts the value
|
|
1642
|
+
and sets ``value_type`` to ``"boolean"``.
|
|
1643
|
+
|
|
1644
|
+
Args: filters: List of ``FilterParam`` objects to inspect and
|
|
1645
|
+
correct. schema_graph: ``SchemaGraph`` providing column
|
|
1646
|
+
metadata.
|
|
1647
|
+
|
|
1648
|
+
Returns: Tuple of ``(normalised_filters, changed)`` where
|
|
1649
|
+
*changed* is ``True`` when at least one filter was rewritten.
|
|
1650
|
+
"""
|
|
1651
|
+
new_filters: list[FilterParam] = []
|
|
1652
|
+
changed = False
|
|
1653
|
+
for fp in filters:
|
|
1654
|
+
if fp.raw_value is None:
|
|
1655
|
+
new_filters.append(fp)
|
|
1656
|
+
continue
|
|
1657
|
+
col = fp.left_expr.primary_column
|
|
1658
|
+
parts = col.split(".", 1) if "." in col else None
|
|
1659
|
+
if not parts:
|
|
1660
|
+
new_filters.append(fp)
|
|
1661
|
+
continue
|
|
1662
|
+
col_meta = schema_graph.get_column(parts[0], parts[1])
|
|
1663
|
+
if not col_meta:
|
|
1664
|
+
new_filters.append(fp)
|
|
1665
|
+
continue
|
|
1666
|
+
resolved = _resolve_boolean_value(fp.raw_value, col_meta)
|
|
1667
|
+
if resolved is None:
|
|
1668
|
+
new_filters.append(fp)
|
|
1669
|
+
continue
|
|
1670
|
+
bool_val, vtype = resolved
|
|
1671
|
+
new_filters.append(replace(fp, raw_value=bool_val, value_type=vtype))
|
|
1672
|
+
changed = True
|
|
1673
|
+
debug(
|
|
1674
|
+
f"[intent_resolve_normalize_boolean_filter_list] {col}: "
|
|
1675
|
+
f"{fp.raw_value!r} ({fp.value_type}) → {bool_val!r} ({vtype})"
|
|
1676
|
+
)
|
|
1677
|
+
return new_filters, changed
|
|
1678
|
+
|
|
1679
|
+
|
|
1680
|
+
def normalize_boolean_filter_values(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
1681
|
+
"""Normalise boolean filter values across the main query and CTE
|
|
1682
|
+
steps.
|
|
1683
|
+
|
|
1684
|
+
LLM-extracted intents frequently represent boolean filters as
|
|
1685
|
+
integers. For native boolean columns these must become Python
|
|
1686
|
+
``True``/``False`` with ``value_type="boolean"`` so that
|
|
1687
|
+
``substitute_params`` emits the SQL literal ``TRUE`` or ``FALSE``
|
|
1688
|
+
rather than an integer or quoted string.
|
|
1689
|
+
|
|
1690
|
+
Args: intent: ``RuntimeIntent`` whose filter values may need
|
|
1691
|
+
boolean normalisation. schema_graph: ``SchemaGraph``
|
|
1692
|
+
providing column data-type metadata.
|
|
1693
|
+
|
|
1694
|
+
Returns: Updated ``RuntimeIntent`` with normalised boolean
|
|
1695
|
+
filter values, or the original intent unchanged when no
|
|
1696
|
+
corrections are needed.
|
|
1697
|
+
"""
|
|
1698
|
+
|
|
1699
|
+
def process(filters: list[FilterParam]) -> tuple[list[FilterParam], bool]:
|
|
1700
|
+
return _normalize_boolean_filter_list(filters, schema_graph)
|
|
1701
|
+
|
|
1702
|
+
return _apply_filters_to_main_and_ctes(intent, process)
|
|
1703
|
+
|
|
1704
|
+
|
|
1705
|
+
def _normalize_null_filter_list(
|
|
1706
|
+
filters: list[FilterParam],
|
|
1707
|
+
) -> tuple[list[FilterParam], bool]:
|
|
1708
|
+
"""Normalise IS NULL / IS NOT NULL filters to canonical form.
|
|
1709
|
+
|
|
1710
|
+
Ensures ``value_type`` is ``"null"`` and ``raw_value`` is ``None``
|
|
1711
|
+
for any filter whose operator is ``"is null"`` or ``"is not null"``.
|
|
1712
|
+
"""
|
|
1713
|
+
result: list[FilterParam] = []
|
|
1714
|
+
changed = False
|
|
1715
|
+
for fp in filters:
|
|
1716
|
+
if fp.op in ("is null", "is not null"):
|
|
1717
|
+
needs_fix = fp.value_type != "null" or fp.raw_value is not None
|
|
1718
|
+
if needs_fix:
|
|
1719
|
+
result.append(replace(fp, value_type="null", raw_value=None))
|
|
1720
|
+
changed = True
|
|
1721
|
+
continue
|
|
1722
|
+
result.append(fp)
|
|
1723
|
+
return result, changed
|
|
1724
|
+
|
|
1725
|
+
|
|
1726
|
+
def normalize_null_filter_values(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1727
|
+
"""Normalise null-operator filters across main query and CTE steps.
|
|
1728
|
+
|
|
1729
|
+
Ensures every ``IS NULL`` / ``IS NOT NULL`` filter carries
|
|
1730
|
+
``value_type="null"`` and ``raw_value=None`` so downstream
|
|
1731
|
+
validation does not flag a spurious type mismatch.
|
|
1732
|
+
"""
|
|
1733
|
+
return _apply_filters_to_main_and_ctes(intent, _normalize_null_filter_list)
|