aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
|
@@ -0,0 +1,1292 @@
|
|
|
1
|
+
"""Column resolution, normalization, and grain enforcement for intent post-processing.
|
|
2
|
+
|
|
3
|
+
Resolves bare column names to source tables or CTE steps via resolve_column_map and resolve_cte_column_maps, normalizes filter and having operators to canonical forms, deduplicates conditions, and sorts select, order, filter, and having clauses by structural keys.
|
|
4
|
+
|
|
5
|
+
Enforces grain consistency between scalar, grouped, and row_level settings versus aggregation and GROUP BY, validates tables and columns against the schema, applies algebraic simplification such as constant folding and like-term combining to all expressions, and normalizes CTE names and COUNT(1) to COUNT(*).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import re
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from dataclasses import replace
|
|
14
|
+
|
|
15
|
+
from .config import REVERSE_OP_MAP, normalize_value_type
|
|
16
|
+
from .contracts_base import SchemaGraph
|
|
17
|
+
from .contracts_core import (
|
|
18
|
+
ExprValue,
|
|
19
|
+
FilterParam,
|
|
20
|
+
HavingParam,
|
|
21
|
+
MulGroup,
|
|
22
|
+
NormalizedExpr,
|
|
23
|
+
OrderByCol,
|
|
24
|
+
RuntimeCteStep,
|
|
25
|
+
RuntimeIntent,
|
|
26
|
+
SelectCol,
|
|
27
|
+
)
|
|
28
|
+
from .core_utils import debug, normalize_op
|
|
29
|
+
from .intent_expr import extract_columns_from_expr, replace_refs_in_expr
|
|
30
|
+
from .intent_repair import best_descriptive_column
|
|
31
|
+
from .sql_gen import _render_expr_sql
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def normalize_count_star(intent: RuntimeIntent) -> RuntimeIntent:
|
|
35
|
+
"""Convert COUNT(1) references to COUNT(*) throughout an intent for consistency.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
|
|
39
|
+
intent: RuntimeIntent to normalize.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
|
|
43
|
+
New RuntimeIntent with COUNT(1) replaced by COUNT(*) in all expressions.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def _fix_count(term: str) -> str:
|
|
47
|
+
if term.upper() == "COUNT(1)":
|
|
48
|
+
return "COUNT(*)"
|
|
49
|
+
return term
|
|
50
|
+
|
|
51
|
+
def _fix_filter_list(params):
|
|
52
|
+
return [
|
|
53
|
+
replace(
|
|
54
|
+
fp,
|
|
55
|
+
left_expr=replace_refs_in_expr(fp.left_expr, _fix_count),
|
|
56
|
+
right_expr=(replace_refs_in_expr(fp.right_expr, _fix_count) if fp.right_expr else None),
|
|
57
|
+
)
|
|
58
|
+
for fp in params
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
new_select_cols = [replace(sc, expr=replace_refs_in_expr(sc.expr, _fix_count)) for sc in (intent.select_cols or [])]
|
|
62
|
+
new_order_by_cols = [
|
|
63
|
+
replace(obc, expr=replace_refs_in_expr(obc.expr, _fix_count)) for obc in (intent.order_by_cols or [])
|
|
64
|
+
]
|
|
65
|
+
new_filters = _fix_filter_list(intent.filters_param or [])
|
|
66
|
+
new_having = _fix_filter_list(intent.having_param or [])
|
|
67
|
+
new_cte_steps = []
|
|
68
|
+
for cte in intent.cte_steps or []:
|
|
69
|
+
cte_sc = [replace(sc, expr=replace_refs_in_expr(sc.expr, _fix_count)) for sc in (cte.select_cols or [])]
|
|
70
|
+
cte_obc = [replace(obc, expr=replace_refs_in_expr(obc.expr, _fix_count)) for obc in (cte.order_by_cols or [])]
|
|
71
|
+
cte_fp = _fix_filter_list(cte.filters_param or [])
|
|
72
|
+
cte_hp = _fix_filter_list(cte.having_param or [])
|
|
73
|
+
new_cte_steps.append(
|
|
74
|
+
replace(
|
|
75
|
+
cte,
|
|
76
|
+
select_cols=cte_sc,
|
|
77
|
+
order_by_cols=cte_obc,
|
|
78
|
+
filters_param=cte_fp,
|
|
79
|
+
having_param=cte_hp,
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
return replace(
|
|
83
|
+
intent,
|
|
84
|
+
select_cols=new_select_cols,
|
|
85
|
+
order_by_cols=new_order_by_cols,
|
|
86
|
+
filters_param=new_filters,
|
|
87
|
+
having_param=new_having,
|
|
88
|
+
cte_steps=new_cte_steps,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def sort_select_cols(cols: list[SelectCol]) -> list[SelectCol]:
|
|
93
|
+
"""Sort select columns so non-aggregated expressions come before aggregated ones and ties are broken by expression signature.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
|
|
97
|
+
cols: List of SelectCol objects to sort.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
|
|
101
|
+
Sorted list of SelectCol objects.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def key_fn(sc: SelectCol) -> tuple[int, str]:
|
|
105
|
+
return (1 if sc.is_aggregated else 0, sc.signature_key)
|
|
106
|
+
|
|
107
|
+
return sorted(cols, key=key_fn)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def sort_order_by_cols(cols: list[OrderByCol]) -> list[OrderByCol]:
|
|
111
|
+
"""Sort order-by columns so non-aggregated expressions come before aggregated ones and ties are broken by expression signature.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
|
|
115
|
+
cols: List of OrderByCol objects to sort.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
|
|
119
|
+
Sorted list of OrderByCol objects.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def key_fn(obc: OrderByCol) -> tuple[int, str]:
|
|
123
|
+
return (1 if obc.is_aggregated else 0, obc.signature_key)
|
|
124
|
+
|
|
125
|
+
return sorted(cols, key=key_fn)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _filter_structural_key(fp: FilterParam) -> tuple[str, str, str, str]:
|
|
129
|
+
"""Return the structural sort key for a single FilterParam.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
|
|
133
|
+
fp: FilterParam to compute the key for.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
|
|
137
|
+
Tuple of (left_sig, op, right_sig, value_type) with all components lowercased.
|
|
138
|
+
"""
|
|
139
|
+
left = fp.left_expr.signature_key if fp.left_expr else ""
|
|
140
|
+
right = fp.right_expr.signature_key if fp.right_expr else ""
|
|
141
|
+
return (left, fp.op.lower(), right, fp.value_type.lower())
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _having_structural_key(hp: HavingParam) -> tuple[str, str, str, str]:
|
|
145
|
+
"""Return the structural sort key for a single HavingParam.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
|
|
149
|
+
hp: HavingParam to compute the key for.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
|
|
153
|
+
Tuple of (left_sig, op, right_sig, value_type) with all components lowercased.
|
|
154
|
+
"""
|
|
155
|
+
left = hp.left_expr.signature_key if hp.left_expr else ""
|
|
156
|
+
right = hp.right_expr.signature_key if hp.right_expr else ""
|
|
157
|
+
return (left, hp.op.lower(), right, hp.value_type.lower())
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _canonicalize_condition_order(
|
|
161
|
+
items: list,
|
|
162
|
+
structural_key_fn: Callable,
|
|
163
|
+
) -> list:
|
|
164
|
+
"""Canonicalize a flat condition list by parsing the positional bool_op operators into a precedence tree (AND binds tighter than OR), sorting at each level using the commutativity of AND/OR, and re-serializing with adjusted bool_ops.
|
|
165
|
+
|
|
166
|
+
The last element's bool_op is treated as the inter-group connector and is preserved on whichever element ends up last after sorting.
|
|
167
|
+
|
|
168
|
+
Works for both FilterParam and HavingParam since they share the same bool_op / filter_group field layout.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
|
|
172
|
+
items: List of FilterParam or HavingParam objects.
|
|
173
|
+
|
|
174
|
+
structural_key_fn: Callable that returns a sortable tuple for one item.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
|
|
178
|
+
New list with the same elements in canonical order and bool_ops reassigned.
|
|
179
|
+
"""
|
|
180
|
+
if len(items) <= 1:
|
|
181
|
+
return list(items)
|
|
182
|
+
inter_connector = items[-1].bool_op
|
|
183
|
+
ops: list[str] = [it.bool_op for it in items[:-1]]
|
|
184
|
+
chunks: list[list] = []
|
|
185
|
+
current_chunk: list = [items[0]]
|
|
186
|
+
for i, op in enumerate(ops):
|
|
187
|
+
if op == "OR":
|
|
188
|
+
chunks.append(current_chunk)
|
|
189
|
+
current_chunk = [items[i + 1]]
|
|
190
|
+
else:
|
|
191
|
+
current_chunk.append(items[i + 1])
|
|
192
|
+
chunks.append(current_chunk)
|
|
193
|
+
sorted_chunks: list[list] = []
|
|
194
|
+
for chunk in chunks:
|
|
195
|
+
sorted_chunks.append(sorted(chunk, key=structural_key_fn))
|
|
196
|
+
sorted_chunks.sort(key=lambda ch: structural_key_fn(ch[0]))
|
|
197
|
+
result: list = []
|
|
198
|
+
for ci, chunk in enumerate(sorted_chunks):
|
|
199
|
+
for fi, item in enumerate(chunk):
|
|
200
|
+
is_last_in_chunk = fi == len(chunk) - 1
|
|
201
|
+
is_last_chunk = ci == len(sorted_chunks) - 1
|
|
202
|
+
if is_last_chunk and is_last_in_chunk:
|
|
203
|
+
new_bool_op = inter_connector
|
|
204
|
+
elif is_last_in_chunk:
|
|
205
|
+
new_bool_op = "OR"
|
|
206
|
+
else:
|
|
207
|
+
new_bool_op = "AND"
|
|
208
|
+
result.append(replace(item, bool_op=new_bool_op))
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def sort_filters(filters: list[FilterParam]) -> list[FilterParam]:
|
|
213
|
+
"""Sort filter parameters using precedence-aware group canonicalization.
|
|
214
|
+
|
|
215
|
+
Partitions filters by filter_group, canonicalizes order within each group using the precedence tree algorithm (AND binds tighter than OR, both are commutative), then sorts the groups themselves at the inter-group level using the same algorithm on group-representative elements.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
|
|
219
|
+
filters: List of FilterParam objects to sort.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
|
|
223
|
+
Sorted list of FilterParam objects with bool_ops adjusted to reflect the new canonical positions.
|
|
224
|
+
"""
|
|
225
|
+
if not filters:
|
|
226
|
+
return []
|
|
227
|
+
buckets: dict[int | None, list[FilterParam]] = defaultdict(list)
|
|
228
|
+
for fp in filters:
|
|
229
|
+
buckets[fp.filter_group].append(fp)
|
|
230
|
+
canonicalized_groups: list[tuple[int | None, list[FilterParam]]] = []
|
|
231
|
+
for gid, group in buckets.items():
|
|
232
|
+
canonicalized_groups.append((gid, _canonicalize_condition_order(group, _filter_structural_key)))
|
|
233
|
+
if len(canonicalized_groups) == 1:
|
|
234
|
+
return canonicalized_groups[0][1]
|
|
235
|
+
representatives: list[FilterParam] = []
|
|
236
|
+
group_map: dict[int, tuple[int | None, list[FilterParam]]] = {}
|
|
237
|
+
for idx, (gid, group) in enumerate(canonicalized_groups):
|
|
238
|
+
rep = group[-1]
|
|
239
|
+
representatives.append(replace(rep, filter_group=idx))
|
|
240
|
+
group_map[idx] = (gid, group)
|
|
241
|
+
sorted_reps = _canonicalize_condition_order(representatives, _filter_structural_key)
|
|
242
|
+
result: list[FilterParam] = []
|
|
243
|
+
for _ri, rep in enumerate(sorted_reps):
|
|
244
|
+
proxy_id = rep.filter_group
|
|
245
|
+
assert isinstance(proxy_id, int)
|
|
246
|
+
real_gid, group = group_map[proxy_id]
|
|
247
|
+
inter_connector = rep.bool_op
|
|
248
|
+
for fi, fp in enumerate(group):
|
|
249
|
+
if fi == len(group) - 1:
|
|
250
|
+
result.append(replace(fp, bool_op=inter_connector, filter_group=real_gid))
|
|
251
|
+
else:
|
|
252
|
+
result.append(replace(fp, filter_group=real_gid))
|
|
253
|
+
return result
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def sort_having(having: list[HavingParam]) -> list[HavingParam]:
|
|
257
|
+
"""Sort having parameters using precedence-aware group canonicalization.
|
|
258
|
+
|
|
259
|
+
Partitions having conditions by filter_group, canonicalizes order within each group using the precedence tree algorithm, then sorts the groups at the inter-group level using the same algorithm.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
|
|
263
|
+
having: List of HavingParam objects to sort.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
|
|
267
|
+
Sorted list of HavingParam objects with bool_ops adjusted to reflect the new canonical positions.
|
|
268
|
+
"""
|
|
269
|
+
if not having:
|
|
270
|
+
return []
|
|
271
|
+
buckets: dict[int | None, list[HavingParam]] = defaultdict(list)
|
|
272
|
+
for hp in having:
|
|
273
|
+
buckets[hp.filter_group].append(hp)
|
|
274
|
+
canonicalized_groups: list[tuple[int | None, list[HavingParam]]] = []
|
|
275
|
+
for gid, group in buckets.items():
|
|
276
|
+
canonicalized_groups.append((gid, _canonicalize_condition_order(group, _having_structural_key)))
|
|
277
|
+
if len(canonicalized_groups) == 1:
|
|
278
|
+
return canonicalized_groups[0][1]
|
|
279
|
+
representatives: list[HavingParam] = []
|
|
280
|
+
group_map: dict[int, tuple[int | None, list[HavingParam]]] = {}
|
|
281
|
+
for idx, (gid, group) in enumerate(canonicalized_groups):
|
|
282
|
+
rep = group[-1]
|
|
283
|
+
representatives.append(replace(rep, filter_group=idx))
|
|
284
|
+
group_map[idx] = (gid, group)
|
|
285
|
+
sorted_reps = _canonicalize_condition_order(representatives, _having_structural_key)
|
|
286
|
+
result: list[HavingParam] = []
|
|
287
|
+
for _ri, rep in enumerate(sorted_reps):
|
|
288
|
+
proxy_id = rep.filter_group
|
|
289
|
+
assert isinstance(proxy_id, int)
|
|
290
|
+
real_gid, group = group_map[proxy_id]
|
|
291
|
+
inter_connector = rep.bool_op
|
|
292
|
+
for fi, hp in enumerate(group):
|
|
293
|
+
if fi == len(group) - 1:
|
|
294
|
+
result.append(replace(hp, bool_op=inter_connector, filter_group=real_gid))
|
|
295
|
+
else:
|
|
296
|
+
result.append(replace(hp, filter_group=real_gid))
|
|
297
|
+
return result
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _is_cte_output_groupable(term: str, cte_steps: list[RuntimeCteStep]) -> bool:
|
|
301
|
+
"""Return True if term references a CTE output column."""
|
|
302
|
+
if "." not in term:
|
|
303
|
+
return False
|
|
304
|
+
table_part, col_part = term.split(".", 1)
|
|
305
|
+
table_lower = table_part.strip().lower()
|
|
306
|
+
col_lower = col_part.strip().lower()
|
|
307
|
+
for cte in cte_steps or []:
|
|
308
|
+
if cte.cte_name.lower() == table_lower:
|
|
309
|
+
out_cols = cte.output_columns or []
|
|
310
|
+
return any(c.strip().lower() == col_lower for c in out_cols)
|
|
311
|
+
return False
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def enforce_grain_consistency(intent: RuntimeIntent, schema_graph: SchemaGraph) -> RuntimeIntent:
|
|
315
|
+
"""Ensure GROUP BY grain matches select columns and augment with descriptive columns.
|
|
316
|
+
|
|
317
|
+
When group_by_cols is empty but select has mixed aggregated/non-aggregated columns, infers the group-by from groupable non-aggregated columns and for PK group-by columns auto-adds the best descriptive column to both select and group_by for readability.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
|
|
321
|
+
intent: RuntimeIntent to enforce grain on.
|
|
322
|
+
|
|
323
|
+
schema_graph: SchemaGraph for column role lookups.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
|
|
327
|
+
Updated RuntimeIntent with grain set to 'grouped' when grouping is present.
|
|
328
|
+
"""
|
|
329
|
+
group_by = list(intent.group_by_cols or [])
|
|
330
|
+
select_cols = list(intent.select_cols or [])
|
|
331
|
+
cte_steps = intent.cte_steps or []
|
|
332
|
+
if not group_by:
|
|
333
|
+
has_agg = any(sc.is_aggregated for sc in select_cols)
|
|
334
|
+
non_agg = [sc for sc in select_cols if not sc.is_aggregated]
|
|
335
|
+
if not (has_agg and non_agg):
|
|
336
|
+
return intent
|
|
337
|
+
groupable: list[NormalizedExpr] = []
|
|
338
|
+
for sc in non_agg:
|
|
339
|
+
term = sc.expr.primary_term
|
|
340
|
+
parts = term.split(".", 1) if "." in term else None
|
|
341
|
+
if not parts:
|
|
342
|
+
groupable.append(sc.expr)
|
|
343
|
+
continue
|
|
344
|
+
if _is_cte_output_groupable(term, cte_steps):
|
|
345
|
+
groupable.append(sc.expr)
|
|
346
|
+
continue
|
|
347
|
+
tbl_meta = schema_graph.tables.get(parts[0])
|
|
348
|
+
col_meta = tbl_meta.columns.get(parts[1]) if tbl_meta else None
|
|
349
|
+
if not col_meta or col_meta.is_groupable:
|
|
350
|
+
groupable.append(sc.expr)
|
|
351
|
+
group_by = sorted(groupable, key=lambda g: g.signature_key)
|
|
352
|
+
debug(
|
|
353
|
+
f"[intent_resolve.enforce_grain_consistency] inferred group_by from groupable non-agg cols: {[g.primary_term for g in group_by]}"
|
|
354
|
+
)
|
|
355
|
+
existing_terms = {sc.expr.primary_term for sc in select_cols}
|
|
356
|
+
gb_terms = {g.primary_term for g in group_by}
|
|
357
|
+
has_agg_check = any(sc.is_aggregated for sc in select_cols)
|
|
358
|
+
if has_agg_check and gb_terms:
|
|
359
|
+
for sc in select_cols:
|
|
360
|
+
if sc.is_aggregated:
|
|
361
|
+
continue
|
|
362
|
+
term = sc.expr.primary_term
|
|
363
|
+
if term in gb_terms:
|
|
364
|
+
continue
|
|
365
|
+
parts = term.split(".", 1) if "." in term else None
|
|
366
|
+
if not parts:
|
|
367
|
+
group_by.append(sc.expr)
|
|
368
|
+
gb_terms.add(term)
|
|
369
|
+
debug(f"[intent_resolve.enforce_grain_consistency] auto-added non-agg select col to group_by: {term}")
|
|
370
|
+
continue
|
|
371
|
+
if _is_cte_output_groupable(term, cte_steps):
|
|
372
|
+
group_by.append(sc.expr)
|
|
373
|
+
gb_terms.add(term)
|
|
374
|
+
debug(f"[intent_resolve.enforce_grain_consistency] auto-added CTE output col to group_by: {term}")
|
|
375
|
+
continue
|
|
376
|
+
tbl_meta = schema_graph.tables.get(parts[0])
|
|
377
|
+
col_meta = tbl_meta.columns.get(parts[1]) if tbl_meta else None
|
|
378
|
+
if not col_meta or col_meta.is_groupable:
|
|
379
|
+
group_by.append(sc.expr)
|
|
380
|
+
gb_terms.add(term)
|
|
381
|
+
debug(f"[intent_resolve.enforce_grain_consistency] auto-added non-agg select col to group_by: {term}")
|
|
382
|
+
intent_tables = set(intent.tables or [])
|
|
383
|
+
for gb_expr in list(group_by):
|
|
384
|
+
gb_col = gb_expr.primary_term
|
|
385
|
+
parts = gb_col.split(".", 1) if "." in gb_col else None
|
|
386
|
+
if not parts:
|
|
387
|
+
continue
|
|
388
|
+
tbl_meta = schema_graph.tables.get(parts[0])
|
|
389
|
+
col_meta = tbl_meta.columns.get(parts[1]) if tbl_meta else None
|
|
390
|
+
if not col_meta:
|
|
391
|
+
continue
|
|
392
|
+
if col_meta.is_primary_key:
|
|
393
|
+
desc = best_descriptive_column(parts[0], schema_graph, existing_terms | gb_terms)
|
|
394
|
+
if desc:
|
|
395
|
+
fq = f"{parts[0]}.{desc}"
|
|
396
|
+
group_by.append(NormalizedExpr.from_column(fq))
|
|
397
|
+
select_cols.append(SelectCol(expr=NormalizedExpr.from_column(fq)))
|
|
398
|
+
existing_terms.add(fq)
|
|
399
|
+
gb_terms.add(fq)
|
|
400
|
+
debug(f"[intent_resolve.enforce_grain_consistency] auto-added descriptive column {fq}")
|
|
401
|
+
continue
|
|
402
|
+
if col_meta.is_foreign_key:
|
|
403
|
+
for fk in tbl_meta.foreign_keys or []:
|
|
404
|
+
if parts[1] not in fk.src_cols:
|
|
405
|
+
continue
|
|
406
|
+
if fk.dst_table not in intent_tables:
|
|
407
|
+
continue
|
|
408
|
+
desc = best_descriptive_column(fk.dst_table, schema_graph, existing_terms | gb_terms)
|
|
409
|
+
if not desc:
|
|
410
|
+
continue
|
|
411
|
+
fq = f"{fk.dst_table}.{desc}"
|
|
412
|
+
group_by.append(NormalizedExpr.from_column(fq))
|
|
413
|
+
select_cols.append(SelectCol(expr=NormalizedExpr.from_column(fq)))
|
|
414
|
+
existing_terms.add(fq)
|
|
415
|
+
gb_terms.add(fq)
|
|
416
|
+
debug(
|
|
417
|
+
f"[intent_resolve.enforce_grain_consistency] auto-added FK descriptive column {fq} via {parts[0]}.{parts[1]}->{fk.dst_table}"
|
|
418
|
+
)
|
|
419
|
+
return replace(
|
|
420
|
+
intent,
|
|
421
|
+
group_by_cols=sorted(group_by, key=lambda g: g.signature_key),
|
|
422
|
+
select_cols=select_cols,
|
|
423
|
+
grain="grouped",
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _cte_is_aggregated(cte: RuntimeCteStep) -> bool:
|
|
428
|
+
"""Return True if the CTE is grouped or scalar by structure."""
|
|
429
|
+
has_agg = any(sc.is_aggregated for sc in (cte.select_cols or []))
|
|
430
|
+
if cte.group_by_cols:
|
|
431
|
+
return True
|
|
432
|
+
if has_agg:
|
|
433
|
+
return True
|
|
434
|
+
return False
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def force_main_grain_when_using_grouped_cte(intent: RuntimeIntent) -> RuntimeIntent:
|
|
438
|
+
"""Promote main grain to grouped when it aggregates over a grouped CTE.
|
|
439
|
+
|
|
440
|
+
Only promotes when the main query itself contains aggregated
|
|
441
|
+
select columns, preventing false promotion when the main query
|
|
442
|
+
simply selects pre-aggregated CTE output without further
|
|
443
|
+
aggregation.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
|
|
447
|
+
intent: RuntimeIntent with optional cte_steps and main tables.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
|
|
451
|
+
Updated RuntimeIntent with grain "grouped" when applicable.
|
|
452
|
+
"""
|
|
453
|
+
if intent.grain != "row_level":
|
|
454
|
+
return intent
|
|
455
|
+
cte_steps = intent.cte_steps or []
|
|
456
|
+
if not cte_steps:
|
|
457
|
+
return intent
|
|
458
|
+
has_main_agg = any(
|
|
459
|
+
sc.is_aggregated for sc in (intent.select_cols or [])
|
|
460
|
+
)
|
|
461
|
+
if not has_main_agg:
|
|
462
|
+
return intent
|
|
463
|
+
main_tables = set(intent.tables or [])
|
|
464
|
+
aggregated_cte_names = {
|
|
465
|
+
cte.cte_name for cte in cte_steps if _cte_is_aggregated(cte)
|
|
466
|
+
}
|
|
467
|
+
if not main_tables.intersection(aggregated_cte_names):
|
|
468
|
+
return intent
|
|
469
|
+
return replace(intent, grain="grouped")
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def enforce_cte_grain_consistency(cte: RuntimeCteStep) -> RuntimeCteStep:
|
|
473
|
+
"""Set grain on a CTE step based on its structural properties.
|
|
474
|
+
|
|
475
|
+
Sets grain to ``"grouped"`` when ``group_by_cols`` are present, or
|
|
476
|
+
``"scalar"`` when the CTE contains aggregation in select columns
|
|
477
|
+
but no GROUP BY clause, indicating a single-row aggregate result.
|
|
478
|
+
Sorts group_by_cols by signature_key for consistency with main
|
|
479
|
+
intent.
|
|
480
|
+
"""
|
|
481
|
+
has_agg = any(sc.is_aggregated for sc in (cte.select_cols or []))
|
|
482
|
+
if not cte.group_by_cols:
|
|
483
|
+
if has_agg and cte.grain != "scalar":
|
|
484
|
+
return replace(cte, grain="scalar")
|
|
485
|
+
return cte
|
|
486
|
+
sorted_gb = sorted(cte.group_by_cols, key=lambda g: g.signature_key)
|
|
487
|
+
return replace(cte, grain="grouped", group_by_cols=sorted_gb)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def resolve_column_map(columns: list[str], schema_graph: SchemaGraph, tables: list[str]) -> dict[str, str]:
|
|
491
|
+
"""Map bare or qualified column references to their source table names.
|
|
492
|
+
|
|
493
|
+
For qualified references (table.col), validates the table against the allowed list and for bare references resolves by scanning the column lists of all candidate tables while logging a debug message when a column is ambiguous across multiple tables.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
|
|
497
|
+
columns: List of column reference strings (bare or table-qualified).
|
|
498
|
+
|
|
499
|
+
schema_graph: SchemaGraph containing table/column metadata.
|
|
500
|
+
|
|
501
|
+
tables: Allowed table names to resolve against.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
|
|
505
|
+
Dictionary mapping bare column name to its source table name.
|
|
506
|
+
"""
|
|
507
|
+
column_map: dict[str, str] = {}
|
|
508
|
+
table_col_index: dict[str, set[str]] = {}
|
|
509
|
+
for tbl in tables:
|
|
510
|
+
if tbl not in schema_graph.tables:
|
|
511
|
+
continue
|
|
512
|
+
table_col_index[tbl] = {c.lower() for c in schema_graph.tables[tbl].columns}
|
|
513
|
+
for col in columns:
|
|
514
|
+
col_stripped = col.strip()
|
|
515
|
+
if "." in col_stripped:
|
|
516
|
+
tbl_ref, col_ref = col_stripped.split(".", 1)
|
|
517
|
+
col_ref_lower = col_ref.strip().lower()
|
|
518
|
+
tbl_ref_lower = tbl_ref.strip().lower()
|
|
519
|
+
for tbl in tables:
|
|
520
|
+
if (
|
|
521
|
+
tbl.lower() == tbl_ref_lower or tbl.split(".")[-1].lower() == tbl_ref_lower
|
|
522
|
+
) and col_ref_lower in table_col_index.get(tbl, set()):
|
|
523
|
+
column_map[col_ref.strip()] = tbl
|
|
524
|
+
break
|
|
525
|
+
continue
|
|
526
|
+
col_lower = col_stripped.lower()
|
|
527
|
+
candidates = [tbl for tbl in tables if col_lower in table_col_index.get(tbl, set())]
|
|
528
|
+
if len(candidates) == 1:
|
|
529
|
+
column_map[col_stripped] = candidates[0]
|
|
530
|
+
elif len(candidates) > 1:
|
|
531
|
+
column_map[col_stripped] = candidates[0]
|
|
532
|
+
debug(f"[intent_resolve.resolve_column_map] ambiguous column '{col}': {candidates}, using {candidates[0]}")
|
|
533
|
+
return column_map
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def resolve_cte_column_maps(cte_steps: list[RuntimeCteStep]) -> list[RuntimeCteStep]:
|
|
537
|
+
"""Build a column_map for each CTE step, mapping bare column names to source CTE names.
|
|
538
|
+
|
|
539
|
+
Processes CTE steps in order so each step can reference output columns from prior steps and bare column names found in earlier CTE output lists are mapped to that CTE's name.
|
|
540
|
+
|
|
541
|
+
Args:
|
|
542
|
+
|
|
543
|
+
cte_steps: Ordered list of RuntimeCteStep objects.
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
|
|
547
|
+
New list of RuntimeCteStep objects with column_map populated.
|
|
548
|
+
"""
|
|
549
|
+
cte_output_cols: dict[str, set[str]] = {}
|
|
550
|
+
result = []
|
|
551
|
+
for cte in cte_steps:
|
|
552
|
+
cte_name = cte.cte_name
|
|
553
|
+
out_cols = set(cte.output_columns or [])
|
|
554
|
+
for sc in cte.select_cols or []:
|
|
555
|
+
col = sc.expr.primary_column
|
|
556
|
+
if col:
|
|
557
|
+
out_cols.add(col.split(".")[-1])
|
|
558
|
+
cte_output_cols[cte_name] = out_cols
|
|
559
|
+
available_sources: dict[str, str] = {}
|
|
560
|
+
for prev_cte_name, prev_cols in cte_output_cols.items():
|
|
561
|
+
if prev_cte_name == cte_name:
|
|
562
|
+
continue
|
|
563
|
+
for c in prev_cols:
|
|
564
|
+
available_sources[c.lower()] = prev_cte_name
|
|
565
|
+
cols_to_resolve: list[str] = []
|
|
566
|
+
for sc in cte.select_cols or []:
|
|
567
|
+
cols_to_resolve.extend(extract_columns_from_expr(sc.expr))
|
|
568
|
+
for obc in cte.order_by_cols or []:
|
|
569
|
+
cols_to_resolve.extend(extract_columns_from_expr(obc.expr))
|
|
570
|
+
for fp in cte.filters_param or []:
|
|
571
|
+
cols_to_resolve.extend(extract_columns_from_expr(fp.left_expr))
|
|
572
|
+
if fp.right_expr:
|
|
573
|
+
cols_to_resolve.extend(extract_columns_from_expr(fp.right_expr))
|
|
574
|
+
for hp in cte.having_param or []:
|
|
575
|
+
cols_to_resolve.extend(extract_columns_from_expr(hp.left_expr))
|
|
576
|
+
if hp.right_expr:
|
|
577
|
+
cols_to_resolve.extend(extract_columns_from_expr(hp.right_expr))
|
|
578
|
+
column_map: dict[str, str] = {}
|
|
579
|
+
for col in cols_to_resolve:
|
|
580
|
+
col_stripped = col.strip()
|
|
581
|
+
if "." in col_stripped:
|
|
582
|
+
bare = col_stripped.split(".", 1)[1].strip()
|
|
583
|
+
source = col_stripped.split(".", 1)[0].strip()
|
|
584
|
+
column_map[bare] = source
|
|
585
|
+
elif col_stripped.lower() in available_sources:
|
|
586
|
+
column_map[col_stripped] = available_sources[col_stripped.lower()]
|
|
587
|
+
updated_cte = replace(cte, column_map=column_map)
|
|
588
|
+
result.append(updated_cte)
|
|
589
|
+
return result
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def normalize_cte_names(intent: RuntimeIntent) -> RuntimeIntent:
|
|
593
|
+
"""Rename all CTE steps to canonical names (cte1, cte2, ...) and update all references.
|
|
594
|
+
|
|
595
|
+
Replaces old CTE name occurrences in tables lists, expression terms, column maps, and output column names throughout both CTE steps and the main query.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
|
|
599
|
+
intent: RuntimeIntent with CTE steps to normalize.
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
|
|
603
|
+
New RuntimeIntent with CTE names and all cross-references updated.
|
|
604
|
+
"""
|
|
605
|
+
cte_steps = intent.cte_steps or []
|
|
606
|
+
if not cte_steps:
|
|
607
|
+
return intent
|
|
608
|
+
old_to_new: dict[str, str] = {}
|
|
609
|
+
for i, cte in enumerate(cte_steps, start=1):
|
|
610
|
+
new_name = f"cte{i}"
|
|
611
|
+
old_to_new[cte.cte_name] = new_name
|
|
612
|
+
|
|
613
|
+
def replace_cte_refs(s: str) -> str:
|
|
614
|
+
for old, new in old_to_new.items():
|
|
615
|
+
pattern = re.compile(rf"\b{re.escape(old)}\b", re.IGNORECASE)
|
|
616
|
+
s = pattern.sub(new, s)
|
|
617
|
+
return s
|
|
618
|
+
|
|
619
|
+
def _update_expr(expr: NormalizedExpr) -> NormalizedExpr:
|
|
620
|
+
return replace_refs_in_expr(expr, replace_cte_refs)
|
|
621
|
+
|
|
622
|
+
new_cte_steps = []
|
|
623
|
+
for cte in cte_steps:
|
|
624
|
+
new_name = old_to_new[cte.cte_name]
|
|
625
|
+
new_tables = [replace_cte_refs(t) for t in (cte.tables or [])]
|
|
626
|
+
new_select_cols = [replace(sc, expr=_update_expr(sc.expr)) for sc in (cte.select_cols or [])]
|
|
627
|
+
new_group_by = [_update_expr(g) for g in (cte.group_by_cols or [])]
|
|
628
|
+
new_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (cte.order_by_cols or [])]
|
|
629
|
+
new_filters = []
|
|
630
|
+
for fp in cte.filters_param or []:
|
|
631
|
+
new_fp = replace(
|
|
632
|
+
fp,
|
|
633
|
+
left_expr=_update_expr(fp.left_expr),
|
|
634
|
+
right_expr=_update_expr(fp.right_expr) if fp.right_expr else None,
|
|
635
|
+
)
|
|
636
|
+
new_filters.append(new_fp)
|
|
637
|
+
new_having = []
|
|
638
|
+
for hp in cte.having_param or []:
|
|
639
|
+
new_hp = replace(
|
|
640
|
+
hp,
|
|
641
|
+
left_expr=_update_expr(hp.left_expr),
|
|
642
|
+
right_expr=_update_expr(hp.right_expr) if hp.right_expr else None,
|
|
643
|
+
)
|
|
644
|
+
new_having.append(new_hp)
|
|
645
|
+
new_column_map = {}
|
|
646
|
+
for k, v in (cte.column_map or {}).items():
|
|
647
|
+
new_column_map[replace_cte_refs(k)] = replace_cte_refs(v)
|
|
648
|
+
new_output_columns = [replace_cte_refs(oc) for oc in (cte.output_columns or [])]
|
|
649
|
+
new_ocm = {replace_cte_refs(k): v for k, v in (cte.output_column_metadata or {}).items()}
|
|
650
|
+
new_cte = replace(
|
|
651
|
+
cte,
|
|
652
|
+
cte_name=new_name,
|
|
653
|
+
tables=new_tables,
|
|
654
|
+
select_cols=new_select_cols,
|
|
655
|
+
group_by_cols=new_group_by,
|
|
656
|
+
order_by_cols=new_order_by,
|
|
657
|
+
filters_param=new_filters,
|
|
658
|
+
having_param=new_having,
|
|
659
|
+
column_map=new_column_map,
|
|
660
|
+
output_columns=new_output_columns,
|
|
661
|
+
output_column_metadata=new_ocm,
|
|
662
|
+
)
|
|
663
|
+
new_cte_steps.append(new_cte)
|
|
664
|
+
|
|
665
|
+
new_main_tables = [replace_cte_refs(t) for t in (intent.tables or [])]
|
|
666
|
+
new_main_select = [replace(sc, expr=_update_expr(sc.expr)) for sc in (intent.select_cols or [])]
|
|
667
|
+
new_main_group_by = [_update_expr(g) for g in (intent.group_by_cols or [])]
|
|
668
|
+
new_main_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (intent.order_by_cols or [])]
|
|
669
|
+
new_main_filters = []
|
|
670
|
+
for fp in intent.filters_param or []:
|
|
671
|
+
new_fp = replace(
|
|
672
|
+
fp,
|
|
673
|
+
left_expr=_update_expr(fp.left_expr),
|
|
674
|
+
right_expr=_update_expr(fp.right_expr) if fp.right_expr else None,
|
|
675
|
+
)
|
|
676
|
+
new_main_filters.append(new_fp)
|
|
677
|
+
new_main_having = []
|
|
678
|
+
for hp in intent.having_param or []:
|
|
679
|
+
new_hp = replace(
|
|
680
|
+
hp,
|
|
681
|
+
left_expr=_update_expr(hp.left_expr),
|
|
682
|
+
right_expr=_update_expr(hp.right_expr) if hp.right_expr else None,
|
|
683
|
+
)
|
|
684
|
+
new_main_having.append(new_hp)
|
|
685
|
+
new_main_column_map = {}
|
|
686
|
+
for k, v in (intent.column_map or {}).items():
|
|
687
|
+
new_main_column_map[replace_cte_refs(k)] = replace_cte_refs(v)
|
|
688
|
+
return replace(
|
|
689
|
+
intent,
|
|
690
|
+
tables=new_main_tables,
|
|
691
|
+
select_cols=new_main_select,
|
|
692
|
+
group_by_cols=new_main_group_by,
|
|
693
|
+
order_by_cols=new_main_order_by,
|
|
694
|
+
filters_param=new_main_filters,
|
|
695
|
+
having_param=new_main_having,
|
|
696
|
+
column_map=new_main_column_map,
|
|
697
|
+
cte_steps=new_cte_steps,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def _normalize_expr_ref_for_alias(rendered: str) -> str:
|
|
702
|
+
"""Normalize a rendered expression for alias-map matching.
|
|
703
|
+
|
|
704
|
+
Strips spaces inside comparison/operator tokens and normalizes
|
|
705
|
+
aggregation function casing so small formatting differences still match.
|
|
706
|
+
"""
|
|
707
|
+
s = rendered.strip()
|
|
708
|
+
for op in (" >= ", " <= ", " != ", " = ", " > ", " < ", " + ", " - ", " * ", " / "):
|
|
709
|
+
s = s.replace(op, op.replace(" ", ""))
|
|
710
|
+
s = re.sub(r"\b(count|sum|avg|min|max)\s*\(", r"\1(", s, flags=re.IGNORECASE)
|
|
711
|
+
return s
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def _cte_output_alias_map(intent: RuntimeIntent) -> dict[str, str]:
|
|
715
|
+
"""Build a map from CTE-qualified expression form to CTE-qualified output column name.
|
|
716
|
+
|
|
717
|
+
For each CTE step, each select expression is rendered; the key is cte_name.rendered_expr
|
|
718
|
+
and the value is cte_name.output_columns[i]. Also adds a stripped/normalized key so
|
|
719
|
+
small formatting changes (spaces in operators, function casing) still match.
|
|
720
|
+
"""
|
|
721
|
+
alias_map: dict[str, str] = {}
|
|
722
|
+
for cte in intent.cte_steps or []:
|
|
723
|
+
output_cols = cte.output_columns or []
|
|
724
|
+
for i, sc in enumerate(cte.select_cols or []):
|
|
725
|
+
if i >= len(output_cols):
|
|
726
|
+
continue
|
|
727
|
+
rendered = _render_expr_sql(sc.expr)
|
|
728
|
+
from_ref = f"{cte.cte_name}.{rendered}"
|
|
729
|
+
to_ref = f"{cte.cte_name}.{output_cols[i]}"
|
|
730
|
+
if from_ref != to_ref:
|
|
731
|
+
alias_map[from_ref] = to_ref
|
|
732
|
+
stripped = _normalize_expr_ref_for_alias(rendered)
|
|
733
|
+
if stripped != rendered:
|
|
734
|
+
alias_map[f"{cte.cte_name}.{stripped}"] = to_ref
|
|
735
|
+
return alias_map
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def rewrite_cte_output_refs_to_aliases(intent: RuntimeIntent) -> RuntimeIntent:
|
|
739
|
+
"""Rewrite references to CTE outputs from expression form to output column alias.
|
|
740
|
+
|
|
741
|
+
After CTE names are normalized to cte1, cte2, references in the main query (e.g. cte1.COUNT(table_1.column_1))
|
|
742
|
+
are replaced with the deterministic output column name (e.g. cte1.count_column_1) so validation and SQL
|
|
743
|
+
generation see consistent column names.
|
|
744
|
+
"""
|
|
745
|
+
alias_map = _cte_output_alias_map(intent)
|
|
746
|
+
if not alias_map:
|
|
747
|
+
return intent
|
|
748
|
+
|
|
749
|
+
def replacer(s: str) -> str:
|
|
750
|
+
return alias_map.get(s, s)
|
|
751
|
+
|
|
752
|
+
def _update_expr(expr: NormalizedExpr) -> NormalizedExpr:
|
|
753
|
+
return replace_refs_in_expr(expr, replacer)
|
|
754
|
+
|
|
755
|
+
new_cte_steps = []
|
|
756
|
+
for cte in intent.cte_steps or []:
|
|
757
|
+
new_select = [replace(sc, expr=_update_expr(sc.expr)) for sc in (cte.select_cols or [])]
|
|
758
|
+
new_group_by = [_update_expr(g) for g in (cte.group_by_cols or [])]
|
|
759
|
+
new_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (cte.order_by_cols or [])]
|
|
760
|
+
new_filters = [
|
|
761
|
+
replace(fp, left_expr=_update_expr(fp.left_expr), right_expr=_update_expr(fp.right_expr) if fp.right_expr else None)
|
|
762
|
+
for fp in (cte.filters_param or [])
|
|
763
|
+
]
|
|
764
|
+
new_having = [
|
|
765
|
+
replace(hp, left_expr=_update_expr(hp.left_expr), right_expr=_update_expr(hp.right_expr) if hp.right_expr else None)
|
|
766
|
+
for hp in (cte.having_param or [])
|
|
767
|
+
]
|
|
768
|
+
new_cte_steps.append(
|
|
769
|
+
replace(
|
|
770
|
+
cte,
|
|
771
|
+
select_cols=new_select,
|
|
772
|
+
group_by_cols=new_group_by,
|
|
773
|
+
order_by_cols=new_order_by,
|
|
774
|
+
filters_param=new_filters,
|
|
775
|
+
having_param=new_having,
|
|
776
|
+
)
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
new_main_select = [replace(sc, expr=_update_expr(sc.expr)) for sc in (intent.select_cols or [])]
|
|
780
|
+
new_main_group_by = [_update_expr(g) for g in (intent.group_by_cols or [])]
|
|
781
|
+
new_main_order_by = [replace(obc, expr=_update_expr(obc.expr)) for obc in (intent.order_by_cols or [])]
|
|
782
|
+
new_main_filters = [
|
|
783
|
+
replace(fp, left_expr=_update_expr(fp.left_expr), right_expr=_update_expr(fp.right_expr) if fp.right_expr else None)
|
|
784
|
+
for fp in (intent.filters_param or [])
|
|
785
|
+
]
|
|
786
|
+
new_main_having = [
|
|
787
|
+
replace(hp, left_expr=_update_expr(hp.left_expr), right_expr=_update_expr(hp.right_expr) if hp.right_expr else None)
|
|
788
|
+
for hp in (intent.having_param or [])
|
|
789
|
+
]
|
|
790
|
+
return replace(
|
|
791
|
+
intent,
|
|
792
|
+
select_cols=new_main_select,
|
|
793
|
+
group_by_cols=new_main_group_by,
|
|
794
|
+
order_by_cols=new_main_order_by,
|
|
795
|
+
filters_param=new_main_filters,
|
|
796
|
+
having_param=new_main_having,
|
|
797
|
+
cte_steps=new_cte_steps,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
def enforce_schema(intent: RuntimeIntent, schema_graph: SchemaGraph) -> tuple[RuntimeIntent, list[str]]:
|
|
802
|
+
"""Validate intent table and column references against the schema graph.
|
|
803
|
+
|
|
804
|
+
Checks that every table referenced in the intent exists in the schema or is a CTE name and that every qualified column reference points to a known column.
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
|
|
808
|
+
intent: RuntimeIntent to validate.
|
|
809
|
+
|
|
810
|
+
schema_graph: SchemaGraph providing the authoritative table/column set.
|
|
811
|
+
|
|
812
|
+
Returns:
|
|
813
|
+
|
|
814
|
+
Tuple of (intent, errors) where errors is a list of human-readable violation strings and the intent is returned unchanged.
|
|
815
|
+
"""
|
|
816
|
+
errors: list[str] = []
|
|
817
|
+
valid_tables = set(schema_graph.tables.keys())
|
|
818
|
+
cte_names = {cte.cte_name for cte in (intent.cte_steps or [])}
|
|
819
|
+
for tbl in intent.tables or []:
|
|
820
|
+
if tbl not in valid_tables and tbl not in cte_names:
|
|
821
|
+
errors.append(f"Unknown table: {tbl}")
|
|
822
|
+
|
|
823
|
+
def _check_expr_cols(exprs: list, label: str) -> None:
|
|
824
|
+
for item in exprs:
|
|
825
|
+
expr = item.expr if hasattr(item, "expr") else item
|
|
826
|
+
for col in extract_columns_from_expr(expr):
|
|
827
|
+
if "." in col:
|
|
828
|
+
tbl_ref, col_ref = col.split(".", 1)
|
|
829
|
+
if tbl_ref in valid_tables:
|
|
830
|
+
tbl_meta = schema_graph.tables[tbl_ref]
|
|
831
|
+
if col_ref not in tbl_meta.columns:
|
|
832
|
+
errors.append(f"Unknown {label} column: {col}")
|
|
833
|
+
|
|
834
|
+
def _check_filter_cols(params: list, label: str) -> None:
|
|
835
|
+
for fp in params:
|
|
836
|
+
for col in extract_columns_from_expr(fp.left_expr):
|
|
837
|
+
if "." in col:
|
|
838
|
+
tbl_ref, col_ref = col.split(".", 1)
|
|
839
|
+
if tbl_ref in valid_tables:
|
|
840
|
+
tbl_meta = schema_graph.tables[tbl_ref]
|
|
841
|
+
if col_ref not in tbl_meta.columns:
|
|
842
|
+
errors.append(f"Unknown {label} column: {col}")
|
|
843
|
+
if fp.right_expr:
|
|
844
|
+
for col in extract_columns_from_expr(fp.right_expr):
|
|
845
|
+
if "." in col:
|
|
846
|
+
tbl_ref, col_ref = col.split(".", 1)
|
|
847
|
+
if tbl_ref in valid_tables:
|
|
848
|
+
tbl_meta = schema_graph.tables[tbl_ref]
|
|
849
|
+
if col_ref not in tbl_meta.columns:
|
|
850
|
+
errors.append(f"Unknown {label} column: {col}")
|
|
851
|
+
|
|
852
|
+
def _check_bare_cols(cols: list, label: str) -> None:
|
|
853
|
+
for g in cols:
|
|
854
|
+
col = g.primary_term if hasattr(g, "primary_term") else str(g)
|
|
855
|
+
if "." in col:
|
|
856
|
+
tbl_ref, col_ref = col.split(".", 1)
|
|
857
|
+
if tbl_ref in valid_tables:
|
|
858
|
+
tbl_meta = schema_graph.tables[tbl_ref]
|
|
859
|
+
if col_ref not in tbl_meta.columns:
|
|
860
|
+
errors.append(f"Unknown {label} column: {col}")
|
|
861
|
+
|
|
862
|
+
_check_expr_cols(intent.select_cols or [], "select")
|
|
863
|
+
_check_expr_cols(intent.order_by_cols or [], "order_by")
|
|
864
|
+
_check_filter_cols(intent.filters_param or [], "filter")
|
|
865
|
+
_check_filter_cols(intent.having_param or [], "having")
|
|
866
|
+
_check_bare_cols(intent.group_by_cols or [], "group_by")
|
|
867
|
+
for cte in intent.cte_steps or []:
|
|
868
|
+
ctx = f"CTE '{cte.cte_name}'"
|
|
869
|
+
for tbl in cte.tables or []:
|
|
870
|
+
if tbl not in valid_tables and tbl not in cte_names:
|
|
871
|
+
errors.append(f"{ctx} unknown table: {tbl}")
|
|
872
|
+
_check_expr_cols(cte.select_cols or [], f"{ctx} select")
|
|
873
|
+
_check_expr_cols(cte.order_by_cols or [], f"{ctx} order_by")
|
|
874
|
+
_check_filter_cols(cte.filters_param or [], f"{ctx} filter")
|
|
875
|
+
_check_filter_cols(cte.having_param or [], f"{ctx} having")
|
|
876
|
+
_check_bare_cols(cte.group_by_cols or [], f"{ctx} group_by")
|
|
877
|
+
if errors:
|
|
878
|
+
debug(f"[intent_resolve.enforce_schema] validation errors: {errors}")
|
|
879
|
+
return intent, errors
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def _simplify_expr(expr: NormalizedExpr) -> NormalizedExpr:
|
|
883
|
+
"""Apply algebraic simplifications to a NormalizedExpr.
|
|
884
|
+
|
|
885
|
+
Performs constant folding that accumulates numeric literal terms, like-term combining that merges groups with identical structural keys, zero-coefficient elimination, negative coefficient normalization that moves negative-coeff add_groups to sub_groups, and coefficient-to-value collapse where groups with no operands become ExprValue offsets.
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
|
|
889
|
+
expr: NormalizedExpr to simplify.
|
|
890
|
+
|
|
891
|
+
Returns:
|
|
892
|
+
|
|
893
|
+
New NormalizedExpr in simplified canonical form.
|
|
894
|
+
"""
|
|
895
|
+
add_groups: list[MulGroup] = []
|
|
896
|
+
sub_groups: list[MulGroup] = []
|
|
897
|
+
add_vals: list[ExprValue] = []
|
|
898
|
+
sub_vals: list[ExprValue] = []
|
|
899
|
+
parameterized_add: list[ExprValue] = []
|
|
900
|
+
parameterized_sub: list[ExprValue] = []
|
|
901
|
+
for v in expr.add_values:
|
|
902
|
+
(parameterized_add if v.param_key else add_vals).append(v)
|
|
903
|
+
for v in expr.sub_values:
|
|
904
|
+
(parameterized_sub if v.param_key else sub_vals).append(v)
|
|
905
|
+
net_const = sum(v.value for v in add_vals) - sum(v.value for v in sub_vals)
|
|
906
|
+
for g in expr.add_groups:
|
|
907
|
+
if not g.multiply and not g.divide and not g.agg_func and not g.scalar_func and not g.inner_scalar_func:
|
|
908
|
+
net_const += g.coefficient
|
|
909
|
+
else:
|
|
910
|
+
add_groups.append(g)
|
|
911
|
+
for g in expr.sub_groups:
|
|
912
|
+
if not g.multiply and not g.divide and not g.agg_func and not g.scalar_func and not g.inner_scalar_func:
|
|
913
|
+
net_const -= g.coefficient
|
|
914
|
+
else:
|
|
915
|
+
sub_groups.append(g)
|
|
916
|
+
bucket: dict[str, float] = {}
|
|
917
|
+
group_map: dict[str, MulGroup] = {}
|
|
918
|
+
for g in add_groups:
|
|
919
|
+
key = g.structural_key
|
|
920
|
+
bucket[key] = bucket.get(key, 0.0) + g.coefficient
|
|
921
|
+
if key not in group_map:
|
|
922
|
+
group_map[key] = g
|
|
923
|
+
for g in sub_groups:
|
|
924
|
+
key = g.structural_key
|
|
925
|
+
bucket[key] = bucket.get(key, 0.0) - g.coefficient
|
|
926
|
+
if key not in group_map:
|
|
927
|
+
group_map[key] = g
|
|
928
|
+
final_add: list[MulGroup] = []
|
|
929
|
+
final_sub: list[MulGroup] = []
|
|
930
|
+
for key, coeff in bucket.items():
|
|
931
|
+
if coeff == 0.0:
|
|
932
|
+
continue
|
|
933
|
+
ref = group_map[key]
|
|
934
|
+
if coeff > 0:
|
|
935
|
+
final_add.append(
|
|
936
|
+
MulGroup(
|
|
937
|
+
coefficient=coeff,
|
|
938
|
+
multiply=list(ref.multiply),
|
|
939
|
+
divide=list(ref.divide),
|
|
940
|
+
agg_func=ref.agg_func,
|
|
941
|
+
scalar_func=ref.scalar_func,
|
|
942
|
+
inner_scalar_func=ref.inner_scalar_func,
|
|
943
|
+
scalar_func_args=list(ref.scalar_func_args),
|
|
944
|
+
inner_scalar_func_args=list(ref.inner_scalar_func_args),
|
|
945
|
+
)
|
|
946
|
+
)
|
|
947
|
+
else:
|
|
948
|
+
final_sub.append(
|
|
949
|
+
MulGroup(
|
|
950
|
+
coefficient=abs(coeff),
|
|
951
|
+
multiply=list(ref.multiply),
|
|
952
|
+
divide=list(ref.divide),
|
|
953
|
+
agg_func=ref.agg_func,
|
|
954
|
+
scalar_func=ref.scalar_func,
|
|
955
|
+
inner_scalar_func=ref.inner_scalar_func,
|
|
956
|
+
scalar_func_args=list(ref.scalar_func_args),
|
|
957
|
+
inner_scalar_func_args=list(ref.inner_scalar_func_args),
|
|
958
|
+
)
|
|
959
|
+
)
|
|
960
|
+
final_add_vals: list[ExprValue] = list(parameterized_add)
|
|
961
|
+
final_sub_vals: list[ExprValue] = list(parameterized_sub)
|
|
962
|
+
if net_const > 0:
|
|
963
|
+
final_add_vals.append(ExprValue(value=net_const))
|
|
964
|
+
elif net_const < 0:
|
|
965
|
+
final_sub_vals.append(ExprValue(value=abs(net_const)))
|
|
966
|
+
return NormalizedExpr(
|
|
967
|
+
add_groups=final_add,
|
|
968
|
+
sub_groups=final_sub,
|
|
969
|
+
add_values=final_add_vals,
|
|
970
|
+
sub_values=final_sub_vals,
|
|
971
|
+
agg_func=expr.agg_func,
|
|
972
|
+
scalar_func=expr.scalar_func,
|
|
973
|
+
inner_scalar_func=expr.inner_scalar_func,
|
|
974
|
+
scalar_func_args=list(expr.scalar_func_args),
|
|
975
|
+
inner_scalar_func_args=list(expr.inner_scalar_func_args),
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
def _simplify_filter(fp: FilterParam) -> FilterParam:
|
|
980
|
+
"""Apply simplify_expr to both sides of a FilterParam.
|
|
981
|
+
|
|
982
|
+
Args:
|
|
983
|
+
|
|
984
|
+
fp: FilterParam to simplify.
|
|
985
|
+
|
|
986
|
+
Returns:
|
|
987
|
+
|
|
988
|
+
New FilterParam with simplified left_expr and right_expr.
|
|
989
|
+
"""
|
|
990
|
+
new_left = _simplify_expr(fp.left_expr)
|
|
991
|
+
new_right = _simplify_expr(fp.right_expr) if fp.right_expr else None
|
|
992
|
+
return replace(fp, left_expr=new_left, right_expr=new_right)
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
def _simplify_having(hp: HavingParam) -> HavingParam:
|
|
996
|
+
"""Apply simplify_expr to both sides of a HavingParam.
|
|
997
|
+
|
|
998
|
+
Args:
|
|
999
|
+
|
|
1000
|
+
hp: HavingParam to simplify.
|
|
1001
|
+
|
|
1002
|
+
Returns:
|
|
1003
|
+
|
|
1004
|
+
New HavingParam with simplified left_expr and right_expr.
|
|
1005
|
+
"""
|
|
1006
|
+
new_left = _simplify_expr(hp.left_expr)
|
|
1007
|
+
new_right = _simplify_expr(hp.right_expr) if hp.right_expr else None
|
|
1008
|
+
return replace(hp, left_expr=new_left, right_expr=new_right)
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def simplify_exprs(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1012
|
+
"""Apply algebraic simplification to every NormalizedExpr across all intent clauses.
|
|
1013
|
+
|
|
1014
|
+
Args:
|
|
1015
|
+
|
|
1016
|
+
intent: RuntimeIntent whose expressions should be simplified.
|
|
1017
|
+
|
|
1018
|
+
Returns:
|
|
1019
|
+
|
|
1020
|
+
New RuntimeIntent with all expressions in simplified form.
|
|
1021
|
+
"""
|
|
1022
|
+
debug("[intent_resolve.simplify_exprs] simplifying all expressions")
|
|
1023
|
+
new_select = [replace(sc, expr=_simplify_expr(sc.expr)) for sc in (intent.select_cols or [])]
|
|
1024
|
+
new_order = [replace(obc, expr=_simplify_expr(obc.expr)) for obc in (intent.order_by_cols or [])]
|
|
1025
|
+
new_filters = [_simplify_filter(fp) for fp in (intent.filters_param or [])]
|
|
1026
|
+
new_having = [_simplify_having(hp) for hp in (intent.having_param or [])]
|
|
1027
|
+
new_cte_steps: list[RuntimeCteStep] = []
|
|
1028
|
+
for cte in intent.cte_steps or []:
|
|
1029
|
+
cte_select = [replace(sc, expr=_simplify_expr(sc.expr)) for sc in (cte.select_cols or [])]
|
|
1030
|
+
cte_order = [replace(obc, expr=_simplify_expr(obc.expr)) for obc in (cte.order_by_cols or [])]
|
|
1031
|
+
cte_filters = [_simplify_filter(fp) for fp in (cte.filters_param or [])]
|
|
1032
|
+
cte_having = [_simplify_having(hp) for hp in (cte.having_param or [])]
|
|
1033
|
+
new_cte_steps.append(
|
|
1034
|
+
replace(
|
|
1035
|
+
cte,
|
|
1036
|
+
select_cols=cte_select,
|
|
1037
|
+
order_by_cols=cte_order,
|
|
1038
|
+
filters_param=cte_filters,
|
|
1039
|
+
having_param=cte_having,
|
|
1040
|
+
)
|
|
1041
|
+
)
|
|
1042
|
+
return replace(
|
|
1043
|
+
intent,
|
|
1044
|
+
select_cols=new_select,
|
|
1045
|
+
order_by_cols=new_order,
|
|
1046
|
+
filters_param=new_filters,
|
|
1047
|
+
having_param=new_having,
|
|
1048
|
+
cte_steps=new_cte_steps,
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _normalize_filter_scalar_on_left(fp: FilterParam) -> FilterParam:
|
|
1053
|
+
"""Swap sides when left is scalar and right is column, flipping the operator.
|
|
1054
|
+
|
|
1055
|
+
Ensures column or table.column expressions are on the left for validation and SQL generation.
|
|
1056
|
+
|
|
1057
|
+
Args:
|
|
1058
|
+
|
|
1059
|
+
fp: FilterParam to normalize.
|
|
1060
|
+
|
|
1061
|
+
Returns:
|
|
1062
|
+
|
|
1063
|
+
FilterParam with column on the left when applicable.
|
|
1064
|
+
"""
|
|
1065
|
+
if not fp.right_expr:
|
|
1066
|
+
return fp
|
|
1067
|
+
left_cols = [c for c in extract_columns_from_expr(fp.left_expr) if "." in c]
|
|
1068
|
+
right_cols = [c for c in extract_columns_from_expr(fp.right_expr) if "." in c]
|
|
1069
|
+
if left_cols or not right_cols:
|
|
1070
|
+
return fp
|
|
1071
|
+
new_op = REVERSE_OP_MAP.get(fp.op, fp.op)
|
|
1072
|
+
return FilterParam(
|
|
1073
|
+
left_expr=fp.right_expr,
|
|
1074
|
+
op=new_op,
|
|
1075
|
+
right_expr=fp.left_expr,
|
|
1076
|
+
value_type=fp.value_type,
|
|
1077
|
+
param_key=fp.param_key,
|
|
1078
|
+
raw_value=fp.raw_value,
|
|
1079
|
+
bool_op=fp.bool_op,
|
|
1080
|
+
filter_group=fp.filter_group,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
def _normalize_filter_canonical(fp: FilterParam) -> FilterParam:
|
|
1085
|
+
"""Normalize a filter to canonical form with a non-empty expression on the left.
|
|
1086
|
+
|
|
1087
|
+
When the left_expr is empty but right_expr is not, swaps the sides and reverses the comparison operator.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
|
|
1091
|
+
fp: FilterParam to normalize.
|
|
1092
|
+
|
|
1093
|
+
Returns:
|
|
1094
|
+
|
|
1095
|
+
FilterParam with the heavier side on the left.
|
|
1096
|
+
"""
|
|
1097
|
+
if not fp.left_expr.add_groups and not fp.left_expr.add_values and fp.right_expr:
|
|
1098
|
+
new_op = REVERSE_OP_MAP.get(fp.op, fp.op)
|
|
1099
|
+
return FilterParam(
|
|
1100
|
+
left_expr=fp.right_expr,
|
|
1101
|
+
op=new_op,
|
|
1102
|
+
right_expr=fp.left_expr,
|
|
1103
|
+
value_type=fp.value_type,
|
|
1104
|
+
param_key=fp.param_key,
|
|
1105
|
+
)
|
|
1106
|
+
return fp
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
def _normalize_having_canonical(hp: HavingParam) -> HavingParam:
|
|
1110
|
+
"""Normalize a having condition to canonical form with a non-empty expression on the left.
|
|
1111
|
+
|
|
1112
|
+
Args:
|
|
1113
|
+
|
|
1114
|
+
hp: HavingParam to normalize.
|
|
1115
|
+
|
|
1116
|
+
Returns:
|
|
1117
|
+
|
|
1118
|
+
HavingParam with the heavier side on the left.
|
|
1119
|
+
"""
|
|
1120
|
+
if not hp.left_expr.add_groups and not hp.left_expr.add_values and hp.right_expr:
|
|
1121
|
+
new_op = REVERSE_OP_MAP.get(hp.op, hp.op)
|
|
1122
|
+
return HavingParam(
|
|
1123
|
+
left_expr=hp.right_expr,
|
|
1124
|
+
op=new_op,
|
|
1125
|
+
right_expr=hp.left_expr,
|
|
1126
|
+
value_type=hp.value_type,
|
|
1127
|
+
param_key=hp.param_key,
|
|
1128
|
+
)
|
|
1129
|
+
return hp
|
|
1130
|
+
|
|
1131
|
+
|
|
1132
|
+
def _normalize_col_to_col_filter(fp: FilterParam) -> FilterParam:
|
|
1133
|
+
"""Normalize an expr-vs-expr filter so the lexicographically smaller signature is on the left.
|
|
1134
|
+
|
|
1135
|
+
Args:
|
|
1136
|
+
|
|
1137
|
+
fp: FilterParam with a right_expr (col-vs-col filter).
|
|
1138
|
+
|
|
1139
|
+
Returns:
|
|
1140
|
+
|
|
1141
|
+
FilterParam with sides swapped and operator reversed if needed.
|
|
1142
|
+
"""
|
|
1143
|
+
if fp.right_expr and not fp.param_key:
|
|
1144
|
+
left_sig = fp.left_expr.signature_key
|
|
1145
|
+
right_sig = fp.right_expr.signature_key
|
|
1146
|
+
if left_sig > right_sig:
|
|
1147
|
+
new_op = REVERSE_OP_MAP.get(fp.op, fp.op)
|
|
1148
|
+
return FilterParam(
|
|
1149
|
+
left_expr=fp.right_expr,
|
|
1150
|
+
op=new_op,
|
|
1151
|
+
right_expr=fp.left_expr,
|
|
1152
|
+
value_type=fp.value_type,
|
|
1153
|
+
param_key=fp.param_key,
|
|
1154
|
+
)
|
|
1155
|
+
return fp
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
def _normalize_agg_to_agg_having(hp: HavingParam) -> HavingParam:
|
|
1159
|
+
"""Normalize an expr-vs-expr having condition so the lexicographically smaller signature is on the left.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
|
|
1163
|
+
hp: HavingParam with a right_expr (agg-vs-agg having).
|
|
1164
|
+
|
|
1165
|
+
Returns:
|
|
1166
|
+
|
|
1167
|
+
HavingParam with sides swapped and operator reversed if needed.
|
|
1168
|
+
"""
|
|
1169
|
+
if hp.right_expr and not hp.param_key:
|
|
1170
|
+
left_sig = hp.left_expr.signature_key
|
|
1171
|
+
right_sig = hp.right_expr.signature_key
|
|
1172
|
+
if left_sig > right_sig:
|
|
1173
|
+
new_op = REVERSE_OP_MAP.get(hp.op, hp.op)
|
|
1174
|
+
return HavingParam(
|
|
1175
|
+
left_expr=hp.right_expr,
|
|
1176
|
+
op=new_op,
|
|
1177
|
+
right_expr=hp.left_expr,
|
|
1178
|
+
value_type=hp.value_type,
|
|
1179
|
+
param_key=hp.param_key,
|
|
1180
|
+
)
|
|
1181
|
+
return hp
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def _normalize_filter(fp: FilterParam) -> FilterParam:
|
|
1185
|
+
"""Apply all normalization steps to a single filter.
|
|
1186
|
+
|
|
1187
|
+
Runs scalar-on-left swap, canonical form, col-vs-col ordering, operator normalization, and value type normalization in sequence.
|
|
1188
|
+
|
|
1189
|
+
Args:
|
|
1190
|
+
|
|
1191
|
+
fp: FilterParam to normalize.
|
|
1192
|
+
|
|
1193
|
+
Returns:
|
|
1194
|
+
|
|
1195
|
+
Fully normalized FilterParam.
|
|
1196
|
+
"""
|
|
1197
|
+
fp = _normalize_filter_scalar_on_left(fp)
|
|
1198
|
+
fp = _normalize_filter_canonical(fp)
|
|
1199
|
+
fp = _normalize_col_to_col_filter(fp)
|
|
1200
|
+
return replace(fp, op=normalize_op(fp.op), value_type=normalize_value_type(fp.value_type))
|
|
1201
|
+
|
|
1202
|
+
|
|
1203
|
+
def _normalize_having(hp: HavingParam) -> HavingParam:
|
|
1204
|
+
"""Apply all normalization steps to a single having condition.
|
|
1205
|
+
|
|
1206
|
+
Runs canonical form, agg-vs-agg ordering, operator normalization, and value type normalization in sequence.
|
|
1207
|
+
|
|
1208
|
+
Args:
|
|
1209
|
+
|
|
1210
|
+
hp: HavingParam to normalize.
|
|
1211
|
+
|
|
1212
|
+
Returns:
|
|
1213
|
+
|
|
1214
|
+
Fully normalized HavingParam.
|
|
1215
|
+
"""
|
|
1216
|
+
hp = _normalize_having_canonical(hp)
|
|
1217
|
+
hp = _normalize_agg_to_agg_having(hp)
|
|
1218
|
+
return replace(hp, op=normalize_op(hp.op), value_type=normalize_value_type(hp.value_type))
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
def _dedup_filters(filters: list[FilterParam]) -> list[FilterParam]:
|
|
1222
|
+
"""Remove duplicate filters that share an identical structural signature, bool_op, and filter_group.
|
|
1223
|
+
|
|
1224
|
+
Args:
|
|
1225
|
+
|
|
1226
|
+
filters: List of FilterParam objects to deduplicate.
|
|
1227
|
+
|
|
1228
|
+
Returns:
|
|
1229
|
+
|
|
1230
|
+
List with the first occurrence of each unique (signature_key, bool_op, filter_group) retained.
|
|
1231
|
+
"""
|
|
1232
|
+
seen: set[tuple[str, str, int | None]] = set()
|
|
1233
|
+
result: list[FilterParam] = []
|
|
1234
|
+
for fp in filters:
|
|
1235
|
+
key = (fp.signature_key, fp.bool_op, fp.filter_group)
|
|
1236
|
+
if key in seen:
|
|
1237
|
+
debug(f"[intent_resolve.dedup_filters] dropping duplicate filter: {key}")
|
|
1238
|
+
continue
|
|
1239
|
+
seen.add(key)
|
|
1240
|
+
result.append(fp)
|
|
1241
|
+
return result
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
def _dedup_having(having: list[HavingParam]) -> list[HavingParam]:
|
|
1245
|
+
"""Remove duplicate having conditions that share an identical structural signature, bool_op, and filter_group.
|
|
1246
|
+
|
|
1247
|
+
Args:
|
|
1248
|
+
|
|
1249
|
+
having: List of HavingParam objects to deduplicate.
|
|
1250
|
+
|
|
1251
|
+
Returns:
|
|
1252
|
+
|
|
1253
|
+
List with the first occurrence of each unique (signature_key, bool_op, filter_group) retained.
|
|
1254
|
+
"""
|
|
1255
|
+
seen: set[tuple[str, str, int | None]] = set()
|
|
1256
|
+
result: list[HavingParam] = []
|
|
1257
|
+
for hp in having:
|
|
1258
|
+
key = (hp.signature_key, hp.bool_op, hp.filter_group)
|
|
1259
|
+
if key in seen:
|
|
1260
|
+
debug(f"[intent_resolve.dedup_having] dropping duplicate having: {key}")
|
|
1261
|
+
continue
|
|
1262
|
+
seen.add(key)
|
|
1263
|
+
result.append(hp)
|
|
1264
|
+
return result
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def normalize_filters_havings(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1268
|
+
"""Apply all normalization, deduplication, and sorting rules to filters and having conditions.
|
|
1269
|
+
|
|
1270
|
+
Args:
|
|
1271
|
+
|
|
1272
|
+
intent: RuntimeIntent whose filters and having lists should be normalized.
|
|
1273
|
+
|
|
1274
|
+
Returns:
|
|
1275
|
+
|
|
1276
|
+
New RuntimeIntent with all filters and having conditions normalized, deduplicated, and sorted.
|
|
1277
|
+
"""
|
|
1278
|
+
new_filters = [_normalize_filter(fp) for fp in (intent.filters_param or [])]
|
|
1279
|
+
new_having = [_normalize_having(hp) for hp in (intent.having_param or [])]
|
|
1280
|
+
new_cte_steps = []
|
|
1281
|
+
for cte in intent.cte_steps or []:
|
|
1282
|
+
cte_filters = _dedup_filters(sort_filters([_normalize_filter(fp) for fp in (cte.filters_param or [])]))
|
|
1283
|
+
cte_having = _dedup_having(sort_having([_normalize_having(hp) for hp in (cte.having_param or [])]))
|
|
1284
|
+
new_cte_steps.append(replace(cte, filters_param=cte_filters, having_param=cte_having))
|
|
1285
|
+
new_filters = _dedup_filters(sort_filters(new_filters))
|
|
1286
|
+
new_having = _dedup_having(sort_having(new_having))
|
|
1287
|
+
return replace(
|
|
1288
|
+
intent,
|
|
1289
|
+
filters_param=new_filters,
|
|
1290
|
+
having_param=new_having,
|
|
1291
|
+
cte_steps=new_cte_steps,
|
|
1292
|
+
)
|