aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/intent_expr.py
ADDED
|
@@ -0,0 +1,1759 @@
|
|
|
1
|
+
"""Expression parsing, algebraic structure, and intent response parsing.
|
|
2
|
+
|
|
3
|
+
Parses SQL-like expression strings such as SUM(t.col) or t.a + t.b into NormalizedExpr and MulGroup trees via parse_expr_string. Tags expressions as numeric or non-numeric, assigns structural param keys such as s1 and s2, and ensures scalar function defaults. Parses LLM JSON responses into RuntimeIntent via parse_intent_response and handles BETWEEN decomposition, IN-list normalization, and CTE output column and metadata derivation.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from dataclasses import replace
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import jsonschema
|
|
14
|
+
|
|
15
|
+
from .config import (
|
|
16
|
+
CTE_DEFAULT_AGGS,
|
|
17
|
+
CTE_FULL_AGGS,
|
|
18
|
+
CTE_HAVING_COMPARE_OPS,
|
|
19
|
+
CTE_NUMERIC_FILTER_OPS,
|
|
20
|
+
DATE_UNIT_KEYWORDS,
|
|
21
|
+
IN_OPS,
|
|
22
|
+
IN_STRING_SEPARATORS,
|
|
23
|
+
INTEGER_SCALARS,
|
|
24
|
+
INTENT_SCHEMA,
|
|
25
|
+
NUMERIC_RESULT_AGGS,
|
|
26
|
+
NUMERIC_RESULT_SCALARS,
|
|
27
|
+
SCALAR_FUNC_DEFAULTS,
|
|
28
|
+
SCALAR_FUNCTIONS_LEADING_ARG,
|
|
29
|
+
STRUCTURAL_IDENTITY_VALUES,
|
|
30
|
+
VALID_AGG_FUNCS,
|
|
31
|
+
VALID_DATE_DIFF_UNITS,
|
|
32
|
+
VALID_DATE_WINDOW_UNITS,
|
|
33
|
+
)
|
|
34
|
+
from .contracts_base import CteOutputColumnMeta, SchemaGraph
|
|
35
|
+
from .contracts_core import (
|
|
36
|
+
ExprValue,
|
|
37
|
+
FilterParam,
|
|
38
|
+
HavingParam,
|
|
39
|
+
MulGroup,
|
|
40
|
+
NormalizedExpr,
|
|
41
|
+
OrderByCol,
|
|
42
|
+
RuntimeCteStep,
|
|
43
|
+
RuntimeIntent,
|
|
44
|
+
SelectCol,
|
|
45
|
+
Template,
|
|
46
|
+
)
|
|
47
|
+
from .core_utils import debug, normalize_op, safe_json_loads
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _peel_function(s: str) -> tuple[str | None, int | float | str | None, str]:
|
|
51
|
+
"""Strip the outermost FUNC(...) wrapper from an expression string.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
|
|
55
|
+
s: SQL expression string that may be wrapped in a function call.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
|
|
59
|
+
Tuple of (func_name, trailing_arg, inner_body). func_name is None when no function wrapper is found; trailing_arg is the extracted numeric or string secondary argument if present; inner_body is the unwrapped content.
|
|
60
|
+
"""
|
|
61
|
+
s = s.strip()
|
|
62
|
+
m = re.match(r"^([A-Za-z_]\w*)\s*\(", s)
|
|
63
|
+
if not m:
|
|
64
|
+
return None, None, s
|
|
65
|
+
func_name = m.group(1)
|
|
66
|
+
open_pos = m.end() - 1
|
|
67
|
+
depth = 0
|
|
68
|
+
for i in range(open_pos, len(s)):
|
|
69
|
+
if s[i] == "(":
|
|
70
|
+
depth += 1
|
|
71
|
+
elif s[i] == ")":
|
|
72
|
+
depth -= 1
|
|
73
|
+
if depth == 0:
|
|
74
|
+
if i != len(s) - 1:
|
|
75
|
+
return None, None, s
|
|
76
|
+
inner = s[open_pos + 1 : i]
|
|
77
|
+
if func_name.lower() == "extract" and re.search(r"\bfrom\b", inner, re.IGNORECASE):
|
|
78
|
+
from_parts = re.split(r"\s+from\s+", inner, maxsplit=1, flags=re.IGNORECASE)
|
|
79
|
+
if len(from_parts) == 2:
|
|
80
|
+
return (
|
|
81
|
+
func_name,
|
|
82
|
+
from_parts[0].strip().strip("'\"").lower(),
|
|
83
|
+
from_parts[1].strip(),
|
|
84
|
+
)
|
|
85
|
+
parts: list[str] = []
|
|
86
|
+
d = 0
|
|
87
|
+
start = 0
|
|
88
|
+
for j, c in enumerate(inner):
|
|
89
|
+
if c == "(":
|
|
90
|
+
d += 1
|
|
91
|
+
elif c == ")":
|
|
92
|
+
d -= 1
|
|
93
|
+
elif c == "," and d == 0:
|
|
94
|
+
parts.append(inner[start:j])
|
|
95
|
+
start = j + 1
|
|
96
|
+
parts.append(inner[start:])
|
|
97
|
+
if len(parts) > 1:
|
|
98
|
+
first = parts[0].strip()
|
|
99
|
+
if (first.startswith("'") and first.endswith("'")) or (first.startswith('"') and first.endswith('"')):
|
|
100
|
+
body = ",".join(parts[1:]).strip()
|
|
101
|
+
return func_name, first.strip("'\""), body
|
|
102
|
+
last = parts[-1].strip()
|
|
103
|
+
try:
|
|
104
|
+
num = float(last)
|
|
105
|
+
arg: int | float = int(num) if num == int(num) else num
|
|
106
|
+
body = ",".join(parts[:-1]).strip()
|
|
107
|
+
return func_name, arg, body
|
|
108
|
+
except ValueError:
|
|
109
|
+
pass
|
|
110
|
+
if (last.startswith("'") and last.endswith("'")) or (last.startswith('"') and last.endswith('"')):
|
|
111
|
+
body = ",".join(parts[:-1]).strip()
|
|
112
|
+
return func_name, last.strip("'\""), body
|
|
113
|
+
return func_name, None, inner.strip()
|
|
114
|
+
return None, None, s
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
_TABLE_COLUMN_PATTERN = re.compile(r"\w+\.\w+")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _split_arithmetic_term_into_column_refs(term: str) -> list[str]:
|
|
121
|
+
"""Return table.column refs from a term that may contain arithmetic.
|
|
122
|
+
|
|
123
|
+
If the term contains top-level " - " or " + ", extracts all
|
|
124
|
+
table.column substrings. Otherwise returns [term] when term
|
|
125
|
+
looks like table.column.
|
|
126
|
+
"""
|
|
127
|
+
stripped = term.strip()
|
|
128
|
+
if not stripped or stripped == "*":
|
|
129
|
+
return []
|
|
130
|
+
if " - " in stripped or " + " in stripped:
|
|
131
|
+
return _TABLE_COLUMN_PATTERN.findall(stripped)
|
|
132
|
+
if _TABLE_COLUMN_PATTERN.fullmatch(stripped):
|
|
133
|
+
return [stripped]
|
|
134
|
+
return [stripped]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
_DATE_INTERVAL_PATTERNS = (
|
|
138
|
+
"current_date",
|
|
139
|
+
"current_timestamp",
|
|
140
|
+
"now()",
|
|
141
|
+
"sysdate",
|
|
142
|
+
"interval",
|
|
143
|
+
"localtimestamp",
|
|
144
|
+
"localtime",
|
|
145
|
+
"utc_timestamp",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _is_date_or_interval_expr(s: str) -> bool:
|
|
150
|
+
"""Return True if the string is a date or interval expression that should be kept as right_expr."""
|
|
151
|
+
if not s or not isinstance(s, str):
|
|
152
|
+
return False
|
|
153
|
+
lower = s.strip().lower()
|
|
154
|
+
return any(p in lower for p in _DATE_INTERVAL_PATTERNS)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def extract_columns_from_expr(expr: NormalizedExpr) -> list[str]:
|
|
158
|
+
"""Extract all column references from a NormalizedExpr, stripping function wrappers.
|
|
159
|
+
|
|
160
|
+
Traverses multiply and divide operands in every MulGroup, strips
|
|
161
|
+
parenthetical wrappers, and splits arithmetic terms into
|
|
162
|
+
table.column refs so schema validation accepts expression columns.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
|
|
166
|
+
expr: NormalizedExpr to traverse.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
|
|
170
|
+
List of column reference strings, possibly qualified as table.column.
|
|
171
|
+
"""
|
|
172
|
+
cols: list[str] = []
|
|
173
|
+
for group in expr.add_groups + expr.sub_groups:
|
|
174
|
+
for term in group.multiply + group.divide:
|
|
175
|
+
inner = term
|
|
176
|
+
while "(" in inner:
|
|
177
|
+
start = inner.index("(")
|
|
178
|
+
end = inner.rindex(")")
|
|
179
|
+
inner = inner[start + 1 : end]
|
|
180
|
+
for ref in _split_arithmetic_term_into_column_refs(inner):
|
|
181
|
+
if ref and ref not in cols:
|
|
182
|
+
cols.append(ref)
|
|
183
|
+
return cols
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def replace_refs_in_expr(expr: NormalizedExpr, replacer: Callable[[str], str]) -> NormalizedExpr:
|
|
187
|
+
"""Apply a string replacer to all column reference terms in a NormalizedExpr.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
|
|
191
|
+
expr: NormalizedExpr whose MulGroup terms should be transformed.
|
|
192
|
+
|
|
193
|
+
replacer: Function that takes a term string and returns the replacement.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
|
|
197
|
+
New NormalizedExpr with all multiply and divide terms transformed.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def _replace_in_group(g: MulGroup) -> MulGroup:
|
|
201
|
+
return MulGroup(
|
|
202
|
+
coefficient=g.coefficient,
|
|
203
|
+
multiply=[replacer(m) for m in g.multiply],
|
|
204
|
+
divide=[replacer(d) for d in g.divide],
|
|
205
|
+
agg_func=g.agg_func,
|
|
206
|
+
scalar_func=g.scalar_func,
|
|
207
|
+
inner_scalar_func=g.inner_scalar_func,
|
|
208
|
+
scalar_func_args=list(g.scalar_func_args),
|
|
209
|
+
inner_scalar_func_args=list(g.inner_scalar_func_args),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return NormalizedExpr(
|
|
213
|
+
add_groups=[_replace_in_group(g) for g in expr.add_groups],
|
|
214
|
+
sub_groups=[_replace_in_group(g) for g in expr.sub_groups],
|
|
215
|
+
add_values=expr.add_values,
|
|
216
|
+
sub_values=expr.sub_values,
|
|
217
|
+
agg_func=expr.agg_func,
|
|
218
|
+
scalar_func=expr.scalar_func,
|
|
219
|
+
inner_scalar_func=expr.inner_scalar_func,
|
|
220
|
+
scalar_func_args=list(expr.scalar_func_args),
|
|
221
|
+
inner_scalar_func_args=list(expr.inner_scalar_func_args),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _split_additive(s: str) -> list[tuple[str, str]]:
|
|
226
|
+
"""Split an expression string by top-level + and - operators.
|
|
227
|
+
|
|
228
|
+
Respects parenthesis nesting so operators inside function calls are ignored.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
|
|
232
|
+
s: SQL expression string to split.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
|
|
236
|
+
List of (sign, term) tuples where sign is '+' or '-'.
|
|
237
|
+
"""
|
|
238
|
+
s = s.strip()
|
|
239
|
+
if not s:
|
|
240
|
+
return [("+", "")]
|
|
241
|
+
result: list[tuple[str, str]] = []
|
|
242
|
+
depth = 0
|
|
243
|
+
sign = "+"
|
|
244
|
+
start = 0
|
|
245
|
+
i = 0
|
|
246
|
+
if s[0] in ("+", "-"):
|
|
247
|
+
sign = s[0]
|
|
248
|
+
start = 1
|
|
249
|
+
i = 1
|
|
250
|
+
while i < len(s):
|
|
251
|
+
c = s[i]
|
|
252
|
+
if c == "(":
|
|
253
|
+
depth += 1
|
|
254
|
+
elif c == ")":
|
|
255
|
+
depth -= 1
|
|
256
|
+
elif depth == 0 and c in ("+", "-"):
|
|
257
|
+
term = s[start:i].strip()
|
|
258
|
+
if term:
|
|
259
|
+
result.append((sign, term))
|
|
260
|
+
sign = c
|
|
261
|
+
start = i + 1
|
|
262
|
+
i += 1
|
|
263
|
+
term = s[start:].strip()
|
|
264
|
+
if term:
|
|
265
|
+
result.append((sign, term))
|
|
266
|
+
return result or [("+", s)]
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _split_multiplicative(s: str) -> tuple[list[str], list[str]]:
|
|
270
|
+
"""Split an expression string by top-level * and / operators.
|
|
271
|
+
|
|
272
|
+
Respects parenthesis nesting so operators inside function calls are ignored.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
|
|
276
|
+
s: SQL expression string to split.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
|
|
280
|
+
Tuple of (multiply_parts, divide_parts) where each element is a list of operand strings on either side of the respective operator.
|
|
281
|
+
"""
|
|
282
|
+
s = s.strip()
|
|
283
|
+
if not s:
|
|
284
|
+
return ([""], [])
|
|
285
|
+
multiply: list[str] = []
|
|
286
|
+
divide: list[str] = []
|
|
287
|
+
depth = 0
|
|
288
|
+
op = "*"
|
|
289
|
+
start = 0
|
|
290
|
+
for i, c in enumerate(s):
|
|
291
|
+
if c == "(":
|
|
292
|
+
depth += 1
|
|
293
|
+
elif c == ")":
|
|
294
|
+
depth -= 1
|
|
295
|
+
elif depth == 0 and c in ("*", "/"):
|
|
296
|
+
part = s[start:i].strip()
|
|
297
|
+
if part:
|
|
298
|
+
(multiply if op == "*" else divide).append(part)
|
|
299
|
+
op = c
|
|
300
|
+
start = i + 1
|
|
301
|
+
part = s[start:].strip()
|
|
302
|
+
if part:
|
|
303
|
+
(multiply if op == "*" else divide).append(part)
|
|
304
|
+
return multiply, divide
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _parse_term(term_str: str) -> MulGroup | ExprValue:
|
|
308
|
+
"""Parse a single additive term into a MulGroup or a literal ExprValue.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
|
|
312
|
+
term_str: A single additive term string with no top-level plus or minus operator.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
|
|
316
|
+
ExprValue if the term is a plain numeric literal, otherwise a MulGroup describing the function layers, multiplicative operands, and coefficient.
|
|
317
|
+
"""
|
|
318
|
+
term_str = term_str.strip()
|
|
319
|
+
try:
|
|
320
|
+
val = float(term_str)
|
|
321
|
+
return ExprValue(value=val)
|
|
322
|
+
except ValueError:
|
|
323
|
+
pass
|
|
324
|
+
layers: list[tuple[str, int | float | None]] = []
|
|
325
|
+
current = term_str
|
|
326
|
+
while True:
|
|
327
|
+
name, args, body = _peel_function(current)
|
|
328
|
+
if not name:
|
|
329
|
+
break
|
|
330
|
+
layers.append((name.lower(), args))
|
|
331
|
+
current = body
|
|
332
|
+
agg_func = ""
|
|
333
|
+
scalar_func = ""
|
|
334
|
+
scalar_func_args: list | None = None
|
|
335
|
+
inner_scalar_func = ""
|
|
336
|
+
inner_scalar_func_args: list | None = None
|
|
337
|
+
for name, args in layers:
|
|
338
|
+
if name in VALID_AGG_FUNCS:
|
|
339
|
+
agg_func = name
|
|
340
|
+
elif not scalar_func and not agg_func:
|
|
341
|
+
scalar_func = name
|
|
342
|
+
scalar_func_args = [args] if args is not None else []
|
|
343
|
+
else:
|
|
344
|
+
inner_scalar_func = name
|
|
345
|
+
inner_scalar_func_args = [args] if args is not None else []
|
|
346
|
+
mul_parts, div_parts = _split_multiplicative(current)
|
|
347
|
+
multiply: list[str] = []
|
|
348
|
+
divide: list[str] = []
|
|
349
|
+
coefficient = 1.0
|
|
350
|
+
for p in mul_parts:
|
|
351
|
+
try:
|
|
352
|
+
coefficient *= float(p)
|
|
353
|
+
except ValueError:
|
|
354
|
+
multiply.append(p)
|
|
355
|
+
for p in div_parts:
|
|
356
|
+
try:
|
|
357
|
+
coefficient /= float(p)
|
|
358
|
+
except ValueError:
|
|
359
|
+
divide.append(p)
|
|
360
|
+
return MulGroup(
|
|
361
|
+
multiply=multiply or ["*"],
|
|
362
|
+
divide=divide,
|
|
363
|
+
coefficient=coefficient,
|
|
364
|
+
agg_func=agg_func,
|
|
365
|
+
scalar_func=scalar_func,
|
|
366
|
+
inner_scalar_func=inner_scalar_func,
|
|
367
|
+
scalar_func_args=scalar_func_args or [],
|
|
368
|
+
inner_scalar_func_args=inner_scalar_func_args or [],
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _is_trivial_zero(term: str) -> bool:
|
|
373
|
+
"""Return True if the term string is a numeric literal equal to zero.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
|
|
377
|
+
term: Additive term string.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
|
|
381
|
+
True when float(term) is 0.0 and False for non-numeric or non-zero strings.
|
|
382
|
+
"""
|
|
383
|
+
try:
|
|
384
|
+
return float(term.strip()) == 0.0
|
|
385
|
+
except ValueError:
|
|
386
|
+
return False
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def parse_expr_string(expr_str: str | dict) -> NormalizedExpr:
|
|
390
|
+
"""Parse an SQL expression string into a structured NormalizedExpr.
|
|
391
|
+
|
|
392
|
+
Handles nested function wrappers, additive and multiplicative decomposition, numeric coefficients, and aggregation or scalar function classification.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
|
|
396
|
+
expr_str: SQL expression string as returned by the LLM, or a dict with an "expr" key.
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
|
|
400
|
+
NormalizedExpr representing the expression tree.
|
|
401
|
+
"""
|
|
402
|
+
if isinstance(expr_str, dict):
|
|
403
|
+
expr_str = expr_str.get("expr", "")
|
|
404
|
+
if not isinstance(expr_str, str):
|
|
405
|
+
expr_str = str(expr_str) if expr_str else ""
|
|
406
|
+
expr_str = expr_str.strip()
|
|
407
|
+
if not expr_str:
|
|
408
|
+
return NormalizedExpr()
|
|
409
|
+
outer_layers: list[tuple[str, int | float | str | None]] = []
|
|
410
|
+
inner = expr_str
|
|
411
|
+
while True:
|
|
412
|
+
name, args, body = _peel_function(inner)
|
|
413
|
+
if not name:
|
|
414
|
+
break
|
|
415
|
+
outer_layers.append((name.lower(), args))
|
|
416
|
+
inner = body
|
|
417
|
+
additive = _split_additive(inner)
|
|
418
|
+
additive = [(s, t) for s, t in additive if not _is_trivial_zero(t)]
|
|
419
|
+
if not additive:
|
|
420
|
+
additive = [("+", "0")]
|
|
421
|
+
if len(additive) == 1:
|
|
422
|
+
sign, t = additive[0]
|
|
423
|
+
parsed = _parse_term(t)
|
|
424
|
+
if isinstance(parsed, ExprValue):
|
|
425
|
+
if sign == "-":
|
|
426
|
+
return NormalizedExpr(sub_values=[ExprValue(value=abs(parsed.value))])
|
|
427
|
+
return NormalizedExpr(add_values=[parsed])
|
|
428
|
+
group = parsed
|
|
429
|
+
for name, args in outer_layers:
|
|
430
|
+
if name in VALID_AGG_FUNCS:
|
|
431
|
+
group.agg_func = name
|
|
432
|
+
elif not group.scalar_func and not group.agg_func:
|
|
433
|
+
group.scalar_func = name
|
|
434
|
+
group.scalar_func_args = [args] if args is not None else []
|
|
435
|
+
elif not group.scalar_func and group.agg_func:
|
|
436
|
+
group.scalar_func = name
|
|
437
|
+
group.scalar_func_args = [args] if args is not None else []
|
|
438
|
+
else:
|
|
439
|
+
group.inner_scalar_func = name
|
|
440
|
+
group.inner_scalar_func_args = [args] if args is not None else []
|
|
441
|
+
if sign == "-":
|
|
442
|
+
return NormalizedExpr(sub_groups=[group])
|
|
443
|
+
return NormalizedExpr(add_groups=[group])
|
|
444
|
+
add_groups: list[MulGroup] = []
|
|
445
|
+
sub_groups: list[MulGroup] = []
|
|
446
|
+
add_values: list[ExprValue] = []
|
|
447
|
+
sub_values: list[ExprValue] = []
|
|
448
|
+
for sign, t in additive:
|
|
449
|
+
parsed = _parse_term(t)
|
|
450
|
+
if isinstance(parsed, ExprValue):
|
|
451
|
+
(add_values if sign == "+" else sub_values).append(parsed)
|
|
452
|
+
else:
|
|
453
|
+
(add_groups if sign == "+" else sub_groups).append(parsed)
|
|
454
|
+
expr_agg = ""
|
|
455
|
+
expr_scalar = ""
|
|
456
|
+
expr_scalar_args: list | None = None
|
|
457
|
+
expr_inner = ""
|
|
458
|
+
expr_inner_args: list | None = None
|
|
459
|
+
for name, args in outer_layers:
|
|
460
|
+
if name in VALID_AGG_FUNCS:
|
|
461
|
+
expr_agg = name
|
|
462
|
+
elif not expr_scalar and not expr_agg:
|
|
463
|
+
expr_scalar = name
|
|
464
|
+
expr_scalar_args = [args] if args is not None else []
|
|
465
|
+
elif not expr_scalar and expr_agg:
|
|
466
|
+
expr_scalar = name
|
|
467
|
+
expr_scalar_args = [args] if args is not None else []
|
|
468
|
+
else:
|
|
469
|
+
expr_inner = name
|
|
470
|
+
expr_inner_args = [args] if args is not None else []
|
|
471
|
+
return NormalizedExpr(
|
|
472
|
+
add_groups=add_groups,
|
|
473
|
+
sub_groups=sub_groups,
|
|
474
|
+
add_values=add_values,
|
|
475
|
+
sub_values=sub_values,
|
|
476
|
+
agg_func=expr_agg,
|
|
477
|
+
scalar_func=expr_scalar,
|
|
478
|
+
inner_scalar_func=expr_inner,
|
|
479
|
+
scalar_func_args=expr_scalar_args or [],
|
|
480
|
+
inner_scalar_func_args=expr_inner_args or [],
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _is_expr_numeric(expr: NormalizedExpr, schema: SchemaGraph) -> bool:
|
|
485
|
+
"""Determine whether an expression produces a numeric result.
|
|
486
|
+
|
|
487
|
+
Checks aggregation functions known to return numbers, scalar functions that return numeric types, and the value_type of the underlying column in the schema.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
|
|
491
|
+
expr: NormalizedExpr to classify.
|
|
492
|
+
schema: SchemaGraph used to look up column value_type metadata.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
|
|
496
|
+
True when the expression is expected to yield a numeric value.
|
|
497
|
+
"""
|
|
498
|
+
if expr.agg_func and expr.agg_func in NUMERIC_RESULT_AGGS:
|
|
499
|
+
return True
|
|
500
|
+
if expr.scalar_func and expr.scalar_func in NUMERIC_RESULT_SCALARS:
|
|
501
|
+
return True
|
|
502
|
+
if expr.inner_scalar_func and expr.inner_scalar_func in NUMERIC_RESULT_SCALARS:
|
|
503
|
+
return True
|
|
504
|
+
for g in expr.add_groups + expr.sub_groups:
|
|
505
|
+
if g.agg_func and g.agg_func in NUMERIC_RESULT_AGGS:
|
|
506
|
+
return True
|
|
507
|
+
if g.scalar_func and g.scalar_func in NUMERIC_RESULT_SCALARS:
|
|
508
|
+
return True
|
|
509
|
+
if g.inner_scalar_func and g.inner_scalar_func in NUMERIC_RESULT_SCALARS:
|
|
510
|
+
return True
|
|
511
|
+
col = expr.primary_column
|
|
512
|
+
if col and "." in col:
|
|
513
|
+
table, col_name = col.rsplit(".", 1)
|
|
514
|
+
if table in schema.tables:
|
|
515
|
+
meta = schema.tables[table].columns.get(col_name) or schema.tables[table].columns.get(col_name.lower())
|
|
516
|
+
if meta and meta.value_type:
|
|
517
|
+
return meta.value_type in ("integer", "number")
|
|
518
|
+
return len(expr.add_groups) + len(expr.sub_groups) > 1
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def _tag_single_expr(expr: NormalizedExpr, schema: SchemaGraph, skip_value_injection: bool = False) -> NormalizedExpr:
|
|
522
|
+
"""Set is_numeric, inject offset for numeric expressions, and sanitize non-numeric ones.
|
|
523
|
+
|
|
524
|
+
For numeric expressions that have groups but no offset value, injects an ExprValue(0.0) to anchor the signature unless skip_value_injection is True. For non-numeric expressions, strips any coefficient or coeff_param_key and removes add_values and sub_values.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
|
|
528
|
+
expr: NormalizedExpr to tag.
|
|
529
|
+
schema: SchemaGraph used to determine numeric classification.
|
|
530
|
+
skip_value_injection: When True, numeric expressions are tagged but no ExprValue(0.0) offset is injected and this is used for filter and having left_expr where values are handled via param_key instead.
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
|
|
534
|
+
Updated NormalizedExpr with is_numeric set and appropriate sanitization applied.
|
|
535
|
+
"""
|
|
536
|
+
numeric = _is_expr_numeric(expr, schema)
|
|
537
|
+
if numeric:
|
|
538
|
+
if skip_value_injection:
|
|
539
|
+
return replace(expr, is_numeric=True)
|
|
540
|
+
need_offset = expr.add_groups and not expr.add_values
|
|
541
|
+
return replace(
|
|
542
|
+
expr,
|
|
543
|
+
is_numeric=True,
|
|
544
|
+
add_values=[ExprValue(value=0.0)] if need_offset else expr.add_values,
|
|
545
|
+
)
|
|
546
|
+
sanitized_groups = [
|
|
547
|
+
(replace(g, coefficient=1.0, coeff_param_key="") if g.coefficient != 1.0 or g.coeff_param_key else g)
|
|
548
|
+
for g in expr.add_groups
|
|
549
|
+
]
|
|
550
|
+
sanitized_sub_groups = [
|
|
551
|
+
(replace(g, coefficient=1.0, coeff_param_key="") if g.coefficient != 1.0 or g.coeff_param_key else g)
|
|
552
|
+
for g in expr.sub_groups
|
|
553
|
+
]
|
|
554
|
+
return replace(
|
|
555
|
+
expr,
|
|
556
|
+
is_numeric=False,
|
|
557
|
+
add_groups=sanitized_groups,
|
|
558
|
+
sub_groups=sanitized_sub_groups,
|
|
559
|
+
add_values=[],
|
|
560
|
+
sub_values=[],
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def tag_expr_numeric(intent: RuntimeIntent, schema: SchemaGraph) -> RuntimeIntent:
|
|
565
|
+
"""Apply is_numeric tagging and offset injection to all expressions in a RuntimeIntent.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
|
|
569
|
+
intent: RuntimeIntent whose select, order_by, group_by, filter, having, and CTE expressions should be tagged.
|
|
570
|
+
schema: SchemaGraph used for column type lookups.
|
|
571
|
+
|
|
572
|
+
Returns:
|
|
573
|
+
|
|
574
|
+
New RuntimeIntent with all NormalizedExprs tagged and sanitized.
|
|
575
|
+
"""
|
|
576
|
+
select_cols = [replace(sc, expr=_tag_single_expr(sc.expr, schema)) for sc in (intent.select_cols or [])]
|
|
577
|
+
order_by_cols = [replace(obc, expr=_tag_single_expr(obc.expr, schema)) for obc in (intent.order_by_cols or [])]
|
|
578
|
+
group_by_cols = [_tag_single_expr(g, schema) for g in (intent.group_by_cols or [])]
|
|
579
|
+
filters_param = []
|
|
580
|
+
for fp in intent.filters_param or []:
|
|
581
|
+
left = _tag_single_expr(fp.left_expr, schema, skip_value_injection=True)
|
|
582
|
+
right = _tag_single_expr(fp.right_expr, schema) if fp.right_expr else None
|
|
583
|
+
filters_param.append(replace(fp, left_expr=left, right_expr=right))
|
|
584
|
+
having_param = []
|
|
585
|
+
for hp in intent.having_param or []:
|
|
586
|
+
left = _tag_single_expr(hp.left_expr, schema, skip_value_injection=True)
|
|
587
|
+
right = _tag_single_expr(hp.right_expr, schema) if hp.right_expr else None
|
|
588
|
+
having_param.append(replace(hp, left_expr=left, right_expr=right))
|
|
589
|
+
cte_steps = []
|
|
590
|
+
for cte in intent.cte_steps or []:
|
|
591
|
+
cte_sc = [replace(sc, expr=_tag_single_expr(sc.expr, schema)) for sc in (cte.select_cols or [])]
|
|
592
|
+
cte_obc = [replace(obc, expr=_tag_single_expr(obc.expr, schema)) for obc in (cte.order_by_cols or [])]
|
|
593
|
+
cte_gb = [_tag_single_expr(g, schema) for g in (cte.group_by_cols or [])]
|
|
594
|
+
cte_fp = []
|
|
595
|
+
for fp in cte.filters_param or []:
|
|
596
|
+
left = _tag_single_expr(fp.left_expr, schema, skip_value_injection=True)
|
|
597
|
+
right = _tag_single_expr(fp.right_expr, schema) if fp.right_expr else None
|
|
598
|
+
cte_fp.append(replace(fp, left_expr=left, right_expr=right))
|
|
599
|
+
cte_hp = []
|
|
600
|
+
for hp in cte.having_param or []:
|
|
601
|
+
left = _tag_single_expr(hp.left_expr, schema, skip_value_injection=True)
|
|
602
|
+
right = _tag_single_expr(hp.right_expr, schema) if hp.right_expr else None
|
|
603
|
+
cte_hp.append(replace(hp, left_expr=left, right_expr=right))
|
|
604
|
+
cte_steps.append(
|
|
605
|
+
replace(
|
|
606
|
+
cte,
|
|
607
|
+
select_cols=cte_sc,
|
|
608
|
+
order_by_cols=cte_obc,
|
|
609
|
+
group_by_cols=cte_gb,
|
|
610
|
+
filters_param=cte_fp,
|
|
611
|
+
having_param=cte_hp,
|
|
612
|
+
)
|
|
613
|
+
)
|
|
614
|
+
return replace(
|
|
615
|
+
intent,
|
|
616
|
+
select_cols=select_cols,
|
|
617
|
+
order_by_cols=order_by_cols,
|
|
618
|
+
group_by_cols=group_by_cols,
|
|
619
|
+
filters_param=filters_param,
|
|
620
|
+
having_param=having_param,
|
|
621
|
+
cte_steps=cte_steps,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def _classify_cte_expr(expr: NormalizedExpr) -> str:
|
|
626
|
+
"""Classify a CTE select expression by its structural type.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
|
|
630
|
+
expr: NormalizedExpr from a CTE SelectCol.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
|
|
634
|
+
One of 'passthrough', 'aggregation', 'scalar', or 'computed'.
|
|
635
|
+
"""
|
|
636
|
+
agg = expr.agg_func or (expr.add_groups[0].agg_func if expr.add_groups else "")
|
|
637
|
+
has_arithmetic = (
|
|
638
|
+
len(expr.add_groups) + len(expr.sub_groups) > 1
|
|
639
|
+
or expr.add_values
|
|
640
|
+
or expr.sub_values
|
|
641
|
+
or any(g.divide or len(g.multiply) > 1 or g.coefficient != 1.0 for g in expr.add_groups)
|
|
642
|
+
)
|
|
643
|
+
if agg:
|
|
644
|
+
return "aggregation"
|
|
645
|
+
if has_arithmetic:
|
|
646
|
+
return "computed"
|
|
647
|
+
scalar = expr.scalar_func or (expr.add_groups[0].scalar_func if expr.add_groups else "")
|
|
648
|
+
if scalar:
|
|
649
|
+
return "scalar"
|
|
650
|
+
return "passthrough"
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def derive_cte_output_columns(select_cols: list[SelectCol]) -> list[str]:
|
|
654
|
+
"""Derive deterministic CTE output column aliases from parsed select expressions.
|
|
655
|
+
|
|
656
|
+
Generates names like sum_amount, count_star, or expr1 based on the expression type and base column name and deduplicates by appending a numeric suffix.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
|
|
660
|
+
select_cols: List of SelectCol objects from a CTE step.
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
|
|
664
|
+
List of lowercase alias strings, one per SelectCol.
|
|
665
|
+
"""
|
|
666
|
+
derived: list[str] = []
|
|
667
|
+
seen: dict[str, int] = {}
|
|
668
|
+
expr_counter = 0
|
|
669
|
+
for sc in select_cols:
|
|
670
|
+
expr = sc.expr
|
|
671
|
+
kind = _classify_cte_expr(expr)
|
|
672
|
+
agg = expr.agg_func or (expr.add_groups[0].agg_func if expr.add_groups else "")
|
|
673
|
+
scalar = expr.scalar_func or (expr.add_groups[0].scalar_func if expr.add_groups else "")
|
|
674
|
+
base = expr.primary_column
|
|
675
|
+
if kind == "aggregation":
|
|
676
|
+
if base == "*":
|
|
677
|
+
name = "row_count" if agg == "count" else f"{agg}_star"
|
|
678
|
+
elif "." in base:
|
|
679
|
+
bare = base.split(".", 1)[1]
|
|
680
|
+
name = f"{agg}_{bare}"
|
|
681
|
+
else:
|
|
682
|
+
name = f"{agg}_{base}"
|
|
683
|
+
elif kind == "scalar":
|
|
684
|
+
if "." in base:
|
|
685
|
+
bare = base.split(".", 1)[1]
|
|
686
|
+
name = f"{scalar}_{bare}"
|
|
687
|
+
else:
|
|
688
|
+
name = f"{scalar}_{base}"
|
|
689
|
+
elif kind == "computed":
|
|
690
|
+
expr_counter += 1
|
|
691
|
+
name = f"expr{expr_counter}"
|
|
692
|
+
else:
|
|
693
|
+
if "." in base:
|
|
694
|
+
name = base.split(".", 1)[1]
|
|
695
|
+
else:
|
|
696
|
+
name = base
|
|
697
|
+
name = name.lower()
|
|
698
|
+
if name in seen:
|
|
699
|
+
seen[name] += 1
|
|
700
|
+
name = f"{name}_{seen[name]}"
|
|
701
|
+
else:
|
|
702
|
+
seen[name] = 1
|
|
703
|
+
derived.append(name)
|
|
704
|
+
return derived
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def build_cte_output_metadata(
|
|
708
|
+
select_cols: list[SelectCol], output_columns: list[str], schema: SchemaGraph
|
|
709
|
+
) -> dict[str, CteOutputColumnMeta]:
|
|
710
|
+
"""Build CteOutputColumnMeta for each CTE output column.
|
|
711
|
+
|
|
712
|
+
Infers column role, data type, and allowed operations from the expression kind (passthrough, aggregation, scalar, or computed) and the source column schema metadata.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
|
|
716
|
+
select_cols: CTE select expressions.
|
|
717
|
+
output_columns: Derived alias names aligned with select_cols.
|
|
718
|
+
schema: SchemaGraph for source column lookups.
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
|
|
722
|
+
Dictionary mapping output column alias to its CteOutputColumnMeta.
|
|
723
|
+
"""
|
|
724
|
+
result: dict[str, CteOutputColumnMeta] = {}
|
|
725
|
+
for i, sc in enumerate(select_cols):
|
|
726
|
+
if i >= len(output_columns):
|
|
727
|
+
break
|
|
728
|
+
out_col = output_columns[i]
|
|
729
|
+
expr = sc.expr
|
|
730
|
+
kind = _classify_cte_expr(expr)
|
|
731
|
+
agg = expr.agg_func or (expr.add_groups[0].agg_func if expr.add_groups else "")
|
|
732
|
+
scalar = expr.scalar_func or (expr.add_groups[0].scalar_func if expr.add_groups else "")
|
|
733
|
+
base_col = expr.primary_column
|
|
734
|
+
src_meta = None
|
|
735
|
+
base_type = "unknown"
|
|
736
|
+
if "." in base_col and base_col != "*":
|
|
737
|
+
tbl, col = base_col.split(".", 1)
|
|
738
|
+
if tbl in schema.tables:
|
|
739
|
+
src_meta = schema.tables[tbl].columns.get(col) or schema.tables[tbl].columns.get(col.lower())
|
|
740
|
+
if src_meta and src_meta.data_type:
|
|
741
|
+
base_type = src_meta.data_type.lower().split("(")[0].strip()
|
|
742
|
+
if kind == "passthrough":
|
|
743
|
+
role = src_meta.role if src_meta else None
|
|
744
|
+
data_type = base_type
|
|
745
|
+
filterable = src_meta.is_filterable if src_meta else True
|
|
746
|
+
aggregatable = src_meta.is_aggregatable if src_meta else False
|
|
747
|
+
groupable = src_meta.is_groupable if src_meta else True
|
|
748
|
+
vf_ops = list(src_meta.get_valid_filter_ops()) if src_meta else list(CTE_NUMERIC_FILTER_OPS)
|
|
749
|
+
v_aggs = sorted(src_meta.get_valid_aggregations()) if src_meta else list(CTE_DEFAULT_AGGS)
|
|
750
|
+
vh_ops = list(src_meta.get_valid_having_ops()) if src_meta else list(CTE_HAVING_COMPARE_OPS)
|
|
751
|
+
elif kind == "aggregation":
|
|
752
|
+
role = "numeric_measure"
|
|
753
|
+
if agg == "count":
|
|
754
|
+
data_type = "integer"
|
|
755
|
+
elif agg == "avg":
|
|
756
|
+
data_type = "numeric"
|
|
757
|
+
elif agg in ("sum", "min", "max") and base_type != "unknown":
|
|
758
|
+
data_type = base_type
|
|
759
|
+
else:
|
|
760
|
+
data_type = "integer" if agg == "count" else "numeric"
|
|
761
|
+
filterable = True
|
|
762
|
+
aggregatable = True
|
|
763
|
+
groupable = False
|
|
764
|
+
vf_ops = list(CTE_NUMERIC_FILTER_OPS)
|
|
765
|
+
v_aggs = list(CTE_FULL_AGGS)
|
|
766
|
+
vh_ops = list(CTE_HAVING_COMPARE_OPS)
|
|
767
|
+
elif kind == "scalar":
|
|
768
|
+
if scalar in NUMERIC_RESULT_SCALARS:
|
|
769
|
+
role = "numeric_measure"
|
|
770
|
+
data_type = "integer" if scalar in INTEGER_SCALARS else "numeric"
|
|
771
|
+
aggregatable = True
|
|
772
|
+
groupable = False
|
|
773
|
+
vf_ops = list(CTE_NUMERIC_FILTER_OPS)
|
|
774
|
+
v_aggs = list(CTE_FULL_AGGS)
|
|
775
|
+
vh_ops = list(CTE_HAVING_COMPARE_OPS)
|
|
776
|
+
else:
|
|
777
|
+
role = src_meta.role if src_meta else None
|
|
778
|
+
data_type = base_type
|
|
779
|
+
aggregatable = src_meta.is_aggregatable if src_meta else False
|
|
780
|
+
groupable = src_meta.is_groupable if src_meta else True
|
|
781
|
+
vf_ops = list(src_meta.get_valid_filter_ops()) if src_meta else list(CTE_NUMERIC_FILTER_OPS)
|
|
782
|
+
v_aggs = sorted(src_meta.get_valid_aggregations()) if src_meta else list(CTE_DEFAULT_AGGS)
|
|
783
|
+
vh_ops = list(src_meta.get_valid_having_ops()) if src_meta else list(CTE_HAVING_COMPARE_OPS)
|
|
784
|
+
filterable = True
|
|
785
|
+
else:
|
|
786
|
+
role = "numeric_measure"
|
|
787
|
+
data_type = "numeric"
|
|
788
|
+
filterable = True
|
|
789
|
+
aggregatable = True
|
|
790
|
+
groupable = False
|
|
791
|
+
vf_ops = list(CTE_NUMERIC_FILTER_OPS)
|
|
792
|
+
v_aggs = list(CTE_FULL_AGGS)
|
|
793
|
+
vh_ops = list(CTE_HAVING_COMPARE_OPS)
|
|
794
|
+
result[out_col] = CteOutputColumnMeta(
|
|
795
|
+
source=kind,
|
|
796
|
+
base_column=base_col,
|
|
797
|
+
agg_func=agg,
|
|
798
|
+
role=role,
|
|
799
|
+
filterable=filterable,
|
|
800
|
+
aggregatable=aggregatable,
|
|
801
|
+
data_type=data_type,
|
|
802
|
+
groupable=groupable,
|
|
803
|
+
valid_filter_ops=vf_ops,
|
|
804
|
+
valid_aggregations=v_aggs,
|
|
805
|
+
valid_having_ops=vh_ops,
|
|
806
|
+
)
|
|
807
|
+
return result
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def _strip_angle_brackets(obj: Any) -> Any:
|
|
811
|
+
"""Recursively strip angle-bracket placeholders from all string values in a parsed LLM output.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
|
|
815
|
+
obj: Parsed JSON value such as a dict, list, string, or other type.
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
|
|
819
|
+
The same structure with <word> placeholders replaced by word.
|
|
820
|
+
"""
|
|
821
|
+
if isinstance(obj, str):
|
|
822
|
+
return re.sub(r"<(\w+)>", r"\1", obj)
|
|
823
|
+
if isinstance(obj, list):
|
|
824
|
+
return [_strip_angle_brackets(item) for item in obj]
|
|
825
|
+
if isinstance(obj, dict):
|
|
826
|
+
return {k: _strip_angle_brackets(v) for k, v in obj.items()}
|
|
827
|
+
return obj
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def _normalize_order_direction(direction: str) -> str:
|
|
831
|
+
"""Return "asc" or "desc" from a direction string."""
|
|
832
|
+
if not isinstance(direction, str):
|
|
833
|
+
return "asc"
|
|
834
|
+
d = direction.strip().lower()
|
|
835
|
+
return "desc" if "desc" in d else "asc"
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def _strip_order_direction(expr_str: str) -> tuple[str, str]:
|
|
839
|
+
"""Strip trailing ASC/DESC from an order-by expression and return direction.
|
|
840
|
+
|
|
841
|
+
Returns (expr_without_direction, "asc" or "desc"). Default direction is asc.
|
|
842
|
+
"""
|
|
843
|
+
s = expr_str.strip()
|
|
844
|
+
if not s:
|
|
845
|
+
return ("", "asc")
|
|
846
|
+
upper = s.upper()
|
|
847
|
+
if upper.endswith(" DESC"):
|
|
848
|
+
return (s[: -5].strip(), "desc")
|
|
849
|
+
if upper.endswith(" ASC"):
|
|
850
|
+
return (s[: -4].strip(), "asc")
|
|
851
|
+
return (s, "asc")
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def _validate_intent_schema(parsed: dict[str, Any]) -> bool:
|
|
855
|
+
"""Validate parsed intent dict against INTENT_SCHEMA.
|
|
856
|
+
|
|
857
|
+
Args:
|
|
858
|
+
|
|
859
|
+
parsed: Parsed JSON dict from an LLM response.
|
|
860
|
+
|
|
861
|
+
Returns:
|
|
862
|
+
|
|
863
|
+
True if the structure is valid for the schema and False otherwise.
|
|
864
|
+
"""
|
|
865
|
+
try:
|
|
866
|
+
jsonschema.validate(instance=parsed, schema=INTENT_SCHEMA)
|
|
867
|
+
return True
|
|
868
|
+
except jsonschema.ValidationError:
|
|
869
|
+
return False
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def parse_intent_response(raw: str, question: str) -> RuntimeIntent | None:
|
|
873
|
+
"""Parse a raw LLM JSON response into a RuntimeIntent.
|
|
874
|
+
|
|
875
|
+
Validates structure against INTENT_SCHEMA before parsing and returns None if the JSON is unparseable or fails schema validation.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
|
|
879
|
+
raw: Raw LLM response string expected to contain intent JSON.
|
|
880
|
+
question: Original question used as fallback for natural_language.
|
|
881
|
+
|
|
882
|
+
Returns:
|
|
883
|
+
|
|
884
|
+
RuntimeIntent on success, or None if the response is unparseable or invalid.
|
|
885
|
+
"""
|
|
886
|
+
parsed = safe_json_loads(raw)
|
|
887
|
+
if not parsed or not isinstance(parsed, dict):
|
|
888
|
+
return None
|
|
889
|
+
parsed = _strip_angle_brackets(parsed)
|
|
890
|
+
if not _validate_intent_schema(parsed):
|
|
891
|
+
debug("[intent_expr.parse_intent_response] schema validation failed")
|
|
892
|
+
return None
|
|
893
|
+
|
|
894
|
+
tables = parsed.get("tables", [])
|
|
895
|
+
if isinstance(tables, str):
|
|
896
|
+
tables = [tables]
|
|
897
|
+
|
|
898
|
+
select_cols_raw = parsed.get("select_cols", [])
|
|
899
|
+
select_cols = []
|
|
900
|
+
for sc in select_cols_raw:
|
|
901
|
+
if isinstance(sc, str):
|
|
902
|
+
select_cols.append(SelectCol(expr=parse_expr_string(sc)))
|
|
903
|
+
elif isinstance(sc, dict):
|
|
904
|
+
expr_str = sc.get("expr", "")
|
|
905
|
+
select_cols.append(
|
|
906
|
+
SelectCol(
|
|
907
|
+
expr=parse_expr_string(expr_str),
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
group_by_cols_raw = parsed.get("group_by_cols", [])
|
|
912
|
+
if isinstance(group_by_cols_raw, str):
|
|
913
|
+
group_by_cols_raw = [group_by_cols_raw]
|
|
914
|
+
group_by_cols = [parse_expr_string(g) for g in group_by_cols_raw]
|
|
915
|
+
|
|
916
|
+
order_by_cols_raw = parsed.get("order_by_cols", [])
|
|
917
|
+
order_by_cols = []
|
|
918
|
+
for obc in order_by_cols_raw:
|
|
919
|
+
if isinstance(obc, str):
|
|
920
|
+
expr_clean, direction = _strip_order_direction(obc)
|
|
921
|
+
order_by_cols.append(OrderByCol(expr=parse_expr_string(expr_clean), direction=direction))
|
|
922
|
+
elif isinstance(obc, dict):
|
|
923
|
+
expr_str = obc.get("expr", "")
|
|
924
|
+
expr_clean, dir_from_expr = _strip_order_direction(expr_str)
|
|
925
|
+
direction = obc.get("direction") or dir_from_expr or "asc"
|
|
926
|
+
direction = _normalize_order_direction(direction)
|
|
927
|
+
order_by_cols.append(OrderByCol(expr=parse_expr_string(expr_clean), direction=direction))
|
|
928
|
+
|
|
929
|
+
filters_param_raw = parsed.get("filters_param", [])
|
|
930
|
+
filters_param = []
|
|
931
|
+
for fp in filters_param_raw:
|
|
932
|
+
if isinstance(fp, dict):
|
|
933
|
+
left_str = fp.get("left_expr") or fp.get("left_col") or ""
|
|
934
|
+
if not left_str:
|
|
935
|
+
continue
|
|
936
|
+
right_str = fp.get("right_expr") or fp.get("right_col") or ""
|
|
937
|
+
if right_str and "." not in right_str and not _is_date_or_interval_expr(right_str):
|
|
938
|
+
right_str = ""
|
|
939
|
+
fg_raw = fp.get("filter_group")
|
|
940
|
+
filters_param.append(
|
|
941
|
+
FilterParam(
|
|
942
|
+
left_expr=parse_expr_string(left_str),
|
|
943
|
+
op=normalize_op(fp.get("op", "=")),
|
|
944
|
+
right_expr=parse_expr_string(right_str) if right_str else None,
|
|
945
|
+
value_type=fp.get("value_type", "string"),
|
|
946
|
+
param_key="",
|
|
947
|
+
raw_value=fp.get("value"),
|
|
948
|
+
bool_op=fp.get("bool_op", "AND"),
|
|
949
|
+
filter_group=int(fg_raw) if fg_raw is not None else None,
|
|
950
|
+
)
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
having_param_raw = parsed.get("having_param", [])
|
|
954
|
+
having_param = []
|
|
955
|
+
for hp in having_param_raw:
|
|
956
|
+
if isinstance(hp, dict):
|
|
957
|
+
left_str = hp.get("left_expr") or hp.get("left_agg") or ""
|
|
958
|
+
if not left_str:
|
|
959
|
+
continue
|
|
960
|
+
right_str = hp.get("right_expr") or hp.get("right_agg") or ""
|
|
961
|
+
if right_str and "." not in right_str and not _is_date_or_interval_expr(right_str):
|
|
962
|
+
right_str = ""
|
|
963
|
+
fg_raw = hp.get("filter_group")
|
|
964
|
+
having_param.append(
|
|
965
|
+
HavingParam(
|
|
966
|
+
left_expr=parse_expr_string(left_str),
|
|
967
|
+
op=normalize_op(hp.get("op", ">")),
|
|
968
|
+
right_expr=parse_expr_string(right_str) if right_str else None,
|
|
969
|
+
value_type=hp.get("value_type", "integer"),
|
|
970
|
+
param_key="",
|
|
971
|
+
raw_value=hp.get("value"),
|
|
972
|
+
bool_op=hp.get("bool_op", "AND"),
|
|
973
|
+
filter_group=int(fg_raw) if fg_raw is not None else None,
|
|
974
|
+
)
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
cte_steps_raw = parsed.get("cte_steps", [])
|
|
978
|
+
cte_steps = []
|
|
979
|
+
for cte in cte_steps_raw:
|
|
980
|
+
if isinstance(cte, dict):
|
|
981
|
+
cte_select_cols = []
|
|
982
|
+
explicit_aliases: list[str | None] = []
|
|
983
|
+
for sc in cte.get("select_cols", []):
|
|
984
|
+
if isinstance(sc, str):
|
|
985
|
+
cte_select_cols.append(SelectCol(expr=parse_expr_string(sc)))
|
|
986
|
+
explicit_aliases.append(None)
|
|
987
|
+
elif isinstance(sc, dict):
|
|
988
|
+
expr_str = sc.get("expr", "")
|
|
989
|
+
cte_select_cols.append(SelectCol(expr=parse_expr_string(expr_str)))
|
|
990
|
+
explicit_aliases.append(sc.get("alias") or None)
|
|
991
|
+
|
|
992
|
+
cte_order_by = []
|
|
993
|
+
for obc in cte.get("order_by_cols", []):
|
|
994
|
+
if isinstance(obc, str):
|
|
995
|
+
expr_clean, direction = _strip_order_direction(obc)
|
|
996
|
+
cte_order_by.append(OrderByCol(expr=parse_expr_string(expr_clean), direction=direction))
|
|
997
|
+
elif isinstance(obc, dict):
|
|
998
|
+
expr_str = obc.get("expr", "")
|
|
999
|
+
expr_clean, dir_from_expr = _strip_order_direction(expr_str)
|
|
1000
|
+
direction = _normalize_order_direction(obc.get("direction") or dir_from_expr or "asc")
|
|
1001
|
+
cte_order_by.append(OrderByCol(expr=parse_expr_string(expr_clean), direction=direction))
|
|
1002
|
+
|
|
1003
|
+
cte_filters = []
|
|
1004
|
+
for fp in cte.get("filters_param", []):
|
|
1005
|
+
if isinstance(fp, dict):
|
|
1006
|
+
left_str = fp.get("left_expr") or fp.get("left_col") or ""
|
|
1007
|
+
if not left_str:
|
|
1008
|
+
continue
|
|
1009
|
+
right_str = fp.get("right_expr") or fp.get("right_col") or ""
|
|
1010
|
+
if right_str and "." not in right_str and not _is_date_or_interval_expr(right_str):
|
|
1011
|
+
right_str = ""
|
|
1012
|
+
fg_raw = fp.get("filter_group")
|
|
1013
|
+
cte_filters.append(
|
|
1014
|
+
FilterParam(
|
|
1015
|
+
left_expr=parse_expr_string(left_str),
|
|
1016
|
+
op=normalize_op(fp.get("op", "=")),
|
|
1017
|
+
right_expr=(parse_expr_string(right_str) if right_str else None),
|
|
1018
|
+
value_type=fp.get("value_type", "string"),
|
|
1019
|
+
param_key="",
|
|
1020
|
+
raw_value=fp.get("value"),
|
|
1021
|
+
bool_op=fp.get("bool_op", "AND"),
|
|
1022
|
+
filter_group=int(fg_raw) if fg_raw is not None else None,
|
|
1023
|
+
)
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
cte_having = []
|
|
1027
|
+
for hp in cte.get("having_param", []):
|
|
1028
|
+
if isinstance(hp, dict):
|
|
1029
|
+
left_str = hp.get("left_expr") or hp.get("left_agg") or ""
|
|
1030
|
+
if not left_str:
|
|
1031
|
+
continue
|
|
1032
|
+
right_str = hp.get("right_expr") or hp.get("right_agg") or ""
|
|
1033
|
+
if right_str and "." not in right_str and not _is_date_or_interval_expr(right_str):
|
|
1034
|
+
right_str = ""
|
|
1035
|
+
fg_raw = hp.get("filter_group")
|
|
1036
|
+
cte_having.append(
|
|
1037
|
+
HavingParam(
|
|
1038
|
+
left_expr=parse_expr_string(left_str),
|
|
1039
|
+
op=normalize_op(hp.get("op", ">")),
|
|
1040
|
+
right_expr=(parse_expr_string(right_str) if right_str else None),
|
|
1041
|
+
value_type=hp.get("value_type", "integer"),
|
|
1042
|
+
param_key="",
|
|
1043
|
+
raw_value=hp.get("value"),
|
|
1044
|
+
bool_op=hp.get("bool_op", "AND"),
|
|
1045
|
+
filter_group=int(fg_raw) if fg_raw is not None else None,
|
|
1046
|
+
)
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
cte_output_columns_raw = cte.get("output_columns", [])
|
|
1050
|
+
if isinstance(cte_output_columns_raw, str):
|
|
1051
|
+
cte_output_columns_raw = [cte_output_columns_raw]
|
|
1052
|
+
cte_output_columns = []
|
|
1053
|
+
for i, sc in enumerate(cte_select_cols):
|
|
1054
|
+
if i < len(explicit_aliases) and explicit_aliases[i]:
|
|
1055
|
+
cte_output_columns.append(str(explicit_aliases[i]).strip())
|
|
1056
|
+
elif i < len(cte_output_columns_raw) and cte_output_columns_raw[i]:
|
|
1057
|
+
cte_output_columns.append(str(cte_output_columns_raw[i]).strip())
|
|
1058
|
+
else:
|
|
1059
|
+
derived = derive_cte_output_columns([sc])
|
|
1060
|
+
cte_output_columns.append(derived[0] if derived else f"col_{i}")
|
|
1061
|
+
|
|
1062
|
+
cte_group_by_raw = cte.get("group_by_cols", [])
|
|
1063
|
+
cte_group_by = [parse_expr_string(g) for g in cte_group_by_raw]
|
|
1064
|
+
|
|
1065
|
+
cte_steps.append(
|
|
1066
|
+
RuntimeCteStep(
|
|
1067
|
+
cte_name=cte.get("cte_name", ""),
|
|
1068
|
+
description=cte.get("description"),
|
|
1069
|
+
tables=cte.get("tables", []),
|
|
1070
|
+
select_cols=cte_select_cols,
|
|
1071
|
+
group_by_cols=cte_group_by,
|
|
1072
|
+
order_by_cols=cte_order_by,
|
|
1073
|
+
filters_param=cte_filters,
|
|
1074
|
+
having_param=cte_having,
|
|
1075
|
+
param_values={},
|
|
1076
|
+
output_columns=cte_output_columns,
|
|
1077
|
+
grain=cte.get("grain") or "row_level",
|
|
1078
|
+
limit=cte.get("limit"),
|
|
1079
|
+
)
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
limit = parsed.get("limit")
|
|
1083
|
+
if isinstance(limit, str):
|
|
1084
|
+
try:
|
|
1085
|
+
limit = int(limit)
|
|
1086
|
+
except ValueError:
|
|
1087
|
+
limit = None
|
|
1088
|
+
|
|
1089
|
+
natural_language = parsed.get("natural_language", "").strip() or question
|
|
1090
|
+
debug(f"[intent_parse.full_intent_parse] extracted natural_language='{natural_language}'")
|
|
1091
|
+
|
|
1092
|
+
has_agg = any(sc.is_aggregated for sc in select_cols)
|
|
1093
|
+
if group_by_cols:
|
|
1094
|
+
grain = "grouped"
|
|
1095
|
+
elif has_agg:
|
|
1096
|
+
grain = "scalar"
|
|
1097
|
+
else:
|
|
1098
|
+
grain = "row_level"
|
|
1099
|
+
|
|
1100
|
+
schema_invalid = (parsed.get("intent_status") or "").strip().lower() == "schema_invalid"
|
|
1101
|
+
return RuntimeIntent(
|
|
1102
|
+
tables=tables,
|
|
1103
|
+
grain=grain,
|
|
1104
|
+
select_cols=select_cols,
|
|
1105
|
+
group_by_cols=group_by_cols,
|
|
1106
|
+
order_by_cols=order_by_cols,
|
|
1107
|
+
filters_param=filters_param,
|
|
1108
|
+
having_param=having_param,
|
|
1109
|
+
param_values={},
|
|
1110
|
+
cte_steps=cte_steps,
|
|
1111
|
+
natural_language=natural_language,
|
|
1112
|
+
limit=limit,
|
|
1113
|
+
schema_invalid=schema_invalid,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
def _is_nontrivial_group(g: MulGroup) -> bool:
|
|
1118
|
+
"""Return True if a MulGroup warrants coefficient parameterization.
|
|
1119
|
+
|
|
1120
|
+
Args:
|
|
1121
|
+
|
|
1122
|
+
g: MulGroup to test.
|
|
1123
|
+
|
|
1124
|
+
Returns:
|
|
1125
|
+
|
|
1126
|
+
True when the group has an aggregation, a scalar function, division operands, or a non-unit coefficient.
|
|
1127
|
+
"""
|
|
1128
|
+
return bool(g.agg_func or g.scalar_func or g.inner_scalar_func or g.divide or g.coefficient != 1.0)
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
def _assign_structural_group(g: MulGroup, idx: int, pv: dict[str, Any], is_numeric: bool = True) -> int:
|
|
1132
|
+
"""Assign structural param keys to a single MulGroup and collect values.
|
|
1133
|
+
|
|
1134
|
+
Args:
|
|
1135
|
+
|
|
1136
|
+
g: The MulGroup to assign param keys to.
|
|
1137
|
+
idx: The current structural param index counter.
|
|
1138
|
+
pv: The param_values dict to populate with key and value pairs.
|
|
1139
|
+
is_numeric: If False, skip coefficient parameterization entirely.
|
|
1140
|
+
|
|
1141
|
+
Returns:
|
|
1142
|
+
|
|
1143
|
+
The updated index after all assignments.
|
|
1144
|
+
"""
|
|
1145
|
+
if is_numeric and _is_nontrivial_group(g):
|
|
1146
|
+
key = f"s{idx}"
|
|
1147
|
+
g.coeff_param_key = key
|
|
1148
|
+
pv[key] = g.coefficient
|
|
1149
|
+
idx += 1
|
|
1150
|
+
for i, v in enumerate(g.scalar_func_args or []):
|
|
1151
|
+
key = f"s{idx}"
|
|
1152
|
+
if len(g.sarg_param_keys) <= i:
|
|
1153
|
+
g.sarg_param_keys.append(key)
|
|
1154
|
+
else:
|
|
1155
|
+
g.sarg_param_keys[i] = key
|
|
1156
|
+
pv[key] = v
|
|
1157
|
+
idx += 1
|
|
1158
|
+
for i, v in enumerate(g.inner_scalar_func_args or []):
|
|
1159
|
+
key = f"s{idx}"
|
|
1160
|
+
if len(g.isarg_param_keys) <= i:
|
|
1161
|
+
g.isarg_param_keys.append(key)
|
|
1162
|
+
else:
|
|
1163
|
+
g.isarg_param_keys[i] = key
|
|
1164
|
+
pv[key] = v
|
|
1165
|
+
idx += 1
|
|
1166
|
+
return idx
|
|
1167
|
+
|
|
1168
|
+
|
|
1169
|
+
def _assign_structural_expr(expr: NormalizedExpr, idx: int, pv: dict[str, Any]) -> int:
|
|
1170
|
+
"""Assign structural param keys to a single NormalizedExpr including ExprValue offsets and collect values.
|
|
1171
|
+
|
|
1172
|
+
Args:
|
|
1173
|
+
|
|
1174
|
+
expr: The NormalizedExpr to assign param keys to.
|
|
1175
|
+
idx: The current structural param index counter.
|
|
1176
|
+
pv: The param_values dict to populate with key and value pairs.
|
|
1177
|
+
|
|
1178
|
+
Returns:
|
|
1179
|
+
|
|
1180
|
+
The updated index after all assignments.
|
|
1181
|
+
"""
|
|
1182
|
+
for g in expr.add_groups:
|
|
1183
|
+
idx = _assign_structural_group(g, idx, pv, is_numeric=expr.is_numeric)
|
|
1184
|
+
for g in expr.sub_groups:
|
|
1185
|
+
idx = _assign_structural_group(g, idx, pv, is_numeric=expr.is_numeric)
|
|
1186
|
+
for i, v in enumerate(expr.scalar_func_args or []):
|
|
1187
|
+
key = f"s{idx}"
|
|
1188
|
+
if len(expr.sarg_param_keys) <= i:
|
|
1189
|
+
expr.sarg_param_keys.append(key)
|
|
1190
|
+
else:
|
|
1191
|
+
expr.sarg_param_keys[i] = key
|
|
1192
|
+
pv[key] = v
|
|
1193
|
+
idx += 1
|
|
1194
|
+
for i, v in enumerate(expr.inner_scalar_func_args or []):
|
|
1195
|
+
key = f"s{idx}"
|
|
1196
|
+
if len(expr.isarg_param_keys) <= i:
|
|
1197
|
+
expr.isarg_param_keys.append(key)
|
|
1198
|
+
else:
|
|
1199
|
+
expr.isarg_param_keys[i] = key
|
|
1200
|
+
pv[key] = v
|
|
1201
|
+
idx += 1
|
|
1202
|
+
if expr.is_numeric:
|
|
1203
|
+
for ev in expr.add_values:
|
|
1204
|
+
key = f"s{idx}"
|
|
1205
|
+
ev.param_key = key
|
|
1206
|
+
pv[key] = ev.value
|
|
1207
|
+
idx += 1
|
|
1208
|
+
for ev in expr.sub_values:
|
|
1209
|
+
key = f"s{idx}"
|
|
1210
|
+
ev.param_key = key
|
|
1211
|
+
pv[key] = ev.value
|
|
1212
|
+
idx += 1
|
|
1213
|
+
return idx
|
|
1214
|
+
|
|
1215
|
+
|
|
1216
|
+
def _infer_date_unit(column: str) -> str:
|
|
1217
|
+
"""Infer a temporal unit keyword from a column name for date function default arguments.
|
|
1218
|
+
|
|
1219
|
+
Args:
|
|
1220
|
+
|
|
1221
|
+
column: Fully qualified or bare column name string.
|
|
1222
|
+
|
|
1223
|
+
Returns:
|
|
1224
|
+
|
|
1225
|
+
One of 'month', 'day', 'week', 'quarter', or 'year', defaulting to 'month' when no date keyword is found in the column name.
|
|
1226
|
+
"""
|
|
1227
|
+
col_lower = column.lower()
|
|
1228
|
+
for keyword, unit in DATE_UNIT_KEYWORDS:
|
|
1229
|
+
if keyword in col_lower:
|
|
1230
|
+
return unit
|
|
1231
|
+
return "month"
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
def _fill_group_defaults(g: MulGroup) -> None:
|
|
1235
|
+
"""Fill default scalar_func_args and inner_scalar_func_args on a MulGroup if they are missing.
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
|
|
1239
|
+
g: MulGroup to mutate in-place.
|
|
1240
|
+
"""
|
|
1241
|
+
if g.scalar_func and not g.scalar_func_args:
|
|
1242
|
+
func = g.scalar_func.lower()
|
|
1243
|
+
if func in SCALAR_FUNCTIONS_LEADING_ARG:
|
|
1244
|
+
col = g.multiply[0] if g.multiply and g.multiply[0] != "*" else ""
|
|
1245
|
+
g.scalar_func_args = [_infer_date_unit(col)]
|
|
1246
|
+
else:
|
|
1247
|
+
defaults = SCALAR_FUNC_DEFAULTS.get(func)
|
|
1248
|
+
if defaults is not None:
|
|
1249
|
+
g.scalar_func_args = list(defaults)
|
|
1250
|
+
if g.inner_scalar_func and not g.inner_scalar_func_args:
|
|
1251
|
+
func = g.inner_scalar_func.lower()
|
|
1252
|
+
if func in SCALAR_FUNCTIONS_LEADING_ARG:
|
|
1253
|
+
col = g.multiply[0] if g.multiply and g.multiply[0] != "*" else ""
|
|
1254
|
+
g.inner_scalar_func_args = [_infer_date_unit(col)]
|
|
1255
|
+
else:
|
|
1256
|
+
defaults = SCALAR_FUNC_DEFAULTS.get(func)
|
|
1257
|
+
if defaults is not None:
|
|
1258
|
+
g.inner_scalar_func_args = list(defaults)
|
|
1259
|
+
|
|
1260
|
+
|
|
1261
|
+
def _fill_expr_defaults(expr: NormalizedExpr) -> None:
|
|
1262
|
+
"""Fill default scalar_func_args on a NormalizedExpr and all of its MulGroups if they are missing.
|
|
1263
|
+
|
|
1264
|
+
Args:
|
|
1265
|
+
|
|
1266
|
+
expr: NormalizedExpr to mutate in-place.
|
|
1267
|
+
"""
|
|
1268
|
+
if expr.scalar_func and not expr.scalar_func_args:
|
|
1269
|
+
func = expr.scalar_func.lower()
|
|
1270
|
+
if func in SCALAR_FUNCTIONS_LEADING_ARG:
|
|
1271
|
+
col = expr.primary_column
|
|
1272
|
+
expr.scalar_func_args = [_infer_date_unit(col)]
|
|
1273
|
+
else:
|
|
1274
|
+
defaults = SCALAR_FUNC_DEFAULTS.get(func)
|
|
1275
|
+
if defaults is not None:
|
|
1276
|
+
expr.scalar_func_args = list(defaults)
|
|
1277
|
+
if expr.inner_scalar_func and not expr.inner_scalar_func_args:
|
|
1278
|
+
func = expr.inner_scalar_func.lower()
|
|
1279
|
+
if func in SCALAR_FUNCTIONS_LEADING_ARG:
|
|
1280
|
+
col = expr.primary_column
|
|
1281
|
+
expr.inner_scalar_func_args = [_infer_date_unit(col)]
|
|
1282
|
+
else:
|
|
1283
|
+
defaults = SCALAR_FUNC_DEFAULTS.get(func)
|
|
1284
|
+
if defaults is not None:
|
|
1285
|
+
expr.inner_scalar_func_args = list(defaults)
|
|
1286
|
+
for g in expr.add_groups + expr.sub_groups:
|
|
1287
|
+
_fill_group_defaults(g)
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def ensure_scalar_func_defaults(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1291
|
+
"""Ensure all scalar functions in a RuntimeIntent carry their default argument lists.
|
|
1292
|
+
|
|
1293
|
+
Fills in missing scalar_func_args and inner_scalar_func_args on every expression across the main query and all CTE steps so template signatures stay consistent.
|
|
1294
|
+
|
|
1295
|
+
Args:
|
|
1296
|
+
|
|
1297
|
+
intent: RuntimeIntent to process.
|
|
1298
|
+
|
|
1299
|
+
Returns:
|
|
1300
|
+
|
|
1301
|
+
The same RuntimeIntent object, mutated in-place via _fill_expr_defaults.
|
|
1302
|
+
"""
|
|
1303
|
+
for cte in intent.cte_steps or []:
|
|
1304
|
+
for sc in cte.select_cols or []:
|
|
1305
|
+
_fill_expr_defaults(sc.expr)
|
|
1306
|
+
for g in cte.group_by_cols or []:
|
|
1307
|
+
_fill_expr_defaults(g)
|
|
1308
|
+
for obc in cte.order_by_cols or []:
|
|
1309
|
+
_fill_expr_defaults(obc.expr)
|
|
1310
|
+
for fp in cte.filters_param or []:
|
|
1311
|
+
_fill_expr_defaults(fp.left_expr)
|
|
1312
|
+
if fp.right_expr:
|
|
1313
|
+
_fill_expr_defaults(fp.right_expr)
|
|
1314
|
+
for hp in cte.having_param or []:
|
|
1315
|
+
_fill_expr_defaults(hp.left_expr)
|
|
1316
|
+
if hp.right_expr:
|
|
1317
|
+
_fill_expr_defaults(hp.right_expr)
|
|
1318
|
+
for sc in intent.select_cols or []:
|
|
1319
|
+
_fill_expr_defaults(sc.expr)
|
|
1320
|
+
for g in intent.group_by_cols or []:
|
|
1321
|
+
_fill_expr_defaults(g)
|
|
1322
|
+
for obc in intent.order_by_cols or []:
|
|
1323
|
+
_fill_expr_defaults(obc.expr)
|
|
1324
|
+
for fp in intent.filters_param or []:
|
|
1325
|
+
_fill_expr_defaults(fp.left_expr)
|
|
1326
|
+
if fp.right_expr:
|
|
1327
|
+
_fill_expr_defaults(fp.right_expr)
|
|
1328
|
+
for hp in intent.having_param or []:
|
|
1329
|
+
_fill_expr_defaults(hp.left_expr)
|
|
1330
|
+
if hp.right_expr:
|
|
1331
|
+
_fill_expr_defaults(hp.right_expr)
|
|
1332
|
+
return intent
|
|
1333
|
+
|
|
1334
|
+
|
|
1335
|
+
def extract_structural_params(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1336
|
+
"""Assign structural param keys such as s1 and s2 to limit values, coefficients, and function arguments.
|
|
1337
|
+
|
|
1338
|
+
Processes CTE steps first and then the main query, populates param_values with the concrete values for each key, and sets limit_param_key on the intent.
|
|
1339
|
+
|
|
1340
|
+
Args:
|
|
1341
|
+
|
|
1342
|
+
intent: RuntimeIntent with fully tagged NormalizedExprs.
|
|
1343
|
+
|
|
1344
|
+
Returns:
|
|
1345
|
+
|
|
1346
|
+
New RuntimeIntent with param_values and limit_param_key populated.
|
|
1347
|
+
"""
|
|
1348
|
+
pv: dict[str, Any] = dict(intent.param_values or {})
|
|
1349
|
+
idx = 1
|
|
1350
|
+
for cte in intent.cte_steps or []:
|
|
1351
|
+
if cte.limit is not None:
|
|
1352
|
+
key = f"s{idx}"
|
|
1353
|
+
pv[key] = cte.limit
|
|
1354
|
+
cte.limit_param_key = key
|
|
1355
|
+
idx += 1
|
|
1356
|
+
for sc in cte.select_cols or []:
|
|
1357
|
+
idx = _assign_structural_expr(sc.expr, idx, pv)
|
|
1358
|
+
for g in cte.group_by_cols or []:
|
|
1359
|
+
idx = _assign_structural_expr(g, idx, pv)
|
|
1360
|
+
for obc in cte.order_by_cols or []:
|
|
1361
|
+
idx = _assign_structural_expr(obc.expr, idx, pv)
|
|
1362
|
+
for fp in cte.filters_param or []:
|
|
1363
|
+
idx = _assign_structural_expr(fp.left_expr, idx, pv)
|
|
1364
|
+
if fp.right_expr:
|
|
1365
|
+
idx = _assign_structural_expr(fp.right_expr, idx, pv)
|
|
1366
|
+
for hp in cte.having_param or []:
|
|
1367
|
+
idx = _assign_structural_expr(hp.left_expr, idx, pv)
|
|
1368
|
+
if hp.right_expr:
|
|
1369
|
+
idx = _assign_structural_expr(hp.right_expr, idx, pv)
|
|
1370
|
+
limit_param_key = ""
|
|
1371
|
+
if intent.limit is not None:
|
|
1372
|
+
key = f"s{idx}"
|
|
1373
|
+
pv[key] = intent.limit
|
|
1374
|
+
limit_param_key = key
|
|
1375
|
+
idx += 1
|
|
1376
|
+
for sc in intent.select_cols or []:
|
|
1377
|
+
idx = _assign_structural_expr(sc.expr, idx, pv)
|
|
1378
|
+
for g in intent.group_by_cols or []:
|
|
1379
|
+
idx = _assign_structural_expr(g, idx, pv)
|
|
1380
|
+
for obc in intent.order_by_cols or []:
|
|
1381
|
+
idx = _assign_structural_expr(obc.expr, idx, pv)
|
|
1382
|
+
for fp in intent.filters_param or []:
|
|
1383
|
+
idx = _assign_structural_expr(fp.left_expr, idx, pv)
|
|
1384
|
+
if fp.right_expr:
|
|
1385
|
+
idx = _assign_structural_expr(fp.right_expr, idx, pv)
|
|
1386
|
+
for hp in intent.having_param or []:
|
|
1387
|
+
idx = _assign_structural_expr(hp.left_expr, idx, pv)
|
|
1388
|
+
if hp.right_expr:
|
|
1389
|
+
idx = _assign_structural_expr(hp.right_expr, idx, pv)
|
|
1390
|
+
debug(f"[intent_process.extract_structural_params] assigned {idx - 1} structural params")
|
|
1391
|
+
return replace(intent, param_values=pv, limit_param_key=limit_param_key)
|
|
1392
|
+
|
|
1393
|
+
|
|
1394
|
+
def has_non_default_structural(template: Template) -> bool:
|
|
1395
|
+
"""Return True if a template has any structural params with non-identity default values.
|
|
1396
|
+
|
|
1397
|
+
Identity values are 0, 0.0, 1, and 1.0 and templates with only identity structural defaults are treated the same as templates with no structural params.
|
|
1398
|
+
|
|
1399
|
+
Args:
|
|
1400
|
+
|
|
1401
|
+
template: Template to inspect.
|
|
1402
|
+
|
|
1403
|
+
Returns:
|
|
1404
|
+
|
|
1405
|
+
True when at least one structural_default value is outside the identity set.
|
|
1406
|
+
"""
|
|
1407
|
+
if not template.structural_defaults:
|
|
1408
|
+
return False
|
|
1409
|
+
return any(v not in STRUCTURAL_IDENTITY_VALUES for v in template.structural_defaults.values())
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
def collect_raw_param_values(intent: RuntimeIntent) -> dict[str, Any]:
|
|
1413
|
+
"""Collect raw_value from all keyed filter and having params and clear raw_value after.
|
|
1414
|
+
|
|
1415
|
+
Iterates CTE steps first and then the main query and clears raw_value in-place to avoid double-counting after collection.
|
|
1416
|
+
|
|
1417
|
+
Args:
|
|
1418
|
+
|
|
1419
|
+
intent: RuntimeIntent whose params will be harvested.
|
|
1420
|
+
|
|
1421
|
+
Returns:
|
|
1422
|
+
|
|
1423
|
+
Dictionary mapping param_key to its raw extracted value.
|
|
1424
|
+
"""
|
|
1425
|
+
pv: dict[str, Any] = {}
|
|
1426
|
+
for cte in intent.cte_steps or []:
|
|
1427
|
+
for fp in cte.filters_param or []:
|
|
1428
|
+
if fp.param_key and fp.raw_value is not None:
|
|
1429
|
+
pv[fp.param_key] = fp.raw_value
|
|
1430
|
+
fp.raw_value = None
|
|
1431
|
+
for hp in cte.having_param or []:
|
|
1432
|
+
if hp.param_key and hp.raw_value is not None:
|
|
1433
|
+
pv[hp.param_key] = hp.raw_value
|
|
1434
|
+
hp.raw_value = None
|
|
1435
|
+
for fp in intent.filters_param or []:
|
|
1436
|
+
if fp.param_key and fp.raw_value is not None:
|
|
1437
|
+
pv[fp.param_key] = fp.raw_value
|
|
1438
|
+
fp.raw_value = None
|
|
1439
|
+
for hp in intent.having_param or []:
|
|
1440
|
+
if hp.param_key and hp.raw_value is not None:
|
|
1441
|
+
pv[hp.param_key] = hp.raw_value
|
|
1442
|
+
hp.raw_value = None
|
|
1443
|
+
return pv
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
def assign_param_keys(
|
|
1447
|
+
filters_param: list[FilterParam],
|
|
1448
|
+
having_param: list[HavingParam],
|
|
1449
|
+
cte_steps: list[RuntimeCteStep] | None = None,
|
|
1450
|
+
) -> tuple[list[FilterParam], list[HavingParam], list[RuntimeCteStep], int]:
|
|
1451
|
+
"""Assign sequential param_key values such as p1 and p2 to all filter and having parameters.
|
|
1452
|
+
|
|
1453
|
+
CTE steps are processed before the main query and parameters with operators 'is null', 'is not null', or value_type 'date_window' are passed through without keys.
|
|
1454
|
+
|
|
1455
|
+
Args:
|
|
1456
|
+
|
|
1457
|
+
filters_param: Main query filter parameters to key.
|
|
1458
|
+
having_param: Main query having parameters to key.
|
|
1459
|
+
cte_steps: Optional list of CTE steps whose filters and having are keyed first.
|
|
1460
|
+
|
|
1461
|
+
Returns:
|
|
1462
|
+
|
|
1463
|
+
Tuple of (updated_filters_param, updated_having_param, updated_cte_steps, next_idx).
|
|
1464
|
+
"""
|
|
1465
|
+
idx = 1
|
|
1466
|
+
updated_cte_steps: list[RuntimeCteStep] = []
|
|
1467
|
+
for cte in cte_steps or []:
|
|
1468
|
+
cte_fp = []
|
|
1469
|
+
for fp in cte.filters_param or []:
|
|
1470
|
+
if fp.op in ("is null", "is not null") or fp.value_type in ("date_window", "date_diff") or fp.right_expr is not None:
|
|
1471
|
+
cte_fp.append(fp)
|
|
1472
|
+
else:
|
|
1473
|
+
cte_fp.append(replace(fp, param_key=f"p{idx}"))
|
|
1474
|
+
idx += 1
|
|
1475
|
+
cte_hp = []
|
|
1476
|
+
for hp in cte.having_param or []:
|
|
1477
|
+
if hp.right_expr is not None:
|
|
1478
|
+
cte_hp.append(hp)
|
|
1479
|
+
else:
|
|
1480
|
+
cte_hp.append(replace(hp, param_key=f"p{idx}"))
|
|
1481
|
+
idx += 1
|
|
1482
|
+
updated_cte_steps.append(replace(cte, filters_param=cte_fp, having_param=cte_hp))
|
|
1483
|
+
new_filters_param = []
|
|
1484
|
+
for fp in filters_param:
|
|
1485
|
+
if fp.op in ("is null", "is not null") or fp.value_type in ("date_window", "date_diff") or fp.right_expr is not None:
|
|
1486
|
+
new_filters_param.append(fp)
|
|
1487
|
+
else:
|
|
1488
|
+
new_filters_param.append(replace(fp, param_key=f"p{idx}"))
|
|
1489
|
+
idx += 1
|
|
1490
|
+
new_having_param = []
|
|
1491
|
+
for hp in having_param:
|
|
1492
|
+
if hp.right_expr is not None:
|
|
1493
|
+
new_having_param.append(hp)
|
|
1494
|
+
else:
|
|
1495
|
+
new_having_param.append(replace(hp, param_key=f"p{idx}"))
|
|
1496
|
+
idx += 1
|
|
1497
|
+
return new_filters_param, new_having_param, updated_cte_steps, idx
|
|
1498
|
+
|
|
1499
|
+
|
|
1500
|
+
def _parse_between_raw_value(raw_value: Any) -> tuple[Any, Any] | None:
|
|
1501
|
+
"""Try to extract two boundary values from a BETWEEN raw_value.
|
|
1502
|
+
|
|
1503
|
+
Handles list-of-two and string representations separated by common delimiters such as ' AND ', ' and ', ',', or ' - '.
|
|
1504
|
+
|
|
1505
|
+
Args:
|
|
1506
|
+
|
|
1507
|
+
raw_value: The raw_value from a FilterParam or HavingParam with op equal to 'between'.
|
|
1508
|
+
|
|
1509
|
+
Returns:
|
|
1510
|
+
|
|
1511
|
+
Tuple of (low, high) on success, or None when the value cannot be split.
|
|
1512
|
+
"""
|
|
1513
|
+
if isinstance(raw_value, list) and len(raw_value) == 2:
|
|
1514
|
+
return raw_value[0], raw_value[1]
|
|
1515
|
+
if isinstance(raw_value, str):
|
|
1516
|
+
for sep in (" AND ", " and ", ",", " - "):
|
|
1517
|
+
parts = raw_value.split(sep, 1)
|
|
1518
|
+
if len(parts) == 2:
|
|
1519
|
+
lo = parts[0].strip()
|
|
1520
|
+
hi = parts[1].strip()
|
|
1521
|
+
if lo and hi:
|
|
1522
|
+
return lo, hi
|
|
1523
|
+
return None
|
|
1524
|
+
|
|
1525
|
+
|
|
1526
|
+
def _decompose_between_param_list(
|
|
1527
|
+
params: list[FilterParam] | list[HavingParam],
|
|
1528
|
+
) -> list[FilterParam] | list[HavingParam]:
|
|
1529
|
+
"""Decompose BETWEEN operators in a list of filter or having params.
|
|
1530
|
+
|
|
1531
|
+
Each BETWEEN param is replaced by a >= and <= pair and when the raw_value can be parsed into two boundaries the pair uses separate values; otherwise the original raw_value is kept on both sides as a best-effort fallback.
|
|
1532
|
+
|
|
1533
|
+
Args:
|
|
1534
|
+
|
|
1535
|
+
params: List of FilterParam or HavingParam objects.
|
|
1536
|
+
|
|
1537
|
+
Returns:
|
|
1538
|
+
|
|
1539
|
+
New list with BETWEEN params replaced by >= and <= pairs.
|
|
1540
|
+
"""
|
|
1541
|
+
result: list = []
|
|
1542
|
+
for p in params:
|
|
1543
|
+
if p.op.lower() != "between":
|
|
1544
|
+
result.append(p)
|
|
1545
|
+
continue
|
|
1546
|
+
bounds = _parse_between_raw_value(p.raw_value)
|
|
1547
|
+
if bounds is not None:
|
|
1548
|
+
result.append(replace(p, op=">=", raw_value=bounds[0]))
|
|
1549
|
+
result.append(replace(p, op="<=", raw_value=bounds[1]))
|
|
1550
|
+
else:
|
|
1551
|
+
result.append(replace(p, op=">="))
|
|
1552
|
+
result.append(replace(p, op="<="))
|
|
1553
|
+
return result
|
|
1554
|
+
|
|
1555
|
+
|
|
1556
|
+
def decompose_between_params(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1557
|
+
"""Decompose BETWEEN filter and having operators into paired >= and <= conditions.
|
|
1558
|
+
|
|
1559
|
+
Applies to filters_param, having_param, and their counterparts in CTE steps.
|
|
1560
|
+
|
|
1561
|
+
Args:
|
|
1562
|
+
|
|
1563
|
+
intent: RuntimeIntent containing filters_param, having_param, and cte_steps to process.
|
|
1564
|
+
|
|
1565
|
+
Returns:
|
|
1566
|
+
|
|
1567
|
+
New RuntimeIntent with all BETWEEN operators replaced by >= and <= pairs.
|
|
1568
|
+
"""
|
|
1569
|
+
new_fp = _decompose_between_param_list(intent.filters_param or [])
|
|
1570
|
+
new_hp = _decompose_between_param_list(intent.having_param or [])
|
|
1571
|
+
new_cte_steps = []
|
|
1572
|
+
for cte in intent.cte_steps or []:
|
|
1573
|
+
cte_fp = _decompose_between_param_list(cte.filters_param or [])
|
|
1574
|
+
cte_hp = _decompose_between_param_list(cte.having_param or [])
|
|
1575
|
+
new_cte_steps.append(replace(cte, filters_param=cte_fp, having_param=cte_hp))
|
|
1576
|
+
return replace(intent, filters_param=new_fp, having_param=new_hp, cte_steps=new_cte_steps)
|
|
1577
|
+
|
|
1578
|
+
|
|
1579
|
+
def _parse_in_string_to_list(raw_value: str) -> list[str]:
|
|
1580
|
+
"""Parse a string-encoded IN-list into a list of stripped string elements.
|
|
1581
|
+
|
|
1582
|
+
Handles formats such as "R, PG-13", "'R','PG-13'", and "1, 2, 3" and strips leading and trailing quotes on each element.
|
|
1583
|
+
|
|
1584
|
+
Args:
|
|
1585
|
+
|
|
1586
|
+
raw_value: String representation of an IN-list value.
|
|
1587
|
+
|
|
1588
|
+
Returns:
|
|
1589
|
+
|
|
1590
|
+
List of individual value strings with surrounding quotes removed.
|
|
1591
|
+
"""
|
|
1592
|
+
parts = IN_STRING_SEPARATORS.split(raw_value)
|
|
1593
|
+
return [p.strip().strip("'\"") for p in parts if p.strip().strip("'\"")]
|
|
1594
|
+
|
|
1595
|
+
|
|
1596
|
+
def _normalize_in_param_list(
|
|
1597
|
+
params: list[FilterParam] | list[HavingParam],
|
|
1598
|
+
) -> list[FilterParam] | list[HavingParam]:
|
|
1599
|
+
"""Convert string raw_values to lists for IN / NOT IN operators.
|
|
1600
|
+
|
|
1601
|
+
Args:
|
|
1602
|
+
|
|
1603
|
+
params: Filter or having params to normalise.
|
|
1604
|
+
|
|
1605
|
+
Returns:
|
|
1606
|
+
|
|
1607
|
+
New list with string IN-values parsed into Python lists.
|
|
1608
|
+
"""
|
|
1609
|
+
result: list = []
|
|
1610
|
+
for p in params:
|
|
1611
|
+
if p.op.lower() not in IN_OPS or not isinstance(p.raw_value, str):
|
|
1612
|
+
result.append(p)
|
|
1613
|
+
continue
|
|
1614
|
+
parsed = _parse_in_string_to_list(p.raw_value)
|
|
1615
|
+
if len(parsed) > 1:
|
|
1616
|
+
result.append(replace(p, raw_value=parsed))
|
|
1617
|
+
else:
|
|
1618
|
+
result.append(p)
|
|
1619
|
+
return result
|
|
1620
|
+
|
|
1621
|
+
|
|
1622
|
+
def normalize_in_raw_values(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1623
|
+
"""Normalise IN and NOT IN raw_values from strings to lists across the intent.
|
|
1624
|
+
|
|
1625
|
+
When the LLM emits IN values as a flat string this function parses them into proper Python lists so that downstream processing can iterate over individual elements.
|
|
1626
|
+
|
|
1627
|
+
Args:
|
|
1628
|
+
|
|
1629
|
+
intent: RuntimeIntent whose filters and havings may contain string IN-values.
|
|
1630
|
+
|
|
1631
|
+
Returns:
|
|
1632
|
+
|
|
1633
|
+
Updated RuntimeIntent with string IN-values converted to lists.
|
|
1634
|
+
"""
|
|
1635
|
+
new_fp = _normalize_in_param_list(intent.filters_param or [])
|
|
1636
|
+
new_hp = _normalize_in_param_list(intent.having_param or [])
|
|
1637
|
+
new_cte_steps = []
|
|
1638
|
+
for cte in intent.cte_steps or []:
|
|
1639
|
+
cte_fp = _normalize_in_param_list(cte.filters_param or [])
|
|
1640
|
+
cte_hp = _normalize_in_param_list(cte.having_param or [])
|
|
1641
|
+
new_cte_steps.append(replace(cte, filters_param=cte_fp, having_param=cte_hp))
|
|
1642
|
+
return replace(intent, filters_param=new_fp, having_param=new_hp, cte_steps=new_cte_steps)
|
|
1643
|
+
|
|
1644
|
+
|
|
1645
|
+
_PLURAL_TO_SINGULAR_UNIT = {
|
|
1646
|
+
"days": "day",
|
|
1647
|
+
"weeks": "week",
|
|
1648
|
+
"months": "month",
|
|
1649
|
+
"years": "year",
|
|
1650
|
+
"hours": "hour",
|
|
1651
|
+
"minutes": "minute",
|
|
1652
|
+
"seconds": "second",
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
|
|
1656
|
+
def _normalize_date_unit_in_raw_value(raw_value: Any) -> Any:
|
|
1657
|
+
"""Normalise unit in date_window/date_diff raw_value dict to singular form."""
|
|
1658
|
+
if not isinstance(raw_value, dict):
|
|
1659
|
+
return raw_value
|
|
1660
|
+
unit = raw_value.get("unit")
|
|
1661
|
+
if isinstance(unit, str):
|
|
1662
|
+
unit_lower = unit.lower().strip()
|
|
1663
|
+
singular = _PLURAL_TO_SINGULAR_UNIT.get(unit_lower)
|
|
1664
|
+
if singular:
|
|
1665
|
+
return {**raw_value, "unit": singular}
|
|
1666
|
+
if unit_lower in VALID_DATE_WINDOW_UNITS | VALID_DATE_DIFF_UNITS:
|
|
1667
|
+
return raw_value
|
|
1668
|
+
return raw_value
|
|
1669
|
+
|
|
1670
|
+
|
|
1671
|
+
def normalize_date_diff_raw_values(intent: RuntimeIntent) -> RuntimeIntent:
|
|
1672
|
+
"""Normalise date_window and date_diff raw_value units to singular form.
|
|
1673
|
+
|
|
1674
|
+
Converts 'days' -> 'day', 'weeks' -> 'week', etc. so downstream
|
|
1675
|
+
rendering uses consistent unit names.
|
|
1676
|
+
"""
|
|
1677
|
+
def _process(params: list) -> list:
|
|
1678
|
+
out = []
|
|
1679
|
+
for p in params or []:
|
|
1680
|
+
if p.value_type in ("date_window", "date_diff") and p.raw_value is not None:
|
|
1681
|
+
out.append(replace(p, raw_value=_normalize_date_unit_in_raw_value(p.raw_value)))
|
|
1682
|
+
else:
|
|
1683
|
+
out.append(p)
|
|
1684
|
+
return out
|
|
1685
|
+
|
|
1686
|
+
new_fp = _process(intent.filters_param or [])
|
|
1687
|
+
new_hp = _process(intent.having_param or [])
|
|
1688
|
+
new_cte_steps = []
|
|
1689
|
+
for cte in intent.cte_steps or []:
|
|
1690
|
+
cte_fp = _process(cte.filters_param or [])
|
|
1691
|
+
cte_hp = _process(cte.having_param or [])
|
|
1692
|
+
new_cte_steps.append(replace(cte, filters_param=cte_fp, having_param=cte_hp))
|
|
1693
|
+
return replace(intent, filters_param=new_fp, having_param=new_hp, cte_steps=new_cte_steps)
|
|
1694
|
+
|
|
1695
|
+
|
|
1696
|
+
def _is_plain_column_expr(expr: NormalizedExpr) -> bool:
|
|
1697
|
+
"""Return True when the expression is a bare column reference.
|
|
1698
|
+
|
|
1699
|
+
A plain column has exactly one add_group with one term, no
|
|
1700
|
+
sub_groups, no add_values, and no arithmetic operators,
|
|
1701
|
+
indicating it is not a computed date-subtraction expression.
|
|
1702
|
+
"""
|
|
1703
|
+
if expr.sub_groups:
|
|
1704
|
+
return False
|
|
1705
|
+
if expr.add_values:
|
|
1706
|
+
return False
|
|
1707
|
+
if len(expr.add_groups) != 1:
|
|
1708
|
+
return False
|
|
1709
|
+
group = expr.add_groups[0]
|
|
1710
|
+
if len(group.multiply) != 1:
|
|
1711
|
+
return False
|
|
1712
|
+
term = group.multiply[0]
|
|
1713
|
+
return not any(ch in term for ch in "+-*/")
|
|
1714
|
+
|
|
1715
|
+
|
|
1716
|
+
def _reclassify_date_diff_param(
|
|
1717
|
+
fp: FilterParam,
|
|
1718
|
+
) -> FilterParam:
|
|
1719
|
+
"""Convert a misclassified date_diff filter to date_window.
|
|
1720
|
+
|
|
1721
|
+
When the LLM returns value_type ``date_diff`` with a plain date
|
|
1722
|
+
column as left_expr (instead of a date subtraction expression),
|
|
1723
|
+
it actually means a relative time window such as "last 90 days".
|
|
1724
|
+
Reclassify as ``date_window`` with the ``amount`` mapped to
|
|
1725
|
+
``offset``.
|
|
1726
|
+
"""
|
|
1727
|
+
if fp.value_type != "date_diff":
|
|
1728
|
+
return fp
|
|
1729
|
+
if not isinstance(fp.raw_value, dict):
|
|
1730
|
+
return fp
|
|
1731
|
+
if not _is_plain_column_expr(fp.left_expr):
|
|
1732
|
+
return fp
|
|
1733
|
+
rv = fp.raw_value
|
|
1734
|
+
amount = rv.get("amount")
|
|
1735
|
+
if amount is None:
|
|
1736
|
+
return fp
|
|
1737
|
+
new_rv = {"unit": rv.get("unit", "day"), "offset": int(amount)}
|
|
1738
|
+
return replace(fp, value_type="date_window", raw_value=new_rv)
|
|
1739
|
+
|
|
1740
|
+
|
|
1741
|
+
def repair_misclassified_date_diff(
|
|
1742
|
+
intent: RuntimeIntent,
|
|
1743
|
+
) -> RuntimeIntent:
|
|
1744
|
+
"""Reclassify date_diff filters that target a plain column.
|
|
1745
|
+
|
|
1746
|
+
When a date_diff filter has a simple column reference as its
|
|
1747
|
+
left_expr rather than a date-subtraction expression, it
|
|
1748
|
+
represents a relative time window and is converted to
|
|
1749
|
+
date_window with the amount mapped to offset.
|
|
1750
|
+
"""
|
|
1751
|
+
def _process(params: list[FilterParam]) -> list[FilterParam]:
|
|
1752
|
+
return [_reclassify_date_diff_param(fp) for fp in params]
|
|
1753
|
+
|
|
1754
|
+
new_fp = _process(intent.filters_param or [])
|
|
1755
|
+
new_cte_steps = []
|
|
1756
|
+
for cte in intent.cte_steps or []:
|
|
1757
|
+
cte_fp = _process(cte.filters_param or [])
|
|
1758
|
+
new_cte_steps.append(replace(cte, filters_param=cte_fp))
|
|
1759
|
+
return replace(intent, filters_param=new_fp, cte_steps=new_cte_steps)
|