aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1292 @@
1
+ """Column resolution, normalization, and grain enforcement for intent post-processing.
2
+
3
+ Resolves bare column names to source tables or CTE steps via resolve_column_map and resolve_cte_column_maps, normalizes filter and having operators to canonical forms, deduplicates conditions, and sorts select, order, filter, and having clauses by structural keys.
4
+
5
+ Enforces grain consistency between scalar, grouped, and row_level settings versus aggregation and GROUP BY, validates tables and columns against the schema, applies algebraic simplification such as constant folding and like-term combining to all expressions, and normalizes CTE names and COUNT(1) to COUNT(*).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import re
11
+ from collections import defaultdict
12
+ from collections.abc import Callable
13
+ from dataclasses import replace
14
+
15
+ from .config import REVERSE_OP_MAP, normalize_value_type
16
+ from .contracts_base import SchemaGraph
17
+ from .contracts_core import (
18
+ ExprValue,
19
+ FilterParam,
20
+ HavingParam,
21
+ MulGroup,
22
+ NormalizedExpr,
23
+ OrderByCol,
24
+ RuntimeCteStep,
25
+ RuntimeIntent,
26
+ SelectCol,
27
+ )
28
+ from .core_utils import debug, normalize_op
29
+ from .intent_expr import extract_columns_from_expr, replace_refs_in_expr
30
+ from .intent_repair import best_descriptive_column
31
+ from .sql_gen import _render_expr_sql
32
+
33
+
34
+ def normalize_count_star(intent: RuntimeIntent) -> RuntimeIntent:
35
+ """Convert COUNT(1) references to COUNT(*) throughout an intent for consistency.
36
+
37
+ Args:
38
+
39
+ intent: RuntimeIntent to normalize.
40
+
41
+ Returns:
42
+
43
+ New RuntimeIntent with COUNT(1) replaced by COUNT(*) in all expressions.
44
+ """
45
+
46
+ def _fix_count(term: str) -> str:
47
+ if term.upper() == "COUNT(1)":
48
+ return "COUNT(*)"
49
+ return term
50
+
51
+ def _fix_filter_list(params):
52
+ return [
53
+ replace(
54
+ fp,
55
+ left_expr=replace_refs_in_expr(fp.left_expr, _fix_count),
56
+ right_expr=(replace_refs_in_expr(fp.right_expr, _fix_count) if fp.right_expr else None),
57
+ )
58
+ for fp in params
59
+ ]
60
+
61
+ new_select_cols = [replace(sc, expr=replace_refs_in_expr(sc.expr, _fix_count)) for sc in (intent.select_cols or [])]
62
+ new_order_by_cols = [
63
+ replace(obc, expr=replace_refs_in_expr(obc.expr, _fix_count)) for obc in (intent.order_by_cols or [])
64
+ ]
65
+ new_filters = _fix_filter_list(intent.filters_param or [])
66
+ new_having = _fix_filter_list(intent.having_param or [])
67
+ new_cte_steps = []
68
+ for cte in intent.cte_steps or []:
69
+ cte_sc = [replace(sc, expr=replace_refs_in_expr(sc.expr, _fix_count)) for sc in (cte.select_cols or [])]
70
+ cte_obc = [replace(obc, expr=replace_refs_in_expr(obc.expr, _fix_count)) for obc in (cte.order_by_cols or [])]
71
+ cte_fp = _fix_filter_list(cte.filters_param or [])
72
+ cte_hp = _fix_filter_list(cte.having_param or [])
73
+ new_cte_steps.append(
74
+ replace(
75
+ cte,
76
+ select_cols=cte_sc,
77
+ order_by_cols=cte_obc,
78
+ filters_param=cte_fp,
79
+ having_param=cte_hp,
80
+ )
81
+ )
82
+ return replace(
83
+ intent,
84
+ select_cols=new_select_cols,
85
+ order_by_cols=new_order_by_cols,
86
+ filters_param=new_filters,
87
+ having_param=new_having,
88
+ cte_steps=new_cte_steps,
89
+ )
90
+
91
+
92
+ def sort_select_cols(cols: list[SelectCol]) -> list[SelectCol]:
93
+ """Sort select columns so non-aggregated expressions come before aggregated ones and ties are broken by expression signature.
94
+
95
+ Args:
96
+
97
+ cols: List of SelectCol objects to sort.
98
+
99
+ Returns:
100
+
101
+ Sorted list of SelectCol objects.
102
+ """
103
+
104
+ def key_fn(sc: SelectCol) -> tuple[int, str]:
105
+ return (1 if sc.is_aggregated else 0, sc.signature_key)
106
+
107
+ return sorted(cols, key=key_fn)
108
+
109
+
110
+ def sort_order_by_cols(cols: list[OrderByCol]) -> list[OrderByCol]:
111
+ """Sort order-by columns so non-aggregated expressions come before aggregated ones and ties are broken by expression signature.
112
+
113
+ Args:
114
+
115
+ cols: List of OrderByCol objects to sort.
116
+
117
+ Returns:
118
+
119
+ Sorted list of OrderByCol objects.
120
+ """
121
+
122
+ def key_fn(obc: OrderByCol) -> tuple[int, str]:
123
+ return (1 if obc.is_aggregated else 0, obc.signature_key)
124
+
125
+ return sorted(cols, key=key_fn)
126
+
127
+
128
+ def _filter_structural_key(fp: FilterParam) -> tuple[str, str, str, str]:
129
+ """Return the structural sort key for a single FilterParam.
130
+
131
+ Args:
132
+
133
+ fp: FilterParam to compute the key for.
134
+
135
+ Returns:
136
+
137
+ Tuple of (left_sig, op, right_sig, value_type) with all components lowercased.
138
+ """
139
+ left = fp.left_expr.signature_key if fp.left_expr else ""
140
+ right = fp.right_expr.signature_key if fp.right_expr else ""
141
+ return (left, fp.op.lower(), right, fp.value_type.lower())
142
+
143
+
144
+ def _having_structural_key(hp: HavingParam) -> tuple[str, str, str, str]:
145
+ """Return the structural sort key for a single HavingParam.
146
+
147
+ Args:
148
+
149
+ hp: HavingParam to compute the key for.
150
+
151
+ Returns:
152
+
153
+ Tuple of (left_sig, op, right_sig, value_type) with all components lowercased.
154
+ """
155
+ left = hp.left_expr.signature_key if hp.left_expr else ""
156
+ right = hp.right_expr.signature_key if hp.right_expr else ""
157
+ return (left, hp.op.lower(), right, hp.value_type.lower())
158
+
159
+
160
+ def _canonicalize_condition_order(
161
+ items: list,
162
+ structural_key_fn: Callable,
163
+ ) -> list:
164
+ """Canonicalize a flat condition list by parsing the positional bool_op operators into a precedence tree (AND binds tighter than OR), sorting at each level using the commutativity of AND/OR, and re-serializing with adjusted bool_ops.
165
+
166
+ The last element's bool_op is treated as the inter-group connector and is preserved on whichever element ends up last after sorting.
167
+
168
+ Works for both FilterParam and HavingParam since they share the same bool_op / filter_group field layout.
169
+
170
+ Args:
171
+
172
+ items: List of FilterParam or HavingParam objects.
173
+
174
+ structural_key_fn: Callable that returns a sortable tuple for one item.
175
+
176
+ Returns:
177
+
178
+ New list with the same elements in canonical order and bool_ops reassigned.
179
+ """
180
+ if len(items) <= 1:
181
+ return list(items)
182
+ inter_connector = items[-1].bool_op
183
+ ops: list[str] = [it.bool_op for it in items[:-1]]
184
+ chunks: list[list] = []
185
+ current_chunk: list = [items[0]]
186
+ for i, op in enumerate(ops):
187
+ if op == "OR":
188
+ chunks.append(current_chunk)
189
+ current_chunk = [items[i + 1]]
190
+ else:
191
+ current_chunk.append(items[i + 1])
192
+ chunks.append(current_chunk)
193
+ sorted_chunks: list[list] = []
194
+ for chunk in chunks:
195
+ sorted_chunks.append(sorted(chunk, key=structural_key_fn))
196
+ sorted_chunks.sort(key=lambda ch: structural_key_fn(ch[0]))
197
+ result: list = []
198
+ for ci, chunk in enumerate(sorted_chunks):
199
+ for fi, item in enumerate(chunk):
200
+ is_last_in_chunk = fi == len(chunk) - 1
201
+ is_last_chunk = ci == len(sorted_chunks) - 1
202
+ if is_last_chunk and is_last_in_chunk:
203
+ new_bool_op = inter_connector
204
+ elif is_last_in_chunk:
205
+ new_bool_op = "OR"
206
+ else:
207
+ new_bool_op = "AND"
208
+ result.append(replace(item, bool_op=new_bool_op))
209
+ return result
210
+
211
+
212
+ def sort_filters(filters: list[FilterParam]) -> list[FilterParam]:
213
+ """Sort filter parameters using precedence-aware group canonicalization.
214
+
215
+ Partitions filters by filter_group, canonicalizes order within each group using the precedence tree algorithm (AND binds tighter than OR, both are commutative), then sorts the groups themselves at the inter-group level using the same algorithm on group-representative elements.
216
+
217
+ Args:
218
+
219
+ filters: List of FilterParam objects to sort.
220
+
221
+ Returns:
222
+
223
+ Sorted list of FilterParam objects with bool_ops adjusted to reflect the new canonical positions.
224
+ """
225
+ if not filters:
226
+ return []
227
+ buckets: dict[int | None, list[FilterParam]] = defaultdict(list)
228
+ for fp in filters:
229
+ buckets[fp.filter_group].append(fp)
230
+ canonicalized_groups: list[tuple[int | None, list[FilterParam]]] = []
231
+ for gid, group in buckets.items():
232
+ canonicalized_groups.append((gid, _canonicalize_condition_order(group, _filter_structural_key)))
233
+ if len(canonicalized_groups) == 1:
234
+ return canonicalized_groups[0][1]
235
+ representatives: list[FilterParam] = []
236
+ group_map: dict[int, tuple[int | None, list[FilterParam]]] = {}
237
+ for idx, (gid, group) in enumerate(canonicalized_groups):
238
+ rep = group[-1]
239
+ representatives.append(replace(rep, filter_group=idx))
240
+ group_map[idx] = (gid, group)
241
+ sorted_reps = _canonicalize_condition_order(representatives, _filter_structural_key)
242
+ result: list[FilterParam] = []
243
+ for _ri, rep in enumerate(sorted_reps):
244
+ proxy_id = rep.filter_group
245
+ assert isinstance(proxy_id, int)
246
+ real_gid, group = group_map[proxy_id]
247
+ inter_connector = rep.bool_op
248
+ for fi, fp in enumerate(group):
249
+ if fi == len(group) - 1:
250
+ result.append(replace(fp, bool_op=inter_connector, filter_group=real_gid))
251
+ else:
252
+ result.append(replace(fp, filter_group=real_gid))
253
+ return result
254
+
255
+
256
+ def sort_having(having: list[HavingParam]) -> list[HavingParam]:
257
+ """Sort having parameters using precedence-aware group canonicalization.
258
+
259
+ Partitions having conditions by filter_group, canonicalizes order within each group using the precedence tree algorithm, then sorts the groups at the inter-group level using the same algorithm.
260
+
261
+ Args:
262
+
263
+ having: List of HavingParam objects to sort.
264
+
265
+ Returns:
266
+
267
+ Sorted list of HavingParam objects with bool_ops adjusted to reflect the new canonical positions.
268
+ """
269
+ if not having:
270
+ return []
271
+ buckets: dict[int | None, list[HavingParam]] = defaultdict(list)
272
+ for hp in having:
273
+ buckets[hp.filter_group].append(hp)
274
+ canonicalized_groups: list[tuple[int | None, list[HavingParam]]] = []
275
+ for gid, group in buckets.items():
276
+ canonicalized_groups.append((gid, _canonicalize_condition_order(group, _having_structural_key)))
277
+ if len(canonicalized_groups) == 1:
278
+ return canonicalized_groups[0][1]
279
+ representatives: list[HavingParam] = []
280
+ group_map: dict[int, tuple[int | None, list[HavingParam]]] = {}
281
+ for idx, (gid, group) in enumerate(canonicalized_groups):
282
+ rep = group[-1]
283
+ representatives.append(replace(rep, filter_group=idx))
284
+ group_map[idx] = (gid, group)
285
+ sorted_reps = _canonicalize_condition_order(representatives, _having_structural_key)
286
+ result: list[HavingParam] = []
287
+ for _ri, rep in enumerate(sorted_reps):
288
+ proxy_id = rep.filter_group
289
+ assert isinstance(proxy_id, int)
290
+ real_gid, group = group_map[proxy_id]
291
+ inter_connector = rep.bool_op
292
+ for fi, hp in enumerate(group):
293
+ if fi == len(group) - 1:
294
+ result.append(replace(hp, bool_op=inter_connector, filter_group=real_gid))
295
+ else:
296
+ result.append(replace(hp, filter_group=real_gid))
297
+ return result
298
+
299
+
300
+ def _is_cte_output_groupable(term: str, cte_steps: list[RuntimeCteStep]) -> bool:
301
+ """Return True if term references a CTE output column."""
302
+ if "." not in term:
303
+ return False
304
+ table_part, col_part = term.split(".", 1)
305
+ table_lower = table_part.strip().lower()
306
+ col_lower = col_part.strip().lower()
307
+ for cte in cte_steps or []:
308
+ if cte.cte_name.lower() == table_lower:
309
+ out_cols = cte.output_columns or []
310
+ return any(c.strip().lower() == col_lower for c in out_cols)
311
+ return False
312
+
313
+
314
+ def enforce_grain_consistency(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
315
+ """Ensure GROUP BY grain matches select columns and augment with descriptive columns.
316
+
317
+ When group_by_cols is empty but select has mixed aggregated/non-aggregated columns, infers the group-by from groupable non-aggregated columns and for PK group-by columns auto-adds the best descriptive column to both select and group_by for readability.
318
+
319
+ Args:
320
+
321
+ intent: RuntimeIntent to enforce grain on.
322
+
323
+ schema_graph: SchemaGraph for column role lookups.
324
+
325
+ Returns:
326
+
327
+ Updated RuntimeIntent with grain set to 'grouped' when grouping is present.
328
+ """
329
+ group_by = list(intent.group_by_cols or [])
330
+ select_cols = list(intent.select_cols or [])
331
+ cte_steps = intent.cte_steps or []
332
+ if not group_by:
333
+ has_agg = any(sc.is_aggregated for sc in select_cols)
334
+ non_agg = [sc for sc in select_cols if not sc.is_aggregated]
335
+ if not (has_agg and non_agg):
336
+ return intent
337
+ groupable: list[NormalizedExpr] = []
338
+ for sc in non_agg:
339
+ term = sc.expr.primary_term
340
+ parts = term.split(".", 1) if "." in term else None
341
+ if not parts:
342
+ groupable.append(sc.expr)
343
+ continue
344
+ if _is_cte_output_groupable(term, cte_steps):
345
+ groupable.append(sc.expr)
346
+ continue
347
+ tbl_meta = schema_graph.tables.get(parts[0])
348
+ col_meta = tbl_meta.columns.get(parts[1]) if tbl_meta else None
349
+ if not col_meta or col_meta.is_groupable:
350
+ groupable.append(sc.expr)
351
+ group_by = sorted(groupable, key=lambda g: g.signature_key)
352
+ debug(
353
+ f"[intent_resolve.enforce_grain_consistency] inferred group_by from groupable non-agg cols: {[g.primary_term for g in group_by]}"
354
+ )
355
+ existing_terms = {sc.expr.primary_term for sc in select_cols}
356
+ gb_terms = {g.primary_term for g in group_by}
357
+ has_agg_check = any(sc.is_aggregated for sc in select_cols)
358
+ if has_agg_check and gb_terms:
359
+ for sc in select_cols:
360
+ if sc.is_aggregated:
361
+ continue
362
+ term = sc.expr.primary_term
363
+ if term in gb_terms:
364
+ continue
365
+ parts = term.split(".", 1) if "." in term else None
366
+ if not parts:
367
+ group_by.append(sc.expr)
368
+ gb_terms.add(term)
369
+ debug(f"[intent_resolve.enforce_grain_consistency] auto-added non-agg select col to group_by: {term}")
370
+ continue
371
+ if _is_cte_output_groupable(term, cte_steps):
372
+ group_by.append(sc.expr)
373
+ gb_terms.add(term)
374
+ debug(f"[intent_resolve.enforce_grain_consistency] auto-added CTE output col to group_by: {term}")
375
+ continue
376
+ tbl_meta = schema_graph.tables.get(parts[0])
377
+ col_meta = tbl_meta.columns.get(parts[1]) if tbl_meta else None
378
+ if not col_meta or col_meta.is_groupable:
379
+ group_by.append(sc.expr)
380
+ gb_terms.add(term)
381
+ debug(f"[intent_resolve.enforce_grain_consistency] auto-added non-agg select col to group_by: {term}")
382
+ intent_tables = set(intent.tables or [])
383
+ for gb_expr in list(group_by):
384
+ gb_col = gb_expr.primary_term
385
+ parts = gb_col.split(".", 1) if "." in gb_col else None
386
+ if not parts:
387
+ continue
388
+ tbl_meta = schema_graph.tables.get(parts[0])
389
+ col_meta = tbl_meta.columns.get(parts[1]) if tbl_meta else None
390
+ if not col_meta:
391
+ continue
392
+ if col_meta.is_primary_key:
393
+ desc = best_descriptive_column(parts[0], schema_graph, existing_terms | gb_terms)
394
+ if desc:
395
+ fq = f"{parts[0]}.{desc}"
396
+ group_by.append(NormalizedExpr.from_column(fq))
397
+ select_cols.append(SelectCol(expr=NormalizedExpr.from_column(fq)))
398
+ existing_terms.add(fq)
399
+ gb_terms.add(fq)
400
+ debug(f"[intent_resolve.enforce_grain_consistency] auto-added descriptive column {fq}")
401
+ continue
402
+ if col_meta.is_foreign_key:
403
+ for fk in tbl_meta.foreign_keys or []:
404
+ if parts[1] not in fk.src_cols:
405
+ continue
406
+ if fk.dst_table not in intent_tables:
407
+ continue
408
+ desc = best_descriptive_column(fk.dst_table, schema_graph, existing_terms | gb_terms)
409
+ if not desc:
410
+ continue
411
+ fq = f"{fk.dst_table}.{desc}"
412
+ group_by.append(NormalizedExpr.from_column(fq))
413
+ select_cols.append(SelectCol(expr=NormalizedExpr.from_column(fq)))
414
+ existing_terms.add(fq)
415
+ gb_terms.add(fq)
416
+ debug(
417
+ f"[intent_resolve.enforce_grain_consistency] auto-added FK descriptive column {fq} via {parts[0]}.{parts[1]}->{fk.dst_table}"
418
+ )
419
+ return replace(
420
+ intent,
421
+ group_by_cols=sorted(group_by, key=lambda g: g.signature_key),
422
+ select_cols=select_cols,
423
+ grain="grouped",
424
+ )
425
+
426
+
427
+ def _cte_is_aggregated(cte: RuntimeCteStep) -> bool:
428
+ """Return True if the CTE is grouped or scalar by structure."""
429
+ has_agg = any(sc.is_aggregated for sc in (cte.select_cols or []))
430
+ if cte.group_by_cols:
431
+ return True
432
+ if has_agg:
433
+ return True
434
+ return False
435
+
436
+
437
+ def force_main_grain_when_using_grouped_cte(intent: RuntimeIntent) -> RuntimeIntent:
438
+ """Promote main grain to grouped when it aggregates over a grouped CTE.
439
+
440
+ Only promotes when the main query itself contains aggregated
441
+ select columns, preventing false promotion when the main query
442
+ simply selects pre-aggregated CTE output without further
443
+ aggregation.
444
+
445
+ Args:
446
+
447
+ intent: RuntimeIntent with optional cte_steps and main tables.
448
+
449
+ Returns:
450
+
451
+ Updated RuntimeIntent with grain "grouped" when applicable.
452
+ """
453
+ if intent.grain != "row_level":
454
+ return intent
455
+ cte_steps = intent.cte_steps or []
456
+ if not cte_steps:
457
+ return intent
458
+ has_main_agg = any(
459
+ sc.is_aggregated for sc in (intent.select_cols or [])
460
+ )
461
+ if not has_main_agg:
462
+ return intent
463
+ main_tables = set(intent.tables or [])
464
+ aggregated_cte_names = {
465
+ cte.cte_name for cte in cte_steps if _cte_is_aggregated(cte)
466
+ }
467
+ if not main_tables.intersection(aggregated_cte_names):
468
+ return intent
469
+ return replace(intent, grain="grouped")
470
+
471
+
472
+ def enforce_cte_grain_consistency(cte: RuntimeCteStep) -> RuntimeCteStep:
473
+ """Set grain on a CTE step based on its structural properties.
474
+
475
+ Sets grain to ``"grouped"`` when ``group_by_cols`` are present, or
476
+ ``"scalar"`` when the CTE contains aggregation in select columns
477
+ but no GROUP BY clause, indicating a single-row aggregate result.
478
+ Sorts group_by_cols by signature_key for consistency with main
479
+ intent.
480
+ """
481
+ has_agg = any(sc.is_aggregated for sc in (cte.select_cols or []))
482
+ if not cte.group_by_cols:
483
+ if has_agg and cte.grain != "scalar":
484
+ return replace(cte, grain="scalar")
485
+ return cte
486
+ sorted_gb = sorted(cte.group_by_cols, key=lambda g: g.signature_key)
487
+ return replace(cte, grain="grouped", group_by_cols=sorted_gb)
488
+
489
+
490
+ def resolve_column_map(columns: list[str], schema_graph: SchemaGraph, tables: list[str]) -> dict[str, str]:
491
+ """Map bare or qualified column references to their source table names.
492
+
493
+ For qualified references (table.col), validates the table against the allowed list and for bare references resolves by scanning the column lists of all candidate tables while logging a debug message when a column is ambiguous across multiple tables.
494
+
495
+ Args:
496
+
497
+ columns: List of column reference strings (bare or table-qualified).
498
+
499
+ schema_graph: SchemaGraph containing table/column metadata.
500
+
501
+ tables: Allowed table names to resolve against.
502
+
503
+ Returns:
504
+
505
+ Dictionary mapping bare column name to its source table name.
506
+ """
507
+ column_map: dict[str, str] = {}
508
+ table_col_index: dict[str, set[str]] = {}
509
+ for tbl in tables:
510
+ if tbl not in schema_graph.tables:
511
+ continue
512
+ table_col_index[tbl] = {c.lower() for c in schema_graph.tables[tbl].columns}
513
+ for col in columns:
514
+ col_stripped = col.strip()
515
+ if "." in col_stripped:
516
+ tbl_ref, col_ref = col_stripped.split(".", 1)
517
+ col_ref_lower = col_ref.strip().lower()
518
+ tbl_ref_lower = tbl_ref.strip().lower()
519
+ for tbl in tables:
520
+ if (
521
+ tbl.lower() == tbl_ref_lower or tbl.split(".")[-1].lower() == tbl_ref_lower
522
+ ) and col_ref_lower in table_col_index.get(tbl, set()):
523
+ column_map[col_ref.strip()] = tbl
524
+ break
525
+ continue
526
+ col_lower = col_stripped.lower()
527
+ candidates = [tbl for tbl in tables if col_lower in table_col_index.get(tbl, set())]
528
+ if len(candidates) == 1:
529
+ column_map[col_stripped] = candidates[0]
530
+ elif len(candidates) > 1:
531
+ column_map[col_stripped] = candidates[0]
532
+ debug(f"[intent_resolve.resolve_column_map] ambiguous column '{col}': {candidates}, using {candidates[0]}")
533
+ return column_map
534
+
535
+
536
+ def resolve_cte_column_maps(cte_steps: list[RuntimeCteStep]) -> list[RuntimeCteStep]:
537
+ """Build a column_map for each CTE step, mapping bare column names to source CTE names.
538
+
539
+ Processes CTE steps in order so each step can reference output columns from prior steps and bare column names found in earlier CTE output lists are mapped to that CTE's name.
540
+
541
+ Args:
542
+
543
+ cte_steps: Ordered list of RuntimeCteStep objects.
544
+
545
+ Returns:
546
+
547
+ New list of RuntimeCteStep objects with column_map populated.
548
+ """
549
+ cte_output_cols: dict[str, set[str]] = {}
550
+ result = []
551
+ for cte in cte_steps:
552
+ cte_name = cte.cte_name
553
+ out_cols = set(cte.output_columns or [])
554
+ for sc in cte.select_cols or []:
555
+ col = sc.expr.primary_column
556
+ if col:
557
+ out_cols.add(col.split(".")[-1])
558
+ cte_output_cols[cte_name] = out_cols
559
+ available_sources: dict[str, str] = {}
560
+ for prev_cte_name, prev_cols in cte_output_cols.items():
561
+ if prev_cte_name == cte_name:
562
+ continue
563
+ for c in prev_cols:
564
+ available_sources[c.lower()] = prev_cte_name
565
+ cols_to_resolve: list[str] = []
566
+ for sc in cte.select_cols or []:
567
+ cols_to_resolve.extend(extract_columns_from_expr(sc.expr))
568
+ for obc in cte.order_by_cols or []:
569
+ cols_to_resolve.extend(extract_columns_from_expr(obc.expr))
570
+ for fp in cte.filters_param or []:
571
+ cols_to_resolve.extend(extract_columns_from_expr(fp.left_expr))
572
+ if fp.right_expr:
573
+ cols_to_resolve.extend(extract_columns_from_expr(fp.right_expr))
574
+ for hp in cte.having_param or []:
575
+ cols_to_resolve.extend(extract_columns_from_expr(hp.left_expr))
576
+ if hp.right_expr:
577
+ cols_to_resolve.extend(extract_columns_from_expr(hp.right_expr))
578
+ column_map: dict[str, str] = {}
579
+ for col in cols_to_resolve:
580
+ col_stripped = col.strip()
581
+ if "." in col_stripped:
582
+ bare = col_stripped.split(".", 1)[1].strip()
583
+ source = col_stripped.split(".", 1)[0].strip()
584
+ column_map[bare] = source
585
+ elif col_stripped.lower() in available_sources:
586
+ column_map[col_stripped] = available_sources[col_stripped.lower()]
587
+ updated_cte = replace(cte, column_map=column_map)
588
+ result.append(updated_cte)
589
+ return result
590
+
591
+
592
+ def normalize_cte_names(intent: RuntimeIntent) -> RuntimeIntent:
593
+ """Rename all CTE steps to canonical names (cte1, cte2, ...) and update all references.
594
+
595
+ Replaces old CTE name occurrences in tables lists, expression terms, column maps, and output column names throughout both CTE steps and the main query.
596
+
597
+ Args:
598
+
599
+ intent: RuntimeIntent with CTE steps to normalize.
600
+
601
+ Returns:
602
+
603
+ New RuntimeIntent with CTE names and all cross-references updated.
604
+ """
605
+ cte_steps = intent.cte_steps or []
606
+ if not cte_steps:
607
+ return intent
608
+ old_to_new: dict[str, str] = {}
609
+ for i, cte in enumerate(cte_steps, start=1):
610
+ new_name = f"cte{i}"
611
+ old_to_new[cte.cte_name] = new_name
612
+
613
+ def replace_cte_refs(s: str) -> str:
614
+ for old, new in old_to_new.items():
615
+ pattern = re.compile(rf"\b{re.escape(old)}\b", re.IGNORECASE)
616
+ s = pattern.sub(new, s)
617
+ return s
618
+
619
+ def _update_expr(expr: NormalizedExpr) -> NormalizedExpr:
620
+ return replace_refs_in_expr(expr, replace_cte_refs)
621
+
622
+ new_cte_steps = []
623
+ for cte in cte_steps:
624
+ new_name = old_to_new[cte.cte_name]
625
+ new_tables = [replace_cte_refs(t) for t in (cte.tables or [])]
626
+ new_select_cols = [replace(sc, expr=_update_expr(sc.expr)) for sc in (cte.select_cols or [])]
627
+ new_group_by = [_update_expr(g) for g in (cte.group_by_cols or [])]
628
+ new_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (cte.order_by_cols or [])]
629
+ new_filters = []
630
+ for fp in cte.filters_param or []:
631
+ new_fp = replace(
632
+ fp,
633
+ left_expr=_update_expr(fp.left_expr),
634
+ right_expr=_update_expr(fp.right_expr) if fp.right_expr else None,
635
+ )
636
+ new_filters.append(new_fp)
637
+ new_having = []
638
+ for hp in cte.having_param or []:
639
+ new_hp = replace(
640
+ hp,
641
+ left_expr=_update_expr(hp.left_expr),
642
+ right_expr=_update_expr(hp.right_expr) if hp.right_expr else None,
643
+ )
644
+ new_having.append(new_hp)
645
+ new_column_map = {}
646
+ for k, v in (cte.column_map or {}).items():
647
+ new_column_map[replace_cte_refs(k)] = replace_cte_refs(v)
648
+ new_output_columns = [replace_cte_refs(oc) for oc in (cte.output_columns or [])]
649
+ new_ocm = {replace_cte_refs(k): v for k, v in (cte.output_column_metadata or {}).items()}
650
+ new_cte = replace(
651
+ cte,
652
+ cte_name=new_name,
653
+ tables=new_tables,
654
+ select_cols=new_select_cols,
655
+ group_by_cols=new_group_by,
656
+ order_by_cols=new_order_by,
657
+ filters_param=new_filters,
658
+ having_param=new_having,
659
+ column_map=new_column_map,
660
+ output_columns=new_output_columns,
661
+ output_column_metadata=new_ocm,
662
+ )
663
+ new_cte_steps.append(new_cte)
664
+
665
+ new_main_tables = [replace_cte_refs(t) for t in (intent.tables or [])]
666
+ new_main_select = [replace(sc, expr=_update_expr(sc.expr)) for sc in (intent.select_cols or [])]
667
+ new_main_group_by = [_update_expr(g) for g in (intent.group_by_cols or [])]
668
+ new_main_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (intent.order_by_cols or [])]
669
+ new_main_filters = []
670
+ for fp in intent.filters_param or []:
671
+ new_fp = replace(
672
+ fp,
673
+ left_expr=_update_expr(fp.left_expr),
674
+ right_expr=_update_expr(fp.right_expr) if fp.right_expr else None,
675
+ )
676
+ new_main_filters.append(new_fp)
677
+ new_main_having = []
678
+ for hp in intent.having_param or []:
679
+ new_hp = replace(
680
+ hp,
681
+ left_expr=_update_expr(hp.left_expr),
682
+ right_expr=_update_expr(hp.right_expr) if hp.right_expr else None,
683
+ )
684
+ new_main_having.append(new_hp)
685
+ new_main_column_map = {}
686
+ for k, v in (intent.column_map or {}).items():
687
+ new_main_column_map[replace_cte_refs(k)] = replace_cte_refs(v)
688
+ return replace(
689
+ intent,
690
+ tables=new_main_tables,
691
+ select_cols=new_main_select,
692
+ group_by_cols=new_main_group_by,
693
+ order_by_cols=new_main_order_by,
694
+ filters_param=new_main_filters,
695
+ having_param=new_main_having,
696
+ column_map=new_main_column_map,
697
+ cte_steps=new_cte_steps,
698
+ )
699
+
700
+
701
+ def _normalize_expr_ref_for_alias(rendered: str) -> str:
702
+ """Normalize a rendered expression for alias-map matching.
703
+
704
+ Strips spaces inside comparison/operator tokens and normalizes
705
+ aggregation function casing so small formatting differences still match.
706
+ """
707
+ s = rendered.strip()
708
+ for op in (" >= ", " <= ", " != ", " = ", " > ", " < ", " + ", " - ", " * ", " / "):
709
+ s = s.replace(op, op.replace(" ", ""))
710
+ s = re.sub(r"\b(count|sum|avg|min|max)\s*\(", r"\1(", s, flags=re.IGNORECASE)
711
+ return s
712
+
713
+
714
+ def _cte_output_alias_map(intent: RuntimeIntent) -> dict[str, str]:
715
+ """Build a map from CTE-qualified expression form to CTE-qualified output column name.
716
+
717
+ For each CTE step, each select expression is rendered; the key is cte_name.rendered_expr
718
+ and the value is cte_name.output_columns[i]. Also adds a stripped/normalized key so
719
+ small formatting changes (spaces in operators, function casing) still match.
720
+ """
721
+ alias_map: dict[str, str] = {}
722
+ for cte in intent.cte_steps or []:
723
+ output_cols = cte.output_columns or []
724
+ for i, sc in enumerate(cte.select_cols or []):
725
+ if i >= len(output_cols):
726
+ continue
727
+ rendered = _render_expr_sql(sc.expr)
728
+ from_ref = f"{cte.cte_name}.{rendered}"
729
+ to_ref = f"{cte.cte_name}.{output_cols[i]}"
730
+ if from_ref != to_ref:
731
+ alias_map[from_ref] = to_ref
732
+ stripped = _normalize_expr_ref_for_alias(rendered)
733
+ if stripped != rendered:
734
+ alias_map[f"{cte.cte_name}.{stripped}"] = to_ref
735
+ return alias_map
736
+
737
+
738
+ def rewrite_cte_output_refs_to_aliases(intent: RuntimeIntent) -> RuntimeIntent:
739
+ """Rewrite references to CTE outputs from expression form to output column alias.
740
+
741
+ After CTE names are normalized to cte1, cte2, references in the main query (e.g. cte1.COUNT(table_1.column_1))
742
+ are replaced with the deterministic output column name (e.g. cte1.count_column_1) so validation and SQL
743
+ generation see consistent column names.
744
+ """
745
+ alias_map = _cte_output_alias_map(intent)
746
+ if not alias_map:
747
+ return intent
748
+
749
+ def replacer(s: str) -> str:
750
+ return alias_map.get(s, s)
751
+
752
+ def _update_expr(expr: NormalizedExpr) -> NormalizedExpr:
753
+ return replace_refs_in_expr(expr, replacer)
754
+
755
+ new_cte_steps = []
756
+ for cte in intent.cte_steps or []:
757
+ new_select = [replace(sc, expr=_update_expr(sc.expr)) for sc in (cte.select_cols or [])]
758
+ new_group_by = [_update_expr(g) for g in (cte.group_by_cols or [])]
759
+ new_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (cte.order_by_cols or [])]
760
+ new_filters = [
761
+ replace(fp, left_expr=_update_expr(fp.left_expr), right_expr=_update_expr(fp.right_expr) if fp.right_expr else None)
762
+ for fp in (cte.filters_param or [])
763
+ ]
764
+ new_having = [
765
+ replace(hp, left_expr=_update_expr(hp.left_expr), right_expr=_update_expr(hp.right_expr) if hp.right_expr else None)
766
+ for hp in (cte.having_param or [])
767
+ ]
768
+ new_cte_steps.append(
769
+ replace(
770
+ cte,
771
+ select_cols=new_select,
772
+ group_by_cols=new_group_by,
773
+ order_by_cols=new_order_by,
774
+ filters_param=new_filters,
775
+ having_param=new_having,
776
+ )
777
+ )
778
+
779
+ new_main_select = [replace(sc, expr=_update_expr(sc.expr)) for sc in (intent.select_cols or [])]
780
+ new_main_group_by = [_update_expr(g) for g in (intent.group_by_cols or [])]
781
+ new_main_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (intent.order_by_cols or [])]
782
+ new_main_filters = [
783
+ replace(fp, left_expr=_update_expr(fp.left_expr), right_expr=_update_expr(fp.right_expr) if fp.right_expr else None)
784
+ for fp in (intent.filters_param or [])
785
+ ]
786
+ new_main_having = [
787
+ replace(hp, left_expr=_update_expr(hp.left_expr), right_expr=_update_expr(hp.right_expr) if hp.right_expr else None)
788
+ for hp in (intent.having_param or [])
789
+ ]
790
+ return replace(
791
+ intent,
792
+ select_cols=new_main_select,
793
+ group_by_cols=new_main_group_by,
794
+ order_by_cols=new_main_order_by,
795
+ filters_param=new_main_filters,
796
+ having_param=new_main_having,
797
+ cte_steps=new_cte_steps,
798
+ )
799
+
800
+
801
+ def enforce_schema(intent: RuntimeIntent, schema_graph: SchemaGraph) -> tuple[RuntimeIntent, list[str]]:
802
+ """Validate intent table and column references against the schema graph.
803
+
804
+ Checks that every table referenced in the intent exists in the schema or is a CTE name and that every qualified column reference points to a known column.
805
+
806
+ Args:
807
+
808
+ intent: RuntimeIntent to validate.
809
+
810
+ schema_graph: SchemaGraph providing the authoritative table/column set.
811
+
812
+ Returns:
813
+
814
+ Tuple of (intent, errors) where errors is a list of human-readable violation strings and the intent is returned unchanged.
815
+ """
816
+ errors: list[str] = []
817
+ valid_tables = set(schema_graph.tables.keys())
818
+ cte_names = {cte.cte_name for cte in (intent.cte_steps or [])}
819
+ for tbl in intent.tables or []:
820
+ if tbl not in valid_tables and tbl not in cte_names:
821
+ errors.append(f"Unknown table: {tbl}")
822
+
823
+ def _check_expr_cols(exprs: list, label: str) -> None:
824
+ for item in exprs:
825
+ expr = item.expr if hasattr(item, "expr") else item
826
+ for col in extract_columns_from_expr(expr):
827
+ if "." in col:
828
+ tbl_ref, col_ref = col.split(".", 1)
829
+ if tbl_ref in valid_tables:
830
+ tbl_meta = schema_graph.tables[tbl_ref]
831
+ if col_ref not in tbl_meta.columns:
832
+ errors.append(f"Unknown {label} column: {col}")
833
+
834
+ def _check_filter_cols(params: list, label: str) -> None:
835
+ for fp in params:
836
+ for col in extract_columns_from_expr(fp.left_expr):
837
+ if "." in col:
838
+ tbl_ref, col_ref = col.split(".", 1)
839
+ if tbl_ref in valid_tables:
840
+ tbl_meta = schema_graph.tables[tbl_ref]
841
+ if col_ref not in tbl_meta.columns:
842
+ errors.append(f"Unknown {label} column: {col}")
843
+ if fp.right_expr:
844
+ for col in extract_columns_from_expr(fp.right_expr):
845
+ if "." in col:
846
+ tbl_ref, col_ref = col.split(".", 1)
847
+ if tbl_ref in valid_tables:
848
+ tbl_meta = schema_graph.tables[tbl_ref]
849
+ if col_ref not in tbl_meta.columns:
850
+ errors.append(f"Unknown {label} column: {col}")
851
+
852
+ def _check_bare_cols(cols: list, label: str) -> None:
853
+ for g in cols:
854
+ col = g.primary_term if hasattr(g, "primary_term") else str(g)
855
+ if "." in col:
856
+ tbl_ref, col_ref = col.split(".", 1)
857
+ if tbl_ref in valid_tables:
858
+ tbl_meta = schema_graph.tables[tbl_ref]
859
+ if col_ref not in tbl_meta.columns:
860
+ errors.append(f"Unknown {label} column: {col}")
861
+
862
+ _check_expr_cols(intent.select_cols or [], "select")
863
+ _check_expr_cols(intent.order_by_cols or [], "order_by")
864
+ _check_filter_cols(intent.filters_param or [], "filter")
865
+ _check_filter_cols(intent.having_param or [], "having")
866
+ _check_bare_cols(intent.group_by_cols or [], "group_by")
867
+ for cte in intent.cte_steps or []:
868
+ ctx = f"CTE '{cte.cte_name}'"
869
+ for tbl in cte.tables or []:
870
+ if tbl not in valid_tables and tbl not in cte_names:
871
+ errors.append(f"{ctx} unknown table: {tbl}")
872
+ _check_expr_cols(cte.select_cols or [], f"{ctx} select")
873
+ _check_expr_cols(cte.order_by_cols or [], f"{ctx} order_by")
874
+ _check_filter_cols(cte.filters_param or [], f"{ctx} filter")
875
+ _check_filter_cols(cte.having_param or [], f"{ctx} having")
876
+ _check_bare_cols(cte.group_by_cols or [], f"{ctx} group_by")
877
+ if errors:
878
+ debug(f"[intent_resolve.enforce_schema] validation errors: {errors}")
879
+ return intent, errors
880
+
881
+
882
+ def _simplify_expr(expr: NormalizedExpr) -> NormalizedExpr:
883
+ """Apply algebraic simplifications to a NormalizedExpr.
884
+
885
+ Performs constant folding that accumulates numeric literal terms, like-term combining that merges groups with identical structural keys, zero-coefficient elimination, negative coefficient normalization that moves negative-coeff add_groups to sub_groups, and coefficient-to-value collapse where groups with no operands become ExprValue offsets.
886
+
887
+ Args:
888
+
889
+ expr: NormalizedExpr to simplify.
890
+
891
+ Returns:
892
+
893
+ New NormalizedExpr in simplified canonical form.
894
+ """
895
+ add_groups: list[MulGroup] = []
896
+ sub_groups: list[MulGroup] = []
897
+ add_vals: list[ExprValue] = []
898
+ sub_vals: list[ExprValue] = []
899
+ parameterized_add: list[ExprValue] = []
900
+ parameterized_sub: list[ExprValue] = []
901
+ for v in expr.add_values:
902
+ (parameterized_add if v.param_key else add_vals).append(v)
903
+ for v in expr.sub_values:
904
+ (parameterized_sub if v.param_key else sub_vals).append(v)
905
+ net_const = sum(v.value for v in add_vals) - sum(v.value for v in sub_vals)
906
+ for g in expr.add_groups:
907
+ if not g.multiply and not g.divide and not g.agg_func and not g.scalar_func and not g.inner_scalar_func:
908
+ net_const += g.coefficient
909
+ else:
910
+ add_groups.append(g)
911
+ for g in expr.sub_groups:
912
+ if not g.multiply and not g.divide and not g.agg_func and not g.scalar_func and not g.inner_scalar_func:
913
+ net_const -= g.coefficient
914
+ else:
915
+ sub_groups.append(g)
916
+ bucket: dict[str, float] = {}
917
+ group_map: dict[str, MulGroup] = {}
918
+ for g in add_groups:
919
+ key = g.structural_key
920
+ bucket[key] = bucket.get(key, 0.0) + g.coefficient
921
+ if key not in group_map:
922
+ group_map[key] = g
923
+ for g in sub_groups:
924
+ key = g.structural_key
925
+ bucket[key] = bucket.get(key, 0.0) - g.coefficient
926
+ if key not in group_map:
927
+ group_map[key] = g
928
+ final_add: list[MulGroup] = []
929
+ final_sub: list[MulGroup] = []
930
+ for key, coeff in bucket.items():
931
+ if coeff == 0.0:
932
+ continue
933
+ ref = group_map[key]
934
+ if coeff > 0:
935
+ final_add.append(
936
+ MulGroup(
937
+ coefficient=coeff,
938
+ multiply=list(ref.multiply),
939
+ divide=list(ref.divide),
940
+ agg_func=ref.agg_func,
941
+ scalar_func=ref.scalar_func,
942
+ inner_scalar_func=ref.inner_scalar_func,
943
+ scalar_func_args=list(ref.scalar_func_args),
944
+ inner_scalar_func_args=list(ref.inner_scalar_func_args),
945
+ )
946
+ )
947
+ else:
948
+ final_sub.append(
949
+ MulGroup(
950
+ coefficient=abs(coeff),
951
+ multiply=list(ref.multiply),
952
+ divide=list(ref.divide),
953
+ agg_func=ref.agg_func,
954
+ scalar_func=ref.scalar_func,
955
+ inner_scalar_func=ref.inner_scalar_func,
956
+ scalar_func_args=list(ref.scalar_func_args),
957
+ inner_scalar_func_args=list(ref.inner_scalar_func_args),
958
+ )
959
+ )
960
+ final_add_vals: list[ExprValue] = list(parameterized_add)
961
+ final_sub_vals: list[ExprValue] = list(parameterized_sub)
962
+ if net_const > 0:
963
+ final_add_vals.append(ExprValue(value=net_const))
964
+ elif net_const < 0:
965
+ final_sub_vals.append(ExprValue(value=abs(net_const)))
966
+ return NormalizedExpr(
967
+ add_groups=final_add,
968
+ sub_groups=final_sub,
969
+ add_values=final_add_vals,
970
+ sub_values=final_sub_vals,
971
+ agg_func=expr.agg_func,
972
+ scalar_func=expr.scalar_func,
973
+ inner_scalar_func=expr.inner_scalar_func,
974
+ scalar_func_args=list(expr.scalar_func_args),
975
+ inner_scalar_func_args=list(expr.inner_scalar_func_args),
976
+ )
977
+
978
+
979
+ def _simplify_filter(fp: FilterParam) -> FilterParam:
980
+ """Apply simplify_expr to both sides of a FilterParam.
981
+
982
+ Args:
983
+
984
+ fp: FilterParam to simplify.
985
+
986
+ Returns:
987
+
988
+ New FilterParam with simplified left_expr and right_expr.
989
+ """
990
+ new_left = _simplify_expr(fp.left_expr)
991
+ new_right = _simplify_expr(fp.right_expr) if fp.right_expr else None
992
+ return replace(fp, left_expr=new_left, right_expr=new_right)
993
+
994
+
995
+ def _simplify_having(hp: HavingParam) -> HavingParam:
996
+ """Apply simplify_expr to both sides of a HavingParam.
997
+
998
+ Args:
999
+
1000
+ hp: HavingParam to simplify.
1001
+
1002
+ Returns:
1003
+
1004
+ New HavingParam with simplified left_expr and right_expr.
1005
+ """
1006
+ new_left = _simplify_expr(hp.left_expr)
1007
+ new_right = _simplify_expr(hp.right_expr) if hp.right_expr else None
1008
+ return replace(hp, left_expr=new_left, right_expr=new_right)
1009
+
1010
+
1011
+ def simplify_exprs(intent: RuntimeIntent) -> RuntimeIntent:
1012
+ """Apply algebraic simplification to every NormalizedExpr across all intent clauses.
1013
+
1014
+ Args:
1015
+
1016
+ intent: RuntimeIntent whose expressions should be simplified.
1017
+
1018
+ Returns:
1019
+
1020
+ New RuntimeIntent with all expressions in simplified form.
1021
+ """
1022
+ debug("[intent_resolve.simplify_exprs] simplifying all expressions")
1023
+ new_select = [replace(sc, expr=_simplify_expr(sc.expr)) for sc in (intent.select_cols or [])]
1024
+ new_order = [replace(obc, expr=_simplify_expr(obc.expr)) for obc in (intent.order_by_cols or [])]
1025
+ new_filters = [_simplify_filter(fp) for fp in (intent.filters_param or [])]
1026
+ new_having = [_simplify_having(hp) for hp in (intent.having_param or [])]
1027
+ new_cte_steps: list[RuntimeCteStep] = []
1028
+ for cte in intent.cte_steps or []:
1029
+ cte_select = [replace(sc, expr=_simplify_expr(sc.expr)) for sc in (cte.select_cols or [])]
1030
+ cte_order = [replace(obc, expr=_simplify_expr(obc.expr)) for obc in (cte.order_by_cols or [])]
1031
+ cte_filters = [_simplify_filter(fp) for fp in (cte.filters_param or [])]
1032
+ cte_having = [_simplify_having(hp) for hp in (cte.having_param or [])]
1033
+ new_cte_steps.append(
1034
+ replace(
1035
+ cte,
1036
+ select_cols=cte_select,
1037
+ order_by_cols=cte_order,
1038
+ filters_param=cte_filters,
1039
+ having_param=cte_having,
1040
+ )
1041
+ )
1042
+ return replace(
1043
+ intent,
1044
+ select_cols=new_select,
1045
+ order_by_cols=new_order,
1046
+ filters_param=new_filters,
1047
+ having_param=new_having,
1048
+ cte_steps=new_cte_steps,
1049
+ )
1050
+
1051
+
1052
+ def _normalize_filter_scalar_on_left(fp: FilterParam) -> FilterParam:
1053
+ """Swap sides when left is scalar and right is column, flipping the operator.
1054
+
1055
+ Ensures column or table.column expressions are on the left for validation and SQL generation.
1056
+
1057
+ Args:
1058
+
1059
+ fp: FilterParam to normalize.
1060
+
1061
+ Returns:
1062
+
1063
+ FilterParam with column on the left when applicable.
1064
+ """
1065
+ if not fp.right_expr:
1066
+ return fp
1067
+ left_cols = [c for c in extract_columns_from_expr(fp.left_expr) if "." in c]
1068
+ right_cols = [c for c in extract_columns_from_expr(fp.right_expr) if "." in c]
1069
+ if left_cols or not right_cols:
1070
+ return fp
1071
+ new_op = REVERSE_OP_MAP.get(fp.op, fp.op)
1072
+ return FilterParam(
1073
+ left_expr=fp.right_expr,
1074
+ op=new_op,
1075
+ right_expr=fp.left_expr,
1076
+ value_type=fp.value_type,
1077
+ param_key=fp.param_key,
1078
+ raw_value=fp.raw_value,
1079
+ bool_op=fp.bool_op,
1080
+ filter_group=fp.filter_group,
1081
+ )
1082
+
1083
+
1084
+ def _normalize_filter_canonical(fp: FilterParam) -> FilterParam:
1085
+ """Normalize a filter to canonical form with a non-empty expression on the left.
1086
+
1087
+ When the left_expr is empty but right_expr is not, swaps the sides and reverses the comparison operator.
1088
+
1089
+ Args:
1090
+
1091
+ fp: FilterParam to normalize.
1092
+
1093
+ Returns:
1094
+
1095
+ FilterParam with the heavier side on the left.
1096
+ """
1097
+ if not fp.left_expr.add_groups and not fp.left_expr.add_values and fp.right_expr:
1098
+ new_op = REVERSE_OP_MAP.get(fp.op, fp.op)
1099
+ return FilterParam(
1100
+ left_expr=fp.right_expr,
1101
+ op=new_op,
1102
+ right_expr=fp.left_expr,
1103
+ value_type=fp.value_type,
1104
+ param_key=fp.param_key,
1105
+ )
1106
+ return fp
1107
+
1108
+
1109
+ def _normalize_having_canonical(hp: HavingParam) -> HavingParam:
1110
+ """Normalize a having condition to canonical form with a non-empty expression on the left.
1111
+
1112
+ Args:
1113
+
1114
+ hp: HavingParam to normalize.
1115
+
1116
+ Returns:
1117
+
1118
+ HavingParam with the heavier side on the left.
1119
+ """
1120
+ if not hp.left_expr.add_groups and not hp.left_expr.add_values and hp.right_expr:
1121
+ new_op = REVERSE_OP_MAP.get(hp.op, hp.op)
1122
+ return HavingParam(
1123
+ left_expr=hp.right_expr,
1124
+ op=new_op,
1125
+ right_expr=hp.left_expr,
1126
+ value_type=hp.value_type,
1127
+ param_key=hp.param_key,
1128
+ )
1129
+ return hp
1130
+
1131
+
1132
+ def _normalize_col_to_col_filter(fp: FilterParam) -> FilterParam:
1133
+ """Normalize an expr-vs-expr filter so the lexicographically smaller signature is on the left.
1134
+
1135
+ Args:
1136
+
1137
+ fp: FilterParam with a right_expr (col-vs-col filter).
1138
+
1139
+ Returns:
1140
+
1141
+ FilterParam with sides swapped and operator reversed if needed.
1142
+ """
1143
+ if fp.right_expr and not fp.param_key:
1144
+ left_sig = fp.left_expr.signature_key
1145
+ right_sig = fp.right_expr.signature_key
1146
+ if left_sig > right_sig:
1147
+ new_op = REVERSE_OP_MAP.get(fp.op, fp.op)
1148
+ return FilterParam(
1149
+ left_expr=fp.right_expr,
1150
+ op=new_op,
1151
+ right_expr=fp.left_expr,
1152
+ value_type=fp.value_type,
1153
+ param_key=fp.param_key,
1154
+ )
1155
+ return fp
1156
+
1157
+
1158
+ def _normalize_agg_to_agg_having(hp: HavingParam) -> HavingParam:
1159
+ """Normalize an expr-vs-expr having condition so the lexicographically smaller signature is on the left.
1160
+
1161
+ Args:
1162
+
1163
+ hp: HavingParam with a right_expr (agg-vs-agg having).
1164
+
1165
+ Returns:
1166
+
1167
+ HavingParam with sides swapped and operator reversed if needed.
1168
+ """
1169
+ if hp.right_expr and not hp.param_key:
1170
+ left_sig = hp.left_expr.signature_key
1171
+ right_sig = hp.right_expr.signature_key
1172
+ if left_sig > right_sig:
1173
+ new_op = REVERSE_OP_MAP.get(hp.op, hp.op)
1174
+ return HavingParam(
1175
+ left_expr=hp.right_expr,
1176
+ op=new_op,
1177
+ right_expr=hp.left_expr,
1178
+ value_type=hp.value_type,
1179
+ param_key=hp.param_key,
1180
+ )
1181
+ return hp
1182
+
1183
+
1184
+ def _normalize_filter(fp: FilterParam) -> FilterParam:
1185
+ """Apply all normalization steps to a single filter.
1186
+
1187
+ Runs scalar-on-left swap, canonical form, col-vs-col ordering, operator normalization, and value type normalization in sequence.
1188
+
1189
+ Args:
1190
+
1191
+ fp: FilterParam to normalize.
1192
+
1193
+ Returns:
1194
+
1195
+ Fully normalized FilterParam.
1196
+ """
1197
+ fp = _normalize_filter_scalar_on_left(fp)
1198
+ fp = _normalize_filter_canonical(fp)
1199
+ fp = _normalize_col_to_col_filter(fp)
1200
+ return replace(fp, op=normalize_op(fp.op), value_type=normalize_value_type(fp.value_type))
1201
+
1202
+
1203
+ def _normalize_having(hp: HavingParam) -> HavingParam:
1204
+ """Apply all normalization steps to a single having condition.
1205
+
1206
+ Runs canonical form, agg-vs-agg ordering, operator normalization, and value type normalization in sequence.
1207
+
1208
+ Args:
1209
+
1210
+ hp: HavingParam to normalize.
1211
+
1212
+ Returns:
1213
+
1214
+ Fully normalized HavingParam.
1215
+ """
1216
+ hp = _normalize_having_canonical(hp)
1217
+ hp = _normalize_agg_to_agg_having(hp)
1218
+ return replace(hp, op=normalize_op(hp.op), value_type=normalize_value_type(hp.value_type))
1219
+
1220
+
1221
+ def _dedup_filters(filters: list[FilterParam]) -> list[FilterParam]:
1222
+ """Remove duplicate filters that share an identical structural signature, bool_op, and filter_group.
1223
+
1224
+ Args:
1225
+
1226
+ filters: List of FilterParam objects to deduplicate.
1227
+
1228
+ Returns:
1229
+
1230
+ List with the first occurrence of each unique (signature_key, bool_op, filter_group) retained.
1231
+ """
1232
+ seen: set[tuple[str, str, int | None]] = set()
1233
+ result: list[FilterParam] = []
1234
+ for fp in filters:
1235
+ key = (fp.signature_key, fp.bool_op, fp.filter_group)
1236
+ if key in seen:
1237
+ debug(f"[intent_resolve.dedup_filters] dropping duplicate filter: {key}")
1238
+ continue
1239
+ seen.add(key)
1240
+ result.append(fp)
1241
+ return result
1242
+
1243
+
1244
+ def _dedup_having(having: list[HavingParam]) -> list[HavingParam]:
1245
+ """Remove duplicate having conditions that share an identical structural signature, bool_op, and filter_group.
1246
+
1247
+ Args:
1248
+
1249
+ having: List of HavingParam objects to deduplicate.
1250
+
1251
+ Returns:
1252
+
1253
+ List with the first occurrence of each unique (signature_key, bool_op, filter_group) retained.
1254
+ """
1255
+ seen: set[tuple[str, str, int | None]] = set()
1256
+ result: list[HavingParam] = []
1257
+ for hp in having:
1258
+ key = (hp.signature_key, hp.bool_op, hp.filter_group)
1259
+ if key in seen:
1260
+ debug(f"[intent_resolve.dedup_having] dropping duplicate having: {key}")
1261
+ continue
1262
+ seen.add(key)
1263
+ result.append(hp)
1264
+ return result
1265
+
1266
+
1267
+ def normalize_filters_havings(intent: RuntimeIntent) -> RuntimeIntent:
1268
+ """Apply all normalization, deduplication, and sorting rules to filters and having conditions.
1269
+
1270
+ Args:
1271
+
1272
+ intent: RuntimeIntent whose filters and having lists should be normalized.
1273
+
1274
+ Returns:
1275
+
1276
+ New RuntimeIntent with all filters and having conditions normalized, deduplicated, and sorted.
1277
+ """
1278
+ new_filters = [_normalize_filter(fp) for fp in (intent.filters_param or [])]
1279
+ new_having = [_normalize_having(hp) for hp in (intent.having_param or [])]
1280
+ new_cte_steps = []
1281
+ for cte in intent.cte_steps or []:
1282
+ cte_filters = _dedup_filters(sort_filters([_normalize_filter(fp) for fp in (cte.filters_param or [])]))
1283
+ cte_having = _dedup_having(sort_having([_normalize_having(hp) for hp in (cte.having_param or [])]))
1284
+ new_cte_steps.append(replace(cte, filters_param=cte_filters, having_param=cte_having))
1285
+ new_filters = _dedup_filters(sort_filters(new_filters))
1286
+ new_having = _dedup_having(sort_having(new_having))
1287
+ return replace(
1288
+ intent,
1289
+ filters_param=new_filters,
1290
+ having_param=new_having,
1291
+ cte_steps=new_cte_steps,
1292
+ )