aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1733 @@
1
+ """Structural intent repairs and value normalization.
2
+
3
+ Repairs foreign key filter type mismatches where an integer column is compared to a string value by rewriting filters to use descriptive columns and expands foreign key selects to descriptive columns.
4
+
5
+ Strips spurious GROUP BY clauses, impossible HAVING conditions such as COUNT < 0, and hallucinated SQL keywords in table names.
6
+
7
+ Strips foreign key equi-join conditions from filters, prunes unreferenced tables, and normalizes boolean filter values, IN-list types, and filter value casing against schema statistics and question text.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ from collections.abc import Callable
14
+ from dataclasses import replace
15
+ from typing import Any
16
+
17
+ from .config import (
18
+ BOOLEAN_FALSY_VALUES,
19
+ BOOLEAN_TRUTHY_VALUES,
20
+ DISTINCT_RE,
21
+ IMPOSSIBLE_HAVING_RE,
22
+ NUMERIC_DATA_TYPES,
23
+ NUMERIC_LITERAL_RE,
24
+ RANGE_OPS,
25
+ SQL_KEYWORDS,
26
+ TOP_N_RE,
27
+ PolicyConfig,
28
+ )
29
+ from .contracts_base import ColumnMetadata, SchemaGraph, TableMetadata
30
+ from .contracts_core import (
31
+ FilterParam,
32
+ HavingParam,
33
+ MulGroup,
34
+ NormalizedExpr,
35
+ RuntimeCteStep,
36
+ RuntimeIntent,
37
+ SelectCol,
38
+ )
39
+ from .core_utils import debug
40
+ from .intent_expr import extract_columns_from_expr, replace_refs_in_expr
41
+
42
+
43
+ def _english_plurals(word: str) -> list[str]:
44
+ """Return *word* together with its common English plural forms.
45
+
46
+ Covers consonant-y to -ies, sibilant endings to -es, and the default -s suffix and always returns the original word as the first element so callers can iterate a single list.
47
+ """
48
+ forms = [word]
49
+ w = word.lower()
50
+ if w.endswith("y") and len(w) > 2 and w[-2] not in "aeiou":
51
+ forms.append(w[:-1] + "ies")
52
+ elif w.endswith(("s", "sh", "ch", "x", "z")):
53
+ forms.append(w + "es")
54
+ else:
55
+ forms.append(w + "s")
56
+ return forms
57
+
58
+
59
+ def _apply_filters_to_main_and_ctes(
60
+ intent: RuntimeIntent,
61
+ process_fn: Callable[[list[FilterParam]], tuple[list[FilterParam], bool]],
62
+ ) -> RuntimeIntent:
63
+ """Apply a filter processor to the main intent and each CTE, merging results."""
64
+ new_fp, main_changed = process_fn(intent.filters_param or [])
65
+ if not intent.cte_steps:
66
+ return replace(intent, filters_param=new_fp) if main_changed else intent
67
+ new_cte_steps = []
68
+ cte_changed = False
69
+ for cte in intent.cte_steps:
70
+ cte_fp, c = process_fn(cte.filters_param or [])
71
+ if c:
72
+ cte_changed = True
73
+ new_cte_steps.append(replace(cte, filters_param=cte_fp))
74
+ if not main_changed and not cte_changed:
75
+ return intent
76
+ result = replace(intent, filters_param=new_fp)
77
+ if cte_changed:
78
+ result = replace(result, cte_steps=new_cte_steps)
79
+ return result
80
+
81
+
82
+ def _dedup_contradictory_filters_list(
83
+ filters: list[FilterParam],
84
+ ) -> tuple[list[FilterParam], bool]:
85
+ """Remove range operators when equality exists on the same column.
86
+
87
+ When a column has both an '=' filter and a range operator such as '>', '<', '>=', or '<=', the range filter contradicts or is redundant with the equality and the range filter is dropped.
88
+ """
89
+ eq_columns: set[str] = set()
90
+ for fp in filters:
91
+ col = fp.left_expr.primary_column or ""
92
+ if fp.op == "=" and col:
93
+ eq_columns.add(col)
94
+
95
+ if not eq_columns:
96
+ return filters, False
97
+
98
+ kept: list[FilterParam] = []
99
+ changed = False
100
+ for fp in filters:
101
+ col = fp.left_expr.primary_column or ""
102
+ if col in eq_columns and fp.op in RANGE_OPS:
103
+ debug(f"[intent_repair.dedup_contradictory_filters] dropping {fp.op} on '{col}' that contradicts =")
104
+ changed = True
105
+ continue
106
+ kept.append(fp)
107
+ return kept, changed
108
+
109
+
110
+ def dedup_contradictory_filters(intent: RuntimeIntent) -> RuntimeIntent:
111
+ """Remove contradictory range filters from main query and CTEs."""
112
+ return _apply_filters_to_main_and_ctes(intent, _dedup_contradictory_filters_list)
113
+
114
+
115
+ def _is_null_value(raw_value: Any) -> bool:
116
+ """Return True if the raw filter value represents NULL."""
117
+ if raw_value is None:
118
+ return True
119
+ if isinstance(raw_value, str) and raw_value.strip().lower() == "null":
120
+ return True
121
+ return False
122
+
123
+
124
+ def repair_null_equality_filters(intent: RuntimeIntent) -> RuntimeIntent:
125
+ """Rewrite equality filters against null values into proper IS NULL or IS NOT NULL conditions.
126
+
127
+ When the LLM produces a filter with op '=' and a null value, the SQL column = NULL is rewritten to column IS NULL and similarly '!=' or '<>' against null becomes IS NOT NULL and the change applies to both the main query and all CTE steps.
128
+
129
+ Args:
130
+
131
+ intent: RuntimeIntent to inspect.
132
+
133
+ Returns:
134
+
135
+ Updated RuntimeIntent with corrected null operators.
136
+ """
137
+
138
+ return _apply_filters_to_main_and_ctes(intent, _repair_null_equality_list)
139
+
140
+
141
+ def _repair_null_equality_list(
142
+ filters: list[FilterParam],
143
+ ) -> tuple[list[FilterParam], bool]:
144
+ repaired: list[FilterParam] = []
145
+ changed = False
146
+ for fp in filters:
147
+ if fp.op == "=" and _is_null_value(fp.raw_value):
148
+ repaired.append(
149
+ replace(
150
+ fp,
151
+ op="is null",
152
+ raw_value=None,
153
+ value_type="null",
154
+ )
155
+ )
156
+ changed = True
157
+ elif fp.op in ("!=", "<>") and _is_null_value(fp.raw_value):
158
+ repaired.append(
159
+ replace(
160
+ fp,
161
+ op="is not null",
162
+ raw_value=None,
163
+ value_type="null",
164
+ )
165
+ )
166
+ changed = True
167
+ else:
168
+ repaired.append(fp)
169
+ return repaired, changed
170
+
171
+
172
+ def _infer_cte_output_columns(cte: Any) -> list[str]:
173
+ """Derive output column names from a CTE's select_cols.
174
+
175
+ When the LLM omits output_columns this falls back to extracting the trailing column identifier from each select expression and prepends the aggregation function name for aggregated columns to avoid ambiguity.
176
+
177
+ Args:
178
+
179
+ cte: A RuntimeCteStep with populated select_cols.
180
+
181
+ Returns:
182
+
183
+ List of bare column-name strings suitable for use as CTE output aliases.
184
+ """
185
+ names: list[str] = []
186
+ for sc in cte.select_cols or []:
187
+ col = sc.expr.primary_column if sc.expr else ""
188
+ if not col:
189
+ continue
190
+ bare = col.split(".")[-1].strip().lower()
191
+ if sc.is_aggregated and sc.expr.agg_func:
192
+ bare = f"{sc.expr.agg_func.lower()}_{bare}"
193
+ if bare and bare not in names:
194
+ names.append(bare)
195
+ return names
196
+
197
+
198
+ def _qualify_term(term: str, output_to_cte: dict[str, str]) -> str:
199
+ """Prefix an unqualified column reference with its CTE source name.
200
+
201
+ If term, which is a single MulGroup.multiply or MulGroup.divide entry, contains an unqualified column name that matches a CTE output column, the column portion is rewritten to cte_name.column while already qualified terms containing a dot are returned unchanged, and function-wrapped columns such as SUM(total_amount) are handled by replacing the innermost identifier.
202
+
203
+ Args:
204
+
205
+ term: Raw expression term.
206
+
207
+ output_to_cte: Mapping of lowered bare output column name to the owning CTE name.
208
+
209
+ Returns:
210
+
211
+ The possibly rewritten term string.
212
+ """
213
+ for col_lower, cte_name in output_to_cte.items():
214
+ pat = re.compile(
215
+ r"(?<!\.)(?<![A-Za-z0-9_])" + re.escape(col_lower) + r"(?![A-Za-z0-9_])",
216
+ re.IGNORECASE,
217
+ )
218
+ if pat.search(term):
219
+ term = pat.sub(f"{cte_name}.{col_lower}", term)
220
+ return term
221
+
222
+
223
+ def _qualify_expr(expr: NormalizedExpr, output_to_cte: dict[str, str]) -> NormalizedExpr:
224
+ """Apply CTE qualification to every term in a NormalizedExpr by rebuilding each MulGroup with qualified multiply and divide terms."""
225
+
226
+ def _fix_group(g: MulGroup) -> MulGroup:
227
+ return replace(
228
+ g,
229
+ multiply=[_qualify_term(m, output_to_cte) for m in g.multiply],
230
+ divide=[_qualify_term(d, output_to_cte) for d in g.divide],
231
+ )
232
+
233
+ return replace(
234
+ expr,
235
+ add_groups=[_fix_group(g) for g in expr.add_groups],
236
+ sub_groups=[_fix_group(g) for g in expr.sub_groups],
237
+ )
238
+
239
+
240
+ def qualify_cte_output_columns(intent: RuntimeIntent) -> RuntimeIntent:
241
+ """Qualify unqualified column references in the main query that match CTE output columns.
242
+
243
+ When the LLM produces a main-query expression referencing a CTE output column without the CTE-name prefix this repair detects the match and prepends the correct CTE name and only the main query's select_cols, group_by_cols, and order_by_cols are touched because CTE steps reference their own tables rather than other CTE outputs.
244
+
245
+ Args:
246
+
247
+ intent: RuntimeIntent with CTE steps whose output_columns may be referenced without qualification in the main query.
248
+
249
+ Returns:
250
+
251
+ Updated RuntimeIntent with qualified main-query expressions or the original intent if nothing changed.
252
+ """
253
+ cte_steps = intent.cte_steps or []
254
+ if not cte_steps:
255
+ return intent
256
+
257
+ output_to_cte: dict[str, str] = {}
258
+ for cte in cte_steps:
259
+ explicit_outputs = cte.output_columns or []
260
+ if not explicit_outputs:
261
+ explicit_outputs = _infer_cte_output_columns(cte)
262
+ for oc in explicit_outputs:
263
+ bare = oc.split(".")[-1].strip().lower()
264
+ if bare:
265
+ output_to_cte[bare] = cte.cte_name
266
+ if not output_to_cte:
267
+ return intent
268
+
269
+ main_tables = set(intent.tables or [])
270
+
271
+ def _should_skip(term: str) -> bool:
272
+ """Return True if the term is already qualified with a real
273
+ table."""
274
+ if "." in term:
275
+ prefix = term.split(".")[0].lower()
276
+ return prefix in {t.lower() for t in main_tables}
277
+ return False
278
+
279
+ def _safe_qualify(term: str) -> str:
280
+ if _should_skip(term):
281
+ return term
282
+ return _qualify_term(term, output_to_cte)
283
+
284
+ new_select_cols = [
285
+ (replace(sc, expr=_qualify_expr(sc.expr, output_to_cte)) if not _should_skip(sc.expr.primary_column) else sc)
286
+ for sc in (intent.select_cols or [])
287
+ ]
288
+ new_group_by = [
289
+ _qualify_expr(g, output_to_cte) if not _should_skip(g.primary_column) else g
290
+ for g in (intent.group_by_cols or [])
291
+ ]
292
+ new_order_by = [
293
+ (
294
+ replace(obc, expr=_qualify_expr(obc.expr, output_to_cte))
295
+ if not _should_skip(obc.expr.primary_column)
296
+ else obc
297
+ )
298
+ for obc in (intent.order_by_cols or [])
299
+ ]
300
+
301
+ if (
302
+ new_select_cols == intent.select_cols
303
+ and new_group_by == intent.group_by_cols
304
+ and new_order_by == intent.order_by_cols
305
+ ):
306
+ return intent
307
+
308
+ debug("[qualify_cte_output_columns] qualified unqualified CTE output references in main query")
309
+ return replace(
310
+ intent,
311
+ select_cols=new_select_cols,
312
+ group_by_cols=new_group_by,
313
+ order_by_cols=new_order_by,
314
+ )
315
+
316
+
317
+ DESCRIPTIVE_ALLOWED_VALUE_TYPES = frozenset({"string", "integer"})
318
+ DESCRIPTIVE_EXCLUDED_VALUE_TYPES = frozenset({"date", "boolean", "number"})
319
+
320
+
321
+ def _descriptive_column_score(col_name: str, col_meta: ColumnMetadata) -> tuple[int, int]:
322
+ """Score a column for use as a descriptive column; higher is better.
323
+
324
+ Prefers name-like columns (name, title, first_name, last_name) and
325
+ higher distinct_count. No maximum cardinality cap.
326
+ """
327
+ name_lower = col_name.lower()
328
+ name_score = 0
329
+ if "name" in name_lower or "title" in name_lower:
330
+ name_score = 2
331
+ elif "first_name" in name_lower or "last_name" in name_lower:
332
+ name_score = 3
333
+ dc = col_meta.distinct_count or 0
334
+ return (name_score, dc)
335
+
336
+
337
+ def best_descriptive_columns(
338
+ table: str,
339
+ schema_graph: SchemaGraph,
340
+ exclude: set[str],
341
+ max_count: int = 2,
342
+ ) -> list[str]:
343
+ """Return up to *max_count* best descriptive columns for the table.
344
+
345
+ Excludes PK/FK columns and non-string/integer types. Requires
346
+ high individual uniqueness (``distinct_ratio >= 0.95``). When
347
+ *max_count* >= 2 and two name-like candidates exist whose
348
+ composite distinct ratio (profiled during schema loading) exceeds
349
+ the best single-column ratio, both columns are returned.
350
+ """
351
+ tbl_meta = schema_graph.tables.get(table)
352
+ if not tbl_meta:
353
+ return []
354
+ candidates: list[tuple[str, ColumnMetadata]] = []
355
+ for col_name, col_meta in tbl_meta.columns.items():
356
+ if col_meta.is_primary_key or col_meta.is_foreign_key:
357
+ continue
358
+ if f"{table}.{col_name}" in exclude:
359
+ continue
360
+ vt = (col_meta.value_type or "").lower()
361
+ if vt in DESCRIPTIVE_EXCLUDED_VALUE_TYPES:
362
+ continue
363
+ if vt not in DESCRIPTIVE_ALLOWED_VALUE_TYPES:
364
+ continue
365
+ ratio = col_meta.distinct_ratio
366
+ if ratio is not None and ratio < 0.95:
367
+ continue
368
+ candidates.append((col_name, col_meta))
369
+ if not candidates:
370
+ return []
371
+ candidates.sort(
372
+ key=lambda p: _descriptive_column_score(p[0], p[1]),
373
+ reverse=True,
374
+ )
375
+ if max_count >= 2 and len(candidates) >= 2:
376
+ pair = _best_composite_name_pair(tbl_meta, candidates)
377
+ if pair is not None:
378
+ return list(pair)
379
+ return [col_name for col_name, _ in candidates[:max_count]]
380
+
381
+
382
+ def _best_composite_name_pair(
383
+ tbl_meta: TableMetadata,
384
+ candidates: list[tuple[str, ColumnMetadata]],
385
+ ) -> tuple[str, str] | None:
386
+ """Return a name-like column pair if its composite ratio beats singles.
387
+
388
+ Checks whether any two name-scored candidates have a profiled
389
+ composite distinct ratio that exceeds the best individual
390
+ distinct_ratio among the candidates. Returns ``None`` when no
391
+ such pair is found.
392
+ """
393
+ name_candidates = [
394
+ (name, meta)
395
+ for name, meta in candidates
396
+ if _descriptive_column_score(name, meta)[0] >= 2
397
+ ]
398
+ if len(name_candidates) < 2:
399
+ return None
400
+ best_single_ratio = max(
401
+ (m.distinct_ratio or 0.0) for _, m in candidates
402
+ )
403
+ ratios = tbl_meta.composite_descriptive_ratios
404
+ for i in range(len(name_candidates)):
405
+ for j in range(i + 1, len(name_candidates)):
406
+ c1 = name_candidates[i][0]
407
+ c2 = name_candidates[j][0]
408
+ composite = ratios.get((c1, c2)) or ratios.get((c2, c1))
409
+ if composite is not None and composite > best_single_ratio:
410
+ return (c1, c2)
411
+ return None
412
+
413
+
414
+ def best_descriptive_column(table: str, schema_graph: SchemaGraph, exclude: set[str]) -> str | None:
415
+ """Return the best non-PK non-FK descriptive column for the table.
416
+
417
+ Uses best_descriptive_columns with max_count=1. Allows string and integer
418
+ types; excludes PK/FK and decimals, dates, booleans.
419
+ """
420
+ cols = best_descriptive_columns(table, schema_graph, exclude, max_count=1)
421
+ return cols[0] if cols else None
422
+
423
+
424
+ def _repair_fk_filters(
425
+ filters: list[FilterParam],
426
+ select_cols: list,
427
+ tables: list[str],
428
+ schema_graph: SchemaGraph,
429
+ label: str = "",
430
+ ) -> tuple[list[FilterParam], list[str], bool]:
431
+ """Detect foreign-key filters that should use descriptive columns.
432
+
433
+ Leaves filters unchanged but reports whether any filter targets a
434
+ foreign-key integer column with a string-like value, which should be
435
+ surfaced as a semantic issue for repair rather than rewritten
436
+ deterministically.
437
+ """
438
+ new_filters: list[FilterParam] = []
439
+ tables = list(tables)
440
+ changed = False
441
+ existing_terms = {sc.expr.primary_term for sc in select_cols or []}
442
+ for fp in filters:
443
+ if fp.value_type not in {"string", "enum"} or fp.raw_value is None:
444
+ new_filters.append(fp)
445
+ continue
446
+ col = fp.left_expr.primary_column
447
+ parts = col.split(".", 1) if "." in col else None
448
+ if not parts:
449
+ new_filters.append(fp)
450
+ continue
451
+ col_meta = schema_graph.get_column(parts[0], parts[1])
452
+ if not col_meta or not col_meta.is_foreign_key or col_meta.value_type not in {"integer", "number"}:
453
+ new_filters.append(fp)
454
+ continue
455
+ fk_target = col_meta.fk_target
456
+ if not fk_target:
457
+ new_filters.append(fp)
458
+ continue
459
+ target_table, _ = fk_target
460
+ desc = best_descriptive_column(target_table, schema_graph, existing_terms)
461
+ new_filters.append(fp)
462
+ if desc:
463
+ changed = True
464
+ debug(f"[intent_resolve.repair_fk_filter_type_mismatch{label}] detected fk filter {col} needing descriptive column")
465
+ return new_filters, tables, changed
466
+
467
+
468
+ def repair_fk_filter_type_mismatch(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
469
+ main_filters, _, main_changed = _repair_fk_filters(
470
+ intent.filters_param or [],
471
+ intent.select_cols or [],
472
+ list(intent.tables or []),
473
+ schema_graph,
474
+ )
475
+ cte_changed = False
476
+ new_cte_steps = []
477
+ for cte in intent.cte_steps or []:
478
+ cte_filters, _, c = _repair_fk_filters(
479
+ cte.filters_param or [],
480
+ cte.select_cols or [],
481
+ list(cte.tables or []),
482
+ schema_graph,
483
+ label=f" CTE '{cte.cte_name}'",
484
+ )
485
+ if c:
486
+ new_cte_steps.append(replace(cte, filters_param=cte_filters))
487
+ cte_changed = True
488
+ else:
489
+ new_cte_steps.append(cte)
490
+ if not main_changed and not cte_changed:
491
+ return intent
492
+ result = intent
493
+ if main_changed:
494
+ result = replace(result, filters_param=main_filters)
495
+ if cte_changed:
496
+ result = replace(result, cte_steps=new_cte_steps)
497
+ return result
498
+
499
+
500
+ def expand_fk_select_to_descriptive(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
501
+ """Replace foreign-key integer columns in select_cols with the target table's descriptive column.
502
+
503
+ When a SelectCol references a foreign-key column this rewrites it to the foreign-key target table's best descriptive column and adds the target table to intent.tables so join enumeration discovers the path, mirroring repair_fk_filter_type_mismatch but operating on SELECT columns instead of filter conditions and only rewriting bare non-aggregated foreign-key columns.
504
+
505
+ Args:
506
+
507
+ intent: RuntimeIntent whose select_cols may reference foreign-key integer columns.
508
+ schema_graph: SchemaGraph for foreign-key relationship and descriptive column lookups.
509
+
510
+ Returns:
511
+
512
+ Updated RuntimeIntent with foreign-key select_cols expanded to descriptive columns.
513
+ """
514
+ tables = list(intent.tables or [])
515
+ new_select: list[SelectCol] = []
516
+ changed = False
517
+ existing_terms = {sc.expr.primary_term for sc in intent.select_cols or []}
518
+ for sc in intent.select_cols or []:
519
+ if sc.is_aggregated:
520
+ new_select.append(sc)
521
+ continue
522
+ col = sc.expr.primary_column
523
+ parts = col.split(".", 1) if "." in col else None
524
+ if not parts:
525
+ new_select.append(sc)
526
+ continue
527
+ col_meta = schema_graph.get_column(parts[0], parts[1])
528
+ if not col_meta or not col_meta.is_foreign_key or col_meta.value_type not in {"integer", "number"}:
529
+ new_select.append(sc)
530
+ continue
531
+ fk_target = col_meta.fk_target
532
+ if not fk_target:
533
+ new_select.append(sc)
534
+ continue
535
+ target_table, _ = fk_target
536
+ descs = best_descriptive_columns(
537
+ target_table, schema_graph, existing_terms, max_count=2,
538
+ )
539
+ if not descs:
540
+ new_select.append(sc)
541
+ continue
542
+ for desc in descs:
543
+ fq = f"{target_table}.{desc}"
544
+ new_expr = NormalizedExpr.from_column(fq)
545
+ new_select.append(SelectCol(expr=new_expr))
546
+ existing_terms.add(fq)
547
+ if target_table not in tables:
548
+ tables.append(target_table)
549
+ changed = True
550
+ debug(
551
+ f"[intent_resolve.expand_fk_select_to_descriptive] "
552
+ f"rewired select {col} -> {[f'{target_table}.{d}' for d in descs]}"
553
+ )
554
+ if not changed:
555
+ return intent
556
+ return replace(intent, select_cols=new_select, tables=sorted(tables))
557
+
558
+
559
+ def strip_spurious_group_by(intent: RuntimeIntent) -> RuntimeIntent:
560
+ """Remove group_by_cols when no aggregation exists in select or having.
561
+
562
+ Guards against LLM hallucinations that produce GROUP BY without any aggregate function in select_cols or having_param, and when group_by_cols are present but neither select nor having contains an aggregate the GROUP BY is stripped and the grain is downgraded to 'row_level' if it was 'grouped' with the same logic applied to each CTE step independently.
563
+
564
+ Args:
565
+
566
+ intent: RuntimeIntent to inspect.
567
+
568
+ Returns:
569
+
570
+ Updated RuntimeIntent with group_by_cols cleared when spurious, or the original intent unchanged.
571
+ """
572
+ main_changed = False
573
+ new_grain = intent.grain
574
+ new_gb = intent.group_by_cols or []
575
+ if intent.group_by_cols:
576
+ has_agg = any(sc.is_aggregated for sc in (intent.select_cols or []))
577
+ has_agg = has_agg or any(hp.left_expr.has_aggregation for hp in (intent.having_param or []))
578
+ if not has_agg:
579
+ debug(
580
+ f"[intent_resolve.strip_spurious_group_by] group_by_cols present without aggregation — stripping {[g.primary_term for g in intent.group_by_cols]}"
581
+ )
582
+ new_grain = "row_level" if intent.grain == "grouped" else intent.grain
583
+ new_gb = []
584
+ main_changed = True
585
+
586
+ new_cte_steps = []
587
+ cte_changed = False
588
+ for cte in intent.cte_steps or []:
589
+ if not (cte.group_by_cols or []):
590
+ new_cte_steps.append(cte)
591
+ continue
592
+ cte_has_agg = any(sc.is_aggregated for sc in (cte.select_cols or []))
593
+ cte_has_agg = cte_has_agg or any(hp.left_expr.has_aggregation for hp in (cte.having_param or []))
594
+ if cte_has_agg:
595
+ new_cte_steps.append(cte)
596
+ continue
597
+ debug(
598
+ f"[intent_resolve.strip_spurious_group_by] CTE '{cte.cte_name}' group_by_cols present without aggregation — stripping {[g.primary_term for g in cte.group_by_cols]}"
599
+ )
600
+ cte_grain = "row_level" if cte.grain == "grouped" else cte.grain
601
+ new_cte_steps.append(replace(cte, group_by_cols=[], grain=cte_grain))
602
+ cte_changed = True
603
+
604
+ if not main_changed and not cte_changed:
605
+ return intent
606
+ return replace(
607
+ intent,
608
+ group_by_cols=new_gb,
609
+ grain=new_grain,
610
+ cte_steps=new_cte_steps if cte_changed else (intent.cte_steps or []),
611
+ )
612
+
613
+
614
+ def _is_impossible_having(hp: HavingParam) -> bool:
615
+ """Return True when a HAVING condition is logically impossible.
616
+
617
+ Detects patterns such as ``COUNT(...) < 0`` or ``COUNT(...) <= -1``
618
+ which can never be satisfied. Only applies to COUNT since SUM can
619
+ legitimately produce negative values. Handles both raw-string forms
620
+ (``primary_term`` starting with ``COUNT``) and structured forms
621
+ where ``agg_func`` is stored on the ``MulGroup``.
622
+
623
+ Args: hp: A single HavingParam to inspect.
624
+
625
+ Returns: True if the condition can never be true.
626
+ """
627
+ left_expr = hp.left_expr
628
+ if not left_expr:
629
+ return False
630
+ primary = left_expr.primary_term
631
+ agg_func = ""
632
+ if left_expr.add_groups and left_expr.add_groups[0].agg_func:
633
+ agg_func = left_expr.add_groups[0].agg_func.upper()
634
+ is_count = bool(IMPOSSIBLE_HAVING_RE.match(primary)) or agg_func == "COUNT"
635
+ if not is_count:
636
+ return False
637
+ op = (hp.op or "").strip().lower()
638
+ val = hp.raw_value
639
+ if val is None:
640
+ return False
641
+ try:
642
+ numeric_val = float(val) if not isinstance(val, (int, float)) else val
643
+ except (ValueError, TypeError):
644
+ return False
645
+ if op in ("<", "<=") and numeric_val <= 0:
646
+ return True
647
+ if op == "=" and numeric_val < 0:
648
+ return True
649
+ return False
650
+
651
+
652
+ def strip_impossible_having(intent: RuntimeIntent) -> RuntimeIntent:
653
+ """Remove HAVING conditions that are logically impossible.
654
+
655
+ Filters out conditions like ``COUNT(...) < 0`` which can never be
656
+ satisfied. When all HAVING params are removed and the intent was
657
+ ``"grouped"`` with no remaining aggregation need, the grain is
658
+ downgraded.
659
+
660
+ Applies the same logic to each CTE step independently.
661
+
662
+ Args: intent: RuntimeIntent to inspect.
663
+
664
+ Returns: Updated RuntimeIntent with impossible HAVING params
665
+ removed, or the original intent unchanged.
666
+ """
667
+ main_having = intent.having_param or []
668
+ kept_main = [hp for hp in main_having if not _is_impossible_having(hp)]
669
+ main_changed = len(kept_main) != len(main_having)
670
+ if main_changed:
671
+ removed = len(main_having) - len(kept_main)
672
+ debug(f"[strip_impossible_having] removed {removed} impossible HAVING condition(s)")
673
+
674
+ new_cte_steps = []
675
+ cte_changed = False
676
+ for cte in intent.cte_steps or []:
677
+ cte_having = cte.having_param or []
678
+ kept_cte = [hp for hp in cte_having if not _is_impossible_having(hp)]
679
+ if len(kept_cte) != len(cte_having):
680
+ cte_changed = True
681
+ new_cte_steps.append(replace(cte, having_param=kept_cte))
682
+ else:
683
+ new_cte_steps.append(cte)
684
+
685
+ if not main_changed and not cte_changed:
686
+ return intent
687
+ return replace(
688
+ intent,
689
+ having_param=kept_main,
690
+ cte_steps=new_cte_steps if cte_changed else (intent.cte_steps or []),
691
+ )
692
+
693
+
694
+ def sanitize_table_names(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
695
+ """Strip SQL keywords accidentally prepended to table names.
696
+
697
+ LLMs sometimes hallucinate table references like ``"FROM orders"``
698
+ or ``"JOIN products"`` instead of bare ``"orders"`` /
699
+ ``"products"``. This function strips leading tokens that match SQL
700
+ keywords, keeping only the trailing word if it matches a known
701
+ schema table.
702
+
703
+ Args: intent: RuntimeIntent whose ``tables`` list may contain
704
+ polluted names. schema_graph: SchemaGraph providing the set of
705
+ valid table names.
706
+
707
+ Returns: Updated RuntimeIntent with sanitized table names, or
708
+ the original intent when no changes are needed.
709
+ """
710
+ valid_tables = {t.lower(): t for t in schema_graph.tables}
711
+ new_tables: list[str] = []
712
+ changed = False
713
+ for tbl in intent.tables or []:
714
+ if tbl.lower() in valid_tables:
715
+ new_tables.append(tbl)
716
+ continue
717
+ parts = tbl.split()
718
+ candidate = parts[-1].lower() if parts else ""
719
+ if candidate in valid_tables and any(p.lower() in SQL_KEYWORDS for p in parts[:-1]):
720
+ debug(f"[sanitize_table_names] corrected '{tbl}' → '{valid_tables[candidate]}'")
721
+ new_tables.append(valid_tables[candidate])
722
+ changed = True
723
+ else:
724
+ new_tables.append(tbl)
725
+
726
+ if not changed:
727
+ return intent
728
+ return replace(intent, tables=new_tables)
729
+
730
+
731
+ def _strip_join_condition_filters(filters: list[FilterParam], schema_graph: SchemaGraph) -> list[FilterParam]:
732
+ """Remove FilterParam entries that are FK equi-join conditions.
733
+
734
+ An equi-join condition is an equality filter (op '=') between two
735
+ fully-qualified columns that match a known FK edge in the schema (in
736
+ either direction).
737
+
738
+ Args: filters: List of FilterParam objects to process.
739
+ schema_graph: SchemaGraph providing FK edge definitions.
740
+
741
+ Returns: Filtered list with FK join conditions removed.
742
+ """
743
+ fk_pairs: set[tuple[str, str]] = set()
744
+ for tbl in schema_graph.tables.values():
745
+ for fk in tbl.foreign_keys:
746
+ if len(fk.src_cols) == 1 and len(fk.dst_cols) == 1:
747
+ left = f"{fk.src_table}.{fk.src_cols[0]}"
748
+ right = f"{fk.dst_table}.{fk.dst_cols[0]}"
749
+ fk_pairs.add((left, right))
750
+ fk_pairs.add((right, left))
751
+ result: list[FilterParam] = []
752
+ for fp in filters:
753
+ if fp.right_expr is None or fp.op != "=":
754
+ result.append(fp)
755
+ continue
756
+ left_term = fp.left_expr.primary_term
757
+ right_term = fp.right_expr.primary_term
758
+ if (left_term, right_term) in fk_pairs:
759
+ debug(f"[intent_resolve.strip_join_condition_filters] dropping FK join filter: {left_term} = {right_term}")
760
+ continue
761
+ result.append(fp)
762
+ return result
763
+
764
+
765
+ def strip_join_conditions(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
766
+ """Remove FK equi-join filters from the main query and all CTE
767
+ steps.
768
+
769
+ Args: intent: RuntimeIntent whose filter lists should be
770
+ stripped. schema_graph: SchemaGraph providing FK edge
771
+ definitions.
772
+
773
+ Returns: New RuntimeIntent with FK join conditions removed.
774
+ """
775
+ new_filters = _strip_join_condition_filters(intent.filters_param or [], schema_graph)
776
+ new_cte_steps = [
777
+ replace(
778
+ cte,
779
+ filters_param=_strip_join_condition_filters(cte.filters_param or [], schema_graph),
780
+ )
781
+ for cte in (intent.cte_steps or [])
782
+ ]
783
+ return replace(intent, filters_param=new_filters, cte_steps=new_cte_steps)
784
+
785
+
786
+ def _is_pk_column(col_ref: str, schema_graph: SchemaGraph) -> bool:
787
+ """Check whether a fully-qualified column reference points to a
788
+ primary key.
789
+
790
+ Args: col_ref: Column reference string in 'table.column' format.
791
+ schema_graph: SchemaGraph for metadata lookups.
792
+
793
+ Returns: True when the referenced column is marked as a primary
794
+ key.
795
+ """
796
+ if "." not in col_ref:
797
+ return False
798
+ tbl, col = col_ref.split(".", 1)
799
+ tbl_meta = schema_graph.tables.get(tbl)
800
+ if not tbl_meta:
801
+ return False
802
+ col_meta = tbl_meta.columns.get(col)
803
+ return col_meta.is_primary_key if col_meta else False
804
+
805
+
806
+ def _strip_distinct_prefix(term: str) -> str:
807
+ """Remove a DISTINCT prefix that leaked into an expression term.
808
+
809
+ Args: term: Expression term string that may start with 'DISTINCT
810
+ '.
811
+
812
+ Returns: Term string with the prefix removed, or the original
813
+ string if absent.
814
+ """
815
+ if term.upper().startswith("DISTINCT "):
816
+ return term[9:].strip()
817
+ return term
818
+
819
+
820
+ def _normalize_sc_pk_distinct(sc: SelectCol, schema_graph: SchemaGraph) -> SelectCol:
821
+ """Strip redundant DISTINCT from a single SelectCol if aggregated
822
+ column is a PK."""
823
+ e = sc.expr
824
+ agg = e.agg_func or (e.add_groups[0].agg_func if e.add_groups and e.add_groups[0].agg_func else None)
825
+ if agg != "count":
826
+ return sc
827
+ term = e.primary_term
828
+ clean_term = _strip_distinct_prefix(term)
829
+ if not _is_pk_column(clean_term, schema_graph):
830
+ return sc
831
+ needs_term_fix = clean_term != term
832
+ if not needs_term_fix:
833
+ return sc
834
+ new_groups = list(e.add_groups)
835
+ if new_groups and needs_term_fix:
836
+ g = new_groups[0]
837
+ new_mul = [clean_term if _strip_distinct_prefix(m) == clean_term else m for m in g.multiply]
838
+ new_groups[0] = replace(g, multiply=new_mul)
839
+ new_expr = replace(e, add_groups=new_groups)
840
+ debug(f"[normalize_pk_distinct] stripped DISTINCT prefix from PK term: {term} → {clean_term}")
841
+ return replace(sc, expr=new_expr)
842
+
843
+
844
+ def normalize_pk_distinct(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
845
+ """Strip redundant DISTINCT from COUNT expressions on primary key
846
+ columns.
847
+
848
+ COUNT(DISTINCT pk) is semantically equivalent to COUNT(pk) for
849
+ primary keys. Removes the DISTINCT prefix from affected select
850
+ columns.
851
+
852
+ Args: intent: RuntimeIntent to normalize. schema_graph:
853
+ SchemaGraph for PK lookups.
854
+
855
+ Returns: New RuntimeIntent with redundant DISTINCT removed.
856
+ """
857
+ new_select = [_normalize_sc_pk_distinct(sc, schema_graph) for sc in (intent.select_cols or [])]
858
+ new_cte_steps = []
859
+ for cte in intent.cte_steps or []:
860
+ cte_select = [_normalize_sc_pk_distinct(sc, schema_graph) for sc in (cte.select_cols or [])]
861
+ new_cte_steps.append(replace(cte, select_cols=cte_select))
862
+ return replace(intent, select_cols=new_select, cte_steps=new_cte_steps)
863
+
864
+
865
+ def _tables_from_columns(cols: list[str]) -> set[str]:
866
+ """Extract unique table names from a list of fully-qualified column
867
+ references.
868
+
869
+ Args: cols: List of column reference strings in 'table.column'
870
+ format.
871
+
872
+ Returns: Set of table name strings found before the '.'
873
+ separator.
874
+ """
875
+ tables: set[str] = set()
876
+ for col in cols:
877
+ if "." in col:
878
+ tables.add(col.split(".")[0])
879
+ return tables
880
+
881
+
882
+ def _collect_referenced_tables(
883
+ select_cols: list,
884
+ order_by_cols: list,
885
+ group_by_cols: list,
886
+ filters_param: list,
887
+ having_param: list,
888
+ ) -> set[str]:
889
+ """Collect all table names referenced in expressions across the
890
+ given clause lists.
891
+
892
+ Args: select_cols: SelectCol list. order_by_cols: OrderByCol
893
+ list. group_by_cols: NormalizedExpr list. filters_param:
894
+ FilterParam list. having_param: HavingParam list.
895
+
896
+ Returns: Set of table names extracted from fully-qualified
897
+ column references.
898
+ """
899
+ all_cols: list[str] = []
900
+ for sc in select_cols or []:
901
+ all_cols.extend(extract_columns_from_expr(sc.expr))
902
+ for obc in order_by_cols or []:
903
+ all_cols.extend(extract_columns_from_expr(obc.expr))
904
+ for g in group_by_cols or []:
905
+ all_cols.extend(extract_columns_from_expr(g))
906
+ for fp in filters_param or []:
907
+ all_cols.extend(extract_columns_from_expr(fp.left_expr))
908
+ if fp.right_expr:
909
+ all_cols.extend(extract_columns_from_expr(fp.right_expr))
910
+ for hp in having_param or []:
911
+ all_cols.extend(extract_columns_from_expr(hp.left_expr))
912
+ if hp.right_expr:
913
+ all_cols.extend(extract_columns_from_expr(hp.right_expr))
914
+ return _tables_from_columns(all_cols)
915
+
916
+
917
+ def _collect_essential_tables(
918
+ order_by_cols: list,
919
+ group_by_cols: list,
920
+ filters_param: list,
921
+ having_param: list,
922
+ ) -> set[str]:
923
+ """Collect tables referenced in non-select clauses.
924
+
925
+ These tables are essential for query semantics (filtering, grouping,
926
+ ordering, having) and must never be pruned.
927
+ """
928
+ all_cols: list[str] = []
929
+ for obc in order_by_cols or []:
930
+ all_cols.extend(extract_columns_from_expr(obc.expr))
931
+ for g in group_by_cols or []:
932
+ all_cols.extend(extract_columns_from_expr(g))
933
+ for fp in filters_param or []:
934
+ all_cols.extend(extract_columns_from_expr(fp.left_expr))
935
+ if fp.right_expr:
936
+ all_cols.extend(extract_columns_from_expr(fp.right_expr))
937
+ for hp in having_param or []:
938
+ all_cols.extend(extract_columns_from_expr(hp.left_expr))
939
+ if hp.right_expr:
940
+ all_cols.extend(extract_columns_from_expr(hp.right_expr))
941
+ return _tables_from_columns(all_cols)
942
+
943
+
944
+ def _find_fk_column_for_pk(
945
+ parent_table: str,
946
+ pk_column: str,
947
+ candidate_tables: set[str],
948
+ schema_graph: SchemaGraph,
949
+ ) -> str | None:
950
+ """Find an FK column on a *candidate_table* that references
951
+ *parent_table*.*pk_column*.
952
+
953
+ Returns the fully qualified FK column or ``None`` when no candidate
954
+ holds a matching foreign key.
955
+ """
956
+ target_key = (parent_table.lower(), pk_column.lower())
957
+ for tbl in candidate_tables:
958
+ tbl_meta = schema_graph.tables.get(tbl)
959
+ if not tbl_meta:
960
+ continue
961
+ for col_name, col_meta in tbl_meta.columns.items():
962
+ if not col_meta.is_foreign_key or not col_meta.fk_target:
963
+ continue
964
+ fk_tgt = (col_meta.fk_target[0].lower(), col_meta.fk_target[1].lower())
965
+ if fk_tgt == target_key:
966
+ return f"{tbl}.{col_name}"
967
+ return None
968
+
969
+
970
+ def _rewrite_redundant_pk_aggregations(
971
+ select_cols: list[SelectCol],
972
+ select_only_tables: set[str],
973
+ essential_tables: set[str],
974
+ schema_graph: SchemaGraph,
975
+ all_intent_tables: set[str] | None = None,
976
+ ) -> tuple[list[SelectCol], set[str]]:
977
+ """Rewrite aggregations on a parent PK to the child FK column,
978
+ eliminating the need for the parent table.
979
+
980
+ When a select-only table contributes only aggregated columns on its
981
+ primary key, and another intent table has an FK pointing to that PK,
982
+ the aggregation is rewritten to use the FK column. If another
983
+ aggregation on a different table already exists, the redundant
984
+ column is dropped instead of rewritten.
985
+
986
+ *all_intent_tables*, when provided, is used as the candidate set for
987
+ FK lookup so that bridge tables not yet referenced by any expression
988
+ are still discoverable.
989
+
990
+ Returns the updated select_cols and a set of tables whose references
991
+ were fully eliminated by rewriting.
992
+ """
993
+ eliminated: set[str] = set()
994
+ new_select: list[SelectCol] = list(select_cols)
995
+
996
+ for tbl in list(select_only_tables):
997
+ tbl_meta = schema_graph.tables.get(tbl)
998
+ if not tbl_meta:
999
+ continue
1000
+
1001
+ pk_col: str | None = None
1002
+ for col_name, col_meta in tbl_meta.columns.items():
1003
+ if col_meta.is_primary_key:
1004
+ pk_col = col_name
1005
+ break
1006
+ if not pk_col:
1007
+ continue
1008
+
1009
+ prefix = f"{tbl}.".lower()
1010
+ tbl_indices: list[int] = []
1011
+ all_agg_pk = True
1012
+ for idx, sc in enumerate(new_select):
1013
+ cols = extract_columns_from_expr(sc.expr)
1014
+ refs_tbl = any(c.lower().startswith(prefix) for c in cols)
1015
+ if not refs_tbl:
1016
+ continue
1017
+ tbl_indices.append(idx)
1018
+ col_ref = sc.expr.primary_column.lower()
1019
+ if not sc.is_aggregated or col_ref != f"{tbl}.{pk_col}".lower():
1020
+ all_agg_pk = False
1021
+ break
1022
+
1023
+ if not tbl_indices or not all_agg_pk:
1024
+ continue
1025
+
1026
+ candidate_pool = (all_intent_tables or set()) | essential_tables | select_only_tables
1027
+ other_tables = candidate_pool - {tbl}
1028
+ fk_col = _find_fk_column_for_pk(tbl, pk_col, other_tables, schema_graph)
1029
+ if not fk_col:
1030
+ continue
1031
+
1032
+ other_has_agg = any(sc.is_aggregated for idx, sc in enumerate(new_select) if idx not in tbl_indices)
1033
+
1034
+ rewritten: list[SelectCol] = []
1035
+ for idx, sc in enumerate(new_select):
1036
+ if idx not in tbl_indices:
1037
+ rewritten.append(sc)
1038
+ continue
1039
+ if other_has_agg:
1040
+ debug(f"[prune_unreferenced_tables] dropping redundant agg {sc.expr.primary_term} (other agg exists)")
1041
+ continue
1042
+ agg_func = sc.expr.agg_func or (sc.expr.add_groups[0].agg_func if sc.expr.add_groups else "")
1043
+ rewritten.append(SelectCol(expr=NormalizedExpr.from_agg(agg_func, fk_col)))
1044
+ debug(f"[prune_unreferenced_tables] rewrote {sc.expr.primary_term} → {agg_func}({fk_col})")
1045
+
1046
+ new_select = rewritten
1047
+ eliminated.add(tbl)
1048
+
1049
+ return new_select, eliminated
1050
+
1051
+
1052
+ def requalify_redundant_pk_references(
1053
+ intent: RuntimeIntent,
1054
+ schema_graph: SchemaGraph,
1055
+ ) -> RuntimeIntent:
1056
+ """Rewrite target-table PK references to source-table FK
1057
+ equivalents when the PK table contributes no other columns.
1058
+
1059
+ When the LLM places ``target.pk_col`` in ``group_by_cols`` or
1060
+ non-aggregated ``select_cols`` and another intent table holds an
1061
+ FK pointing to that PK, the reference is rewritten to
1062
+ ``source.fk_col``. This eliminates an unnecessary join to the
1063
+ target table.
1064
+
1065
+ Aggregated expressions and ``having_param`` are never touched
1066
+ because PK/FK usage inside aggregation functions (e.g.
1067
+ ``COUNT(table.pk)``) is intentional.
1068
+ """
1069
+ if not schema_graph:
1070
+ return intent
1071
+
1072
+ all_cols: list[str] = []
1073
+ for sc in intent.select_cols or []:
1074
+ all_cols.extend(extract_columns_from_expr(sc.expr))
1075
+ for obc in intent.order_by_cols or []:
1076
+ all_cols.extend(extract_columns_from_expr(obc.expr))
1077
+ for g in intent.group_by_cols or []:
1078
+ all_cols.extend(extract_columns_from_expr(g))
1079
+ for fp in intent.filters_param or []:
1080
+ all_cols.extend(extract_columns_from_expr(fp.left_expr))
1081
+ if fp.right_expr:
1082
+ all_cols.extend(extract_columns_from_expr(fp.right_expr))
1083
+ for hp in intent.having_param or []:
1084
+ all_cols.extend(extract_columns_from_expr(hp.left_expr))
1085
+ if hp.right_expr:
1086
+ all_cols.extend(extract_columns_from_expr(hp.right_expr))
1087
+
1088
+ col_counts: dict[str, int] = {}
1089
+ for col_ref in all_cols:
1090
+ tbl = col_ref.split(".")[0] if "." in col_ref else ""
1091
+ if tbl:
1092
+ col_counts[tbl] = col_counts.get(tbl, 0) + 1
1093
+
1094
+ intent_tables = set(intent.tables or [])
1095
+ fk_lookup: dict[str, tuple[str, str]] = {}
1096
+ for src_table_name in intent_tables:
1097
+ src_table = schema_graph.tables.get(src_table_name)
1098
+ if not src_table:
1099
+ continue
1100
+ for fk in src_table.foreign_keys:
1101
+ dst_table = fk.dst_table
1102
+ if dst_table not in intent_tables:
1103
+ continue
1104
+ for src_col, dst_col in zip(fk.src_cols, fk.dst_cols, strict=False):
1105
+ pk_ref = f"{dst_table}.{dst_col}"
1106
+ fk_ref = f"{src_table_name}.{src_col}"
1107
+ if pk_ref not in fk_lookup:
1108
+ fk_lookup[pk_ref] = (fk_ref, dst_table)
1109
+
1110
+ rewrite_map: dict[str, str] = {}
1111
+ for pk_ref, (fk_ref, pk_table) in fk_lookup.items():
1112
+ if col_counts.get(pk_table, 0) <= 1:
1113
+ rewrite_map[pk_ref] = fk_ref
1114
+
1115
+ if not rewrite_map:
1116
+ return intent
1117
+
1118
+ debug(f"[requalify_redundant_pk_references] rewrite_map: {rewrite_map}")
1119
+
1120
+ def _rewrite_col(col_ref: str) -> str:
1121
+ return rewrite_map.get(col_ref, col_ref)
1122
+
1123
+ def _rewrite_expr(expr: NormalizedExpr) -> NormalizedExpr:
1124
+ return replace_refs_in_expr(expr, _rewrite_col)
1125
+
1126
+ new_group_by = [
1127
+ _rewrite_expr(g) for g in (intent.group_by_cols or [])
1128
+ ]
1129
+
1130
+ new_select_cols = []
1131
+ for sc in intent.select_cols or []:
1132
+ if sc.is_aggregated:
1133
+ new_select_cols.append(sc)
1134
+ else:
1135
+ new_select_cols.append(replace(sc, expr=_rewrite_expr(sc.expr)))
1136
+
1137
+ return replace(
1138
+ intent,
1139
+ select_cols=new_select_cols,
1140
+ group_by_cols=new_group_by,
1141
+ )
1142
+
1143
+
1144
+ def prune_unreferenced_tables(
1145
+ intent: RuntimeIntent,
1146
+ schema_graph: SchemaGraph | None = None,
1147
+ ) -> RuntimeIntent:
1148
+ """Synchronize the tables list with tables actually referenced in
1149
+ expressions.
1150
+
1151
+ Any table referenced in any clause (select, filter, group_by,
1152
+ having, order_by) is kept. Tables present in the intent but not
1153
+ referenced anywhere are removed. Missing referenced tables are
1154
+ added.
1155
+
1156
+ Redundant PK aggregation columns are rewritten to their FK
1157
+ equivalents when possible, which may eliminate a table reference
1158
+ and thus the table itself.
1159
+
1160
+ The same synchronization is applied to each CTE step
1161
+ independently.
1162
+ """
1163
+ cte_names = {cte.cte_name for cte in (intent.cte_steps or [])}
1164
+ select_cols = list(intent.select_cols or [])
1165
+ referenced = (
1166
+ _collect_referenced_tables(
1167
+ select_cols,
1168
+ intent.order_by_cols,
1169
+ intent.group_by_cols,
1170
+ intent.filters_param,
1171
+ intent.having_param,
1172
+ )
1173
+ | cte_names
1174
+ )
1175
+ essential = (
1176
+ _collect_essential_tables(
1177
+ intent.order_by_cols,
1178
+ intent.group_by_cols,
1179
+ intent.filters_param,
1180
+ intent.having_param,
1181
+ )
1182
+ | cte_names
1183
+ )
1184
+
1185
+ select_only = referenced - essential
1186
+
1187
+ if schema_graph and select_only:
1188
+ all_intent = set(intent.tables or [])
1189
+ select_cols, eliminated = _rewrite_redundant_pk_aggregations(
1190
+ select_cols,
1191
+ select_only,
1192
+ essential,
1193
+ schema_graph,
1194
+ all_intent_tables=all_intent,
1195
+ )
1196
+ if eliminated:
1197
+ referenced = (
1198
+ _collect_referenced_tables(
1199
+ select_cols,
1200
+ intent.order_by_cols,
1201
+ intent.group_by_cols,
1202
+ intent.filters_param,
1203
+ intent.having_param,
1204
+ )
1205
+ | cte_names
1206
+ )
1207
+ select_only = referenced - essential
1208
+
1209
+ kept_tables = referenced
1210
+ new_select_cols = select_cols
1211
+
1212
+ original = set(intent.tables or [])
1213
+ added = (kept_tables - original) - cte_names
1214
+ removed = original - kept_tables
1215
+ main_tables = sorted(kept_tables)
1216
+ if added:
1217
+ debug(f"[prune_unreferenced_tables] added {sorted(added)} to tables")
1218
+ if removed:
1219
+ debug(f"[prune_unreferenced_tables] removed {sorted(removed)} from tables")
1220
+ if added or removed:
1221
+ debug(f"[prune_unreferenced_tables] final tables: {main_tables}")
1222
+
1223
+ new_cte_steps = []
1224
+ for cte in intent.cte_steps or []:
1225
+ cte_referenced = (
1226
+ _collect_referenced_tables(
1227
+ cte.select_cols,
1228
+ cte.order_by_cols,
1229
+ cte.group_by_cols,
1230
+ cte.filters_param,
1231
+ cte.having_param,
1232
+ )
1233
+ | cte_names
1234
+ )
1235
+ cte_original = set(cte.tables or [])
1236
+ cte_added = (cte_referenced - cte_original) - cte_names
1237
+ cte_removed = cte_original - cte_referenced
1238
+ cte_tables = sorted(cte_referenced)
1239
+ if cte_added:
1240
+ debug(f"[prune_unreferenced_tables] CTE '{cte.cte_name}' added {sorted(cte_added)} to tables")
1241
+ if cte_removed:
1242
+ debug(f"[prune_unreferenced_tables] CTE '{cte.cte_name}' removed {sorted(cte_removed)} from tables")
1243
+ new_cte_steps.append(replace(cte, tables=cte_tables))
1244
+
1245
+ return replace(
1246
+ intent,
1247
+ tables=main_tables,
1248
+ select_cols=new_select_cols,
1249
+ cte_steps=new_cte_steps,
1250
+ )
1251
+
1252
+
1253
+ def _correct_value_case(raw_value: str, top_k: list[str]) -> str | None:
1254
+ """Return the case-corrected version of a filter value using
1255
+ profiled sample values.
1256
+
1257
+ Performs a case-insensitive comparison of ``raw_value`` against each
1258
+ entry in ``top_k``. When a match is found whose casing differs from
1259
+ the original, the sample value is returned so the filter uses the
1260
+ casing that actually appears in the database.
1261
+
1262
+ Args: raw_value: The filter value string extracted from the user
1263
+ question. top_k: Profiled sample values
1264
+ (``ColumnMetadata.top_k_values``) for the column being
1265
+ filtered.
1266
+
1267
+ Returns: The matching sample string with correct casing, or
1268
+ ``None`` when no case-insensitive match is found or the casing
1269
+ already matches.
1270
+ """
1271
+ if not raw_value or not top_k:
1272
+ return None
1273
+ lower_val = raw_value.lower()
1274
+ for sample in top_k:
1275
+ if sample.lower() == lower_val and sample != raw_value:
1276
+ return sample
1277
+ return None
1278
+
1279
+
1280
+ def _match_enum_value(raw_value: str, col_meta: ColumnMetadata, schema_graph: SchemaGraph) -> str | None:
1281
+ """Case-insensitive match of *raw_value* against the column's enum
1282
+ type values.
1283
+
1284
+ Looks up ``col_meta.data_type`` in ``schema_graph.enum_values``.
1285
+ When the column belongs to a defined enum type, returns the enum
1286
+ member whose casing matches the database definition.
1287
+
1288
+ Args: raw_value: Filter value string extracted from the user
1289
+ question. col_meta: Column metadata for the filter target
1290
+ column. schema_graph: Schema graph holding ``enum_values``.
1291
+
1292
+ Returns: The correctly-cased enum member, or ``None`` when no
1293
+ match is found or the column is not an enum type.
1294
+ """
1295
+ if not schema_graph.enum_values:
1296
+ return None
1297
+ dtype_lower = (col_meta.data_type or "").lower()
1298
+ enum_vals = schema_graph.enum_values.get(dtype_lower)
1299
+ if not enum_vals:
1300
+ return None
1301
+ raw_lower = raw_value.lower()
1302
+ for ev in enum_vals:
1303
+ if ev.lower() == raw_lower:
1304
+ return ev
1305
+ return None
1306
+
1307
+
1308
+ def _extract_question_casing(raw_value: str, question: str) -> str | None:
1309
+ """Extract the user's original casing for a filter value from the
1310
+ question text.
1311
+
1312
+ Searches *question* for *raw_value* (case-insensitive). When a
1313
+ match is found and the matched substring contains at least one
1314
+ uppercase letter, returns it so the filter preserves the user's
1315
+ intended casing. When the matched substring is entirely lowercase
1316
+ the result is ``None`` so that downstream tiers (e.g. ILIKE
1317
+ fallback) can still apply.
1318
+
1319
+ Args: raw_value: Filter value string to locate. question:
1320
+ Original natural-language question.
1321
+
1322
+ Returns: The matched substring from *question* with its original
1323
+ casing, or ``None`` when no match is found, the match is all-
1324
+ lowercase, or the casing already equals *raw_value*.
1325
+ """
1326
+ if not raw_value or not question:
1327
+ return None
1328
+ q_lower = question.lower()
1329
+ val_lower = raw_value.lower()
1330
+ idx = q_lower.find(val_lower)
1331
+ if idx < 0:
1332
+ return None
1333
+ matched = question[idx : idx + len(val_lower)]
1334
+ if matched == matched.lower():
1335
+ return None
1336
+ if matched == raw_value:
1337
+ return None
1338
+ return matched
1339
+
1340
+
1341
+ def _resolve_filter_list_cascade(
1342
+ filters: list[FilterParam],
1343
+ schema_graph: SchemaGraph,
1344
+ question: str,
1345
+ ) -> tuple[list[FilterParam], bool]:
1346
+ """Resolve string filter values to database-safe casing.
1347
+
1348
+ For each eligible string/enum filter the function first checks for
1349
+ an enum-type match (tier 1) which preserves the exact database
1350
+ casing. When no enum match exists, the raw_value is lowercased so
1351
+ the SQL prompt can pair it with a ``LOWER(column)`` wrapper for
1352
+ case-insensitive comparison.
1353
+
1354
+ Args: filters: List of ``FilterParam`` objects to inspect and
1355
+ correct. schema_graph: Schema graph with enum and profiled
1356
+ column data. question: Original natural-language question.
1357
+
1358
+ Returns: Tuple of ``(resolved_filters, changed)`` where
1359
+ *changed* is ``True`` when at least one filter was modified.
1360
+ """
1361
+ new_filters: list[FilterParam] = []
1362
+ changed = False
1363
+ for fp in filters:
1364
+ if fp.raw_value is None or fp.value_type not in {"string", "enum"}:
1365
+ new_filters.append(fp)
1366
+ continue
1367
+ col = fp.left_expr.primary_column
1368
+ parts = col.split(".", 1) if "." in col else None
1369
+ if not parts:
1370
+ new_filters.append(fp)
1371
+ continue
1372
+ col_meta = schema_graph.get_column(parts[0], parts[1])
1373
+ if not col_meta:
1374
+ new_filters.append(fp)
1375
+ continue
1376
+
1377
+ if isinstance(fp.raw_value, list):
1378
+ new_vals: list = []
1379
+ list_changed = False
1380
+ for v in fp.raw_value:
1381
+ if not isinstance(v, str):
1382
+ new_vals.append(v)
1383
+ continue
1384
+ enum_match = _match_enum_value(v, col_meta, schema_graph)
1385
+ if enum_match is not None:
1386
+ if enum_match != v:
1387
+ list_changed = True
1388
+ new_vals.append(enum_match)
1389
+ else:
1390
+ lowered = v.lower()
1391
+ if lowered != v:
1392
+ list_changed = True
1393
+ new_vals.append(lowered)
1394
+ if list_changed:
1395
+ new_filters.append(replace(fp, raw_value=new_vals))
1396
+ changed = True
1397
+ debug(f"[intent_repair.resolve_filter_list_cascade] resolved list values on {col}")
1398
+ else:
1399
+ new_filters.append(fp)
1400
+ continue
1401
+
1402
+ if not isinstance(fp.raw_value, str):
1403
+ new_filters.append(fp)
1404
+ continue
1405
+
1406
+ enum_match = _match_enum_value(fp.raw_value, col_meta, schema_graph)
1407
+ if enum_match is not None:
1408
+ if enum_match != fp.raw_value:
1409
+ new_filters.append(replace(fp, raw_value=enum_match))
1410
+ changed = True
1411
+ debug(f"[intent_repair.resolve_filter_list_cascade] enum {col}: '{fp.raw_value}' -> '{enum_match}'")
1412
+ else:
1413
+ new_filters.append(fp)
1414
+ continue
1415
+
1416
+ lowered = fp.raw_value.lower()
1417
+ if lowered != fp.raw_value:
1418
+ new_filters.append(replace(fp, raw_value=lowered))
1419
+ changed = True
1420
+ debug(f"[intent_repair.resolve_filter_list_cascade] lower {col}: '{fp.raw_value}' -> '{lowered}'")
1421
+ else:
1422
+ new_filters.append(fp)
1423
+ return new_filters, changed
1424
+
1425
+
1426
+ def resolve_filter_value_case(intent: RuntimeIntent, schema_graph: SchemaGraph, question: str) -> RuntimeIntent:
1427
+ """Resolve string filter values across the main query and CTE steps.
1428
+
1429
+ Tier 1 — enum match via ``schema_graph.enum_values`` preserves exact
1430
+ database casing. All other string filters have their raw_value
1431
+ lowercased so the SQL generator can pair them with ``LOWER(column)``
1432
+ for case-insensitive comparison.
1433
+
1434
+ Args: intent: ``RuntimeIntent`` whose filter values may have
1435
+ incorrect casing. schema_graph: Schema graph with enum and
1436
+ profiled column data. question: Original natural-language
1437
+ question.
1438
+
1439
+ Returns: Updated ``RuntimeIntent`` with resolved filter value
1440
+ casing, or the original intent unchanged when no corrections are
1441
+ needed.
1442
+ """
1443
+
1444
+ def process(filters: list[FilterParam]) -> tuple[list[FilterParam], bool]:
1445
+ return _resolve_filter_list_cascade(filters, schema_graph, question)
1446
+
1447
+ return _apply_filters_to_main_and_ctes(intent, process)
1448
+
1449
+
1450
+ def _coerce_element(val: Any, data_type: str) -> Any:
1451
+ """Coerce a single IN-list element to match the column's data type.
1452
+
1453
+ When the column is numeric, strings that look like numbers are cast
1454
+ to ``int`` or ``float``. Non-castable values are returned
1455
+ unchanged.
1456
+
1457
+ Args: val: Single element from an IN-list raw_value.
1458
+ data_type: Lowercased column data_type string.
1459
+
1460
+ Returns: Coerced value, or the original value if coercion is not
1461
+ applicable.
1462
+ """
1463
+ if data_type not in NUMERIC_DATA_TYPES:
1464
+ return val
1465
+ if isinstance(val, (int, float)):
1466
+ return val
1467
+ if not isinstance(val, str):
1468
+ return val
1469
+ stripped = val.strip()
1470
+ try:
1471
+ if "." in stripped:
1472
+ return float(stripped)
1473
+ return int(stripped)
1474
+ except (ValueError, OverflowError):
1475
+ return val
1476
+
1477
+
1478
+ def _consolidate_in_list(vals: list, data_type: str) -> str:
1479
+ """Convert a list of IN-values into a formatted SQL-ready string.
1480
+
1481
+ String elements are wrapped in single quotes (``'R', 'PG-13'``),
1482
+ while numeric elements are joined as-is (``1, 2, 3``).
1483
+
1484
+ Args: vals: List of IN-list elements (already type-coerced).
1485
+ data_type: Lowercased column data_type for formatting decisions.
1486
+
1487
+ Returns: Comma-separated string suitable for direct SQL
1488
+ substitution.
1489
+ """
1490
+ if all(isinstance(v, (int, float)) for v in vals):
1491
+ return ", ".join(str(v) for v in vals)
1492
+ parts: list[str] = []
1493
+ for v in vals:
1494
+ if isinstance(v, str):
1495
+ parts.append(f"'{v}'")
1496
+ else:
1497
+ parts.append(str(v))
1498
+ return ", ".join(parts)
1499
+
1500
+
1501
+ def _normalize_in_types_for_list(
1502
+ filters: list[FilterParam],
1503
+ schema_graph: SchemaGraph,
1504
+ ) -> tuple[list[FilterParam], bool]:
1505
+ """Coerce IN / NOT IN list elements to match their column types and
1506
+ consolidate to strings.
1507
+
1508
+ For each filter with ``op`` in (``in``, ``not in``) and a list
1509
+ ``raw_value``, each element is coerced to the column's native type.
1510
+ The list is then consolidated into a formatted SQL string so
1511
+ ``substitute_params`` can perform direct substitution.
1512
+
1513
+ Args: filters: Filter params to inspect and coerce.
1514
+ schema_graph: Schema graph for column type lookup.
1515
+
1516
+ Returns: Tuple of ``(coerced_filters, changed)``.
1517
+ """
1518
+ new_filters: list[FilterParam] = []
1519
+ changed = False
1520
+ for fp in filters:
1521
+ if fp.op.lower() not in {"in", "not in"} or not isinstance(fp.raw_value, list):
1522
+ new_filters.append(fp)
1523
+ continue
1524
+ col = fp.left_expr.primary_column
1525
+ parts = col.split(".", 1) if "." in col else None
1526
+ if not parts:
1527
+ new_filters.append(fp)
1528
+ continue
1529
+ col_meta = schema_graph.get_column(parts[0], parts[1])
1530
+ dtype = (col_meta.data_type or "").lower() if col_meta else ""
1531
+ coerced = [_coerce_element(v, dtype) for v in fp.raw_value]
1532
+ consolidated = _consolidate_in_list(coerced, dtype)
1533
+ if consolidated != fp.raw_value:
1534
+ new_filters.append(replace(fp, raw_value=consolidated))
1535
+ changed = True
1536
+ debug(f"[intent_resolve_normalize_in_types_for_list] {col}: {fp.raw_value!r} -> {consolidated!r}")
1537
+ else:
1538
+ new_filters.append(fp)
1539
+ return new_filters, changed
1540
+
1541
+
1542
+ def normalize_in_filter_types(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
1543
+ """Coerce IN / NOT IN list elements and consolidate across main
1544
+ query and CTE steps.
1545
+
1546
+ Args: intent: RuntimeIntent whose IN-list filter values may need
1547
+ type coercion. schema_graph: Schema graph for column type
1548
+ lookup.
1549
+
1550
+ Returns: Updated RuntimeIntent with coerced and consolidated IN-
1551
+ list values.
1552
+ """
1553
+
1554
+ def process(filters: list[FilterParam]) -> tuple[list[FilterParam], bool]:
1555
+ return _normalize_in_types_for_list(filters, schema_graph)
1556
+
1557
+ intent = _apply_filters_to_main_and_ctes(intent, process)
1558
+ return decompose_in_not_in_filters(intent)
1559
+
1560
+
1561
+ def _decompose_in_list(
1562
+ filters: list[FilterParam],
1563
+ max_list_size: int = 10,
1564
+ ) -> list[FilterParam]:
1565
+ """Expand small IN / NOT IN lists into primitive comparisons.
1566
+
1567
+ For ``IN`` with a short raw_value list, creates one ``=`` filter per
1568
+ element combined with ``OR``. For ``NOT IN`` with a short list,
1569
+ creates ``!=`` filters combined with ``AND``. Large lists are left
1570
+ unchanged.
1571
+ """
1572
+ new_filters: list[FilterParam] = []
1573
+ for fp in filters:
1574
+ raw = fp.raw_value
1575
+ op_lower = (fp.op or "").lower()
1576
+ if not isinstance(raw, list) or op_lower not in {"in", "not in"} or len(raw) == 0 or len(raw) > max_list_size:
1577
+ new_filters.append(fp)
1578
+ continue
1579
+ elems = list(raw)
1580
+ bool_op = "OR" if op_lower == "in" else "AND"
1581
+ new_group = []
1582
+ for idx, val in enumerate(elems):
1583
+ new_fp = replace(
1584
+ fp,
1585
+ op="=" if op_lower == "in" else "!=",
1586
+ raw_value=val,
1587
+ bool_op=bool_op if idx > 0 else (fp.bool_op or "AND"),
1588
+ )
1589
+ new_group.append(new_fp)
1590
+ new_filters.extend(new_group)
1591
+ return new_filters
1592
+
1593
+
1594
+ def decompose_in_not_in_filters(intent: RuntimeIntent) -> RuntimeIntent:
1595
+ """Decompose small IN / NOT IN lists across main query and CTE steps."""
1596
+ main_filters = _decompose_in_list(intent.filters_param or [])
1597
+ new_ctes: list[RuntimeCteStep] = []
1598
+ for cte in intent.cte_steps or []:
1599
+ decomposed = _decompose_in_list(cte.filters_param or [])
1600
+ new_ctes.append(replace(cte, filters_param=decomposed))
1601
+ return replace(intent, filters_param=main_filters, cte_steps=new_ctes or intent.cte_steps)
1602
+
1603
+
1604
+ def _resolve_boolean_value(raw_value: Any, col_meta: ColumnMetadata) -> tuple[Any, str] | None:
1605
+ """Resolve a filter raw_value to a Python ``bool`` for a native
1606
+ boolean column.
1607
+
1608
+ Only applies to columns whose ``data_type`` contains ``"bool"``.
1609
+ Converts common truthy/falsy representations (integers, strings,
1610
+ Python bools) to ``True``/``False`` and sets the value_type to
1611
+ ``"boolean"`` so that ``substitute_params`` emits the SQL literal
1612
+ ``TRUE`` or ``FALSE``.
1613
+
1614
+ Args: raw_value: The current filter value (int, str, bool, or
1615
+ other). col_meta: Column metadata for the filter target column.
1616
+
1617
+ Returns: Tuple of ``(resolved_value, "boolean")`` when
1618
+ conversion succeeds, or ``None`` when the column is not a native
1619
+ boolean or the value cannot be mapped.
1620
+ """
1621
+ dtype_lower = (col_meta.data_type or "").lower()
1622
+ if "bool" not in dtype_lower:
1623
+ return None
1624
+ if isinstance(raw_value, bool):
1625
+ return raw_value, "boolean"
1626
+ val_str = str(raw_value).lower().strip()
1627
+ if val_str in BOOLEAN_TRUTHY_VALUES:
1628
+ return True, "boolean"
1629
+ if val_str in BOOLEAN_FALSY_VALUES:
1630
+ return False, "boolean"
1631
+ return None
1632
+
1633
+
1634
+ def _normalize_boolean_filter_list(
1635
+ filters: list[FilterParam], schema_graph: SchemaGraph
1636
+ ) -> tuple[list[FilterParam], bool]:
1637
+ """Normalise boolean filter values in a list of ``FilterParam``
1638
+ objects.
1639
+
1640
+ For each filter targeting a native boolean column whose
1641
+ ``raw_value`` is not already a Python ``bool``, converts the value
1642
+ and sets ``value_type`` to ``"boolean"``.
1643
+
1644
+ Args: filters: List of ``FilterParam`` objects to inspect and
1645
+ correct. schema_graph: ``SchemaGraph`` providing column
1646
+ metadata.
1647
+
1648
+ Returns: Tuple of ``(normalised_filters, changed)`` where
1649
+ *changed* is ``True`` when at least one filter was rewritten.
1650
+ """
1651
+ new_filters: list[FilterParam] = []
1652
+ changed = False
1653
+ for fp in filters:
1654
+ if fp.raw_value is None:
1655
+ new_filters.append(fp)
1656
+ continue
1657
+ col = fp.left_expr.primary_column
1658
+ parts = col.split(".", 1) if "." in col else None
1659
+ if not parts:
1660
+ new_filters.append(fp)
1661
+ continue
1662
+ col_meta = schema_graph.get_column(parts[0], parts[1])
1663
+ if not col_meta:
1664
+ new_filters.append(fp)
1665
+ continue
1666
+ resolved = _resolve_boolean_value(fp.raw_value, col_meta)
1667
+ if resolved is None:
1668
+ new_filters.append(fp)
1669
+ continue
1670
+ bool_val, vtype = resolved
1671
+ new_filters.append(replace(fp, raw_value=bool_val, value_type=vtype))
1672
+ changed = True
1673
+ debug(
1674
+ f"[intent_resolve_normalize_boolean_filter_list] {col}: "
1675
+ f"{fp.raw_value!r} ({fp.value_type}) → {bool_val!r} ({vtype})"
1676
+ )
1677
+ return new_filters, changed
1678
+
1679
+
1680
+ def normalize_boolean_filter_values(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
1681
+ """Normalise boolean filter values across the main query and CTE
1682
+ steps.
1683
+
1684
+ LLM-extracted intents frequently represent boolean filters as
1685
+ integers. For native boolean columns these must become Python
1686
+ ``True``/``False`` with ``value_type="boolean"`` so that
1687
+ ``substitute_params`` emits the SQL literal ``TRUE`` or ``FALSE``
1688
+ rather than an integer or quoted string.
1689
+
1690
+ Args: intent: ``RuntimeIntent`` whose filter values may need
1691
+ boolean normalisation. schema_graph: ``SchemaGraph``
1692
+ providing column data-type metadata.
1693
+
1694
+ Returns: Updated ``RuntimeIntent`` with normalised boolean
1695
+ filter values, or the original intent unchanged when no
1696
+ corrections are needed.
1697
+ """
1698
+
1699
+ def process(filters: list[FilterParam]) -> tuple[list[FilterParam], bool]:
1700
+ return _normalize_boolean_filter_list(filters, schema_graph)
1701
+
1702
+ return _apply_filters_to_main_and_ctes(intent, process)
1703
+
1704
+
1705
+ def _normalize_null_filter_list(
1706
+ filters: list[FilterParam],
1707
+ ) -> tuple[list[FilterParam], bool]:
1708
+ """Normalise IS NULL / IS NOT NULL filters to canonical form.
1709
+
1710
+ Ensures ``value_type`` is ``"null"`` and ``raw_value`` is ``None``
1711
+ for any filter whose operator is ``"is null"`` or ``"is not null"``.
1712
+ """
1713
+ result: list[FilterParam] = []
1714
+ changed = False
1715
+ for fp in filters:
1716
+ if fp.op in ("is null", "is not null"):
1717
+ needs_fix = fp.value_type != "null" or fp.raw_value is not None
1718
+ if needs_fix:
1719
+ result.append(replace(fp, value_type="null", raw_value=None))
1720
+ changed = True
1721
+ continue
1722
+ result.append(fp)
1723
+ return result, changed
1724
+
1725
+
1726
+ def normalize_null_filter_values(intent: RuntimeIntent) -> RuntimeIntent:
1727
+ """Normalise null-operator filters across main query and CTE steps.
1728
+
1729
+ Ensures every ``IS NULL`` / ``IS NOT NULL`` filter carries
1730
+ ``value_type="null"`` and ``raw_value=None`` so downstream
1731
+ validation does not flag a spurious type mismatch.
1732
+ """
1733
+ return _apply_filters_to_main_and_ctes(intent, _normalize_null_filter_list)