aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
|
@@ -0,0 +1,1890 @@
|
|
|
1
|
+
"""Intent and template contracts for the text-to-SQL pipeline.
|
|
2
|
+
|
|
3
|
+
Defines the expression model (ExprValue, MulGroup, NormalizedExpr) used throughout the pipeline for canonical sum-of-products representations, along with all intent containers (RuntimeIntent, ConcreteIntent, SimulatorIntent, QSimIntent), filter and having conditions, select and order-by columns, CTE step representations, and the Template/RejectedTemplate structures used for query reuse.
|
|
4
|
+
|
|
5
|
+
Also provides conversion helpers between runtime and concrete forms used at different pipeline stages.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .config import AGG_PREFIXES, OP_FLIP
|
|
14
|
+
from .contracts_base import (
|
|
15
|
+
CteOutputColumnMeta,
|
|
16
|
+
ExpansionMetadata,
|
|
17
|
+
IntentIssue,
|
|
18
|
+
SQLShape,
|
|
19
|
+
TemplateStats,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
ScalarArg = str | int | float
|
|
23
|
+
ParamValue = str | int | float | bool | list[str | int | float]
|
|
24
|
+
RawValue = str | int | float | bool | list[str | int | float] | dict[str, str | int] | None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ExprValue:
|
|
29
|
+
"""Parameterized literal value for expression arithmetic with param_key for template reuse."""
|
|
30
|
+
|
|
31
|
+
value: float = 0.0
|
|
32
|
+
param_key: str = ""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def from_dict(d: dict[str, Any]) -> ExprValue:
|
|
36
|
+
"""Create ExprValue from dictionary.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
|
|
40
|
+
d: Dictionary with 'value' and 'param_key' keys, or a bare numeric value.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
|
|
44
|
+
Populated ExprValue instance.
|
|
45
|
+
"""
|
|
46
|
+
if isinstance(d, int | float):
|
|
47
|
+
return ExprValue(value=float(d))
|
|
48
|
+
return ExprValue(value=d.get("value", 0.0), param_key=d.get("param_key", ""))
|
|
49
|
+
|
|
50
|
+
def to_dict(self) -> dict[str, Any]:
|
|
51
|
+
"""Serialize to a plain dictionary.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
|
|
55
|
+
Dictionary with 'value' and 'param_key' keys.
|
|
56
|
+
"""
|
|
57
|
+
return {"value": self.value, "param_key": self.param_key}
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def signature_key(self) -> str:
|
|
61
|
+
"""Return structural signature for template matching (value-agnostic)."""
|
|
62
|
+
return "val"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class MulGroup:
|
|
67
|
+
"""Single multiplicative term: scalar_func(agg_func(inner_scalar_func(coefficient * multiply[0] * ... / divide[0] / ...))) with scalar_func_args and inner_scalar_func_args."""
|
|
68
|
+
|
|
69
|
+
coefficient: float = 1.0
|
|
70
|
+
multiply: list[str] = field(default_factory=list)
|
|
71
|
+
divide: list[str] = field(default_factory=list)
|
|
72
|
+
agg_func: str | None = None
|
|
73
|
+
scalar_func: str | None = None
|
|
74
|
+
inner_scalar_func: str | None = None
|
|
75
|
+
scalar_func_args: list[ScalarArg] = field(default_factory=list)
|
|
76
|
+
inner_scalar_func_args: list[ScalarArg] = field(default_factory=list)
|
|
77
|
+
coeff_param_key: str = ""
|
|
78
|
+
sarg_param_keys: list[str] = field(default_factory=list)
|
|
79
|
+
isarg_param_keys: list[str] = field(default_factory=list)
|
|
80
|
+
|
|
81
|
+
def __post_init__(self) -> None:
|
|
82
|
+
"""Sort multiply and divide lists and lowercase function names."""
|
|
83
|
+
self.multiply = sorted(self.multiply)
|
|
84
|
+
self.divide = sorted(self.divide)
|
|
85
|
+
if self.agg_func:
|
|
86
|
+
self.agg_func = self.agg_func.lower()
|
|
87
|
+
if self.scalar_func:
|
|
88
|
+
self.scalar_func = self.scalar_func.lower()
|
|
89
|
+
if self.inner_scalar_func:
|
|
90
|
+
self.inner_scalar_func = self.inner_scalar_func.lower()
|
|
91
|
+
if self.scalar_func and self.inner_scalar_func:
|
|
92
|
+
if self.scalar_func == "extract":
|
|
93
|
+
pass
|
|
94
|
+
elif self.inner_scalar_func == "extract":
|
|
95
|
+
self.scalar_func, self.inner_scalar_func = (
|
|
96
|
+
self.inner_scalar_func,
|
|
97
|
+
self.scalar_func,
|
|
98
|
+
)
|
|
99
|
+
self.scalar_func_args, self.inner_scalar_func_args = (
|
|
100
|
+
self.inner_scalar_func_args,
|
|
101
|
+
self.scalar_func_args,
|
|
102
|
+
)
|
|
103
|
+
elif self.scalar_func > self.inner_scalar_func:
|
|
104
|
+
self.scalar_func, self.inner_scalar_func = (
|
|
105
|
+
self.inner_scalar_func,
|
|
106
|
+
self.scalar_func,
|
|
107
|
+
)
|
|
108
|
+
self.scalar_func_args, self.inner_scalar_func_args = (
|
|
109
|
+
self.inner_scalar_func_args,
|
|
110
|
+
self.scalar_func_args,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def from_dict(d: dict[str, Any]) -> MulGroup:
|
|
115
|
+
"""Create MulGroup from dictionary.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
|
|
119
|
+
d: Dictionary with keys matching MulGroup fields.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
|
|
123
|
+
Populated MulGroup instance.
|
|
124
|
+
"""
|
|
125
|
+
return MulGroup(
|
|
126
|
+
coefficient=d.get("coefficient", 1.0),
|
|
127
|
+
multiply=d.get("multiply", []),
|
|
128
|
+
divide=d.get("divide", []),
|
|
129
|
+
agg_func=d.get("agg_func"),
|
|
130
|
+
scalar_func=d.get("scalar_func"),
|
|
131
|
+
inner_scalar_func=d.get("inner_scalar_func"),
|
|
132
|
+
scalar_func_args=d.get("scalar_func_args", []),
|
|
133
|
+
inner_scalar_func_args=d.get("inner_scalar_func_args", []),
|
|
134
|
+
coeff_param_key=d.get("coeff_param_key", ""),
|
|
135
|
+
sarg_param_keys=d.get("sarg_param_keys", []),
|
|
136
|
+
isarg_param_keys=d.get("isarg_param_keys", []),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def to_dict(self) -> dict[str, Any]:
|
|
140
|
+
"""Convert to dictionary."""
|
|
141
|
+
return {
|
|
142
|
+
"coefficient": self.coefficient,
|
|
143
|
+
"multiply": self.multiply,
|
|
144
|
+
"divide": self.divide,
|
|
145
|
+
"agg_func": self.agg_func,
|
|
146
|
+
"scalar_func": self.scalar_func,
|
|
147
|
+
"inner_scalar_func": self.inner_scalar_func,
|
|
148
|
+
"scalar_func_args": self.scalar_func_args,
|
|
149
|
+
"inner_scalar_func_args": self.inner_scalar_func_args,
|
|
150
|
+
"coeff_param_key": self.coeff_param_key,
|
|
151
|
+
"sarg_param_keys": self.sarg_param_keys,
|
|
152
|
+
"isarg_param_keys": self.isarg_param_keys,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def signature_key(self) -> str:
|
|
157
|
+
"""Return structural signature for template matching."""
|
|
158
|
+
parts = ["coeff"]
|
|
159
|
+
if self.agg_func:
|
|
160
|
+
parts.append(f"agg={self.agg_func}")
|
|
161
|
+
if self.scalar_func:
|
|
162
|
+
parts.append(f"scalar={self.scalar_func}")
|
|
163
|
+
if self.scalar_func_args:
|
|
164
|
+
parts.append(f"sargs={len(self.scalar_func_args)}")
|
|
165
|
+
if self.inner_scalar_func:
|
|
166
|
+
parts.append(f"inner={self.inner_scalar_func}")
|
|
167
|
+
if self.inner_scalar_func_args:
|
|
168
|
+
parts.append(f"iargs={len(self.inner_scalar_func_args)}")
|
|
169
|
+
parts.extend(f"*{m}" for m in self.multiply)
|
|
170
|
+
parts.extend(f"/{d}" for d in self.divide)
|
|
171
|
+
return "|".join(parts)
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def structural_key(self) -> str:
|
|
175
|
+
"""Return coefficient-agnostic structural key for like-term combining."""
|
|
176
|
+
parts: list[str] = []
|
|
177
|
+
if self.agg_func:
|
|
178
|
+
parts.append(f"agg={self.agg_func}")
|
|
179
|
+
if self.scalar_func:
|
|
180
|
+
parts.append(f"scalar={self.scalar_func}")
|
|
181
|
+
if self.scalar_func_args:
|
|
182
|
+
parts.append(f"sargs={len(self.scalar_func_args)}")
|
|
183
|
+
if self.inner_scalar_func:
|
|
184
|
+
parts.append(f"inner={self.inner_scalar_func}")
|
|
185
|
+
if self.inner_scalar_func_args:
|
|
186
|
+
parts.append(f"iargs={len(self.inner_scalar_func_args)}")
|
|
187
|
+
parts.extend(f"*{m}" for m in self.multiply)
|
|
188
|
+
parts.extend(f"/{d}" for d in self.divide)
|
|
189
|
+
return "|".join(parts)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@dataclass
|
|
193
|
+
class NormalizedExpr:
|
|
194
|
+
"""Canonical sum-of-products expression: scalar_func(agg_func(inner_scalar_func(sum of add_groups minus sub_groups plus add_values minus sub_values))) with scalar_func_args and inner_scalar_func_args."""
|
|
195
|
+
|
|
196
|
+
add_groups: list[MulGroup] = field(default_factory=list)
|
|
197
|
+
sub_groups: list[MulGroup] = field(default_factory=list)
|
|
198
|
+
add_values: list[ExprValue] = field(default_factory=list)
|
|
199
|
+
sub_values: list[ExprValue] = field(default_factory=list)
|
|
200
|
+
agg_func: str | None = None
|
|
201
|
+
scalar_func: str | None = None
|
|
202
|
+
inner_scalar_func: str | None = None
|
|
203
|
+
scalar_func_args: list[ScalarArg] = field(default_factory=list)
|
|
204
|
+
inner_scalar_func_args: list[ScalarArg] = field(default_factory=list)
|
|
205
|
+
sarg_param_keys: list[str] = field(default_factory=list)
|
|
206
|
+
isarg_param_keys: list[str] = field(default_factory=list)
|
|
207
|
+
is_numeric: bool = True
|
|
208
|
+
|
|
209
|
+
def __post_init__(self) -> None:
|
|
210
|
+
"""Sort groups by signature_key, sort values by value, and lowercase function names."""
|
|
211
|
+
self.add_groups = sorted(self.add_groups, key=lambda g: g.signature_key)
|
|
212
|
+
self.sub_groups = sorted(self.sub_groups, key=lambda g: g.signature_key)
|
|
213
|
+
self.add_values = sorted(self.add_values, key=lambda v: v.value)
|
|
214
|
+
self.sub_values = sorted(self.sub_values, key=lambda v: v.value)
|
|
215
|
+
if self.agg_func:
|
|
216
|
+
self.agg_func = self.agg_func.lower()
|
|
217
|
+
if self.scalar_func:
|
|
218
|
+
self.scalar_func = self.scalar_func.lower()
|
|
219
|
+
if self.inner_scalar_func:
|
|
220
|
+
self.inner_scalar_func = self.inner_scalar_func.lower()
|
|
221
|
+
if self.scalar_func and self.inner_scalar_func:
|
|
222
|
+
if self.scalar_func == "extract":
|
|
223
|
+
pass
|
|
224
|
+
elif self.inner_scalar_func == "extract":
|
|
225
|
+
self.scalar_func, self.inner_scalar_func = (
|
|
226
|
+
self.inner_scalar_func,
|
|
227
|
+
self.scalar_func,
|
|
228
|
+
)
|
|
229
|
+
self.scalar_func_args, self.inner_scalar_func_args = (
|
|
230
|
+
self.inner_scalar_func_args,
|
|
231
|
+
self.scalar_func_args,
|
|
232
|
+
)
|
|
233
|
+
elif self.scalar_func > self.inner_scalar_func:
|
|
234
|
+
self.scalar_func, self.inner_scalar_func = (
|
|
235
|
+
self.inner_scalar_func,
|
|
236
|
+
self.scalar_func,
|
|
237
|
+
)
|
|
238
|
+
self.scalar_func_args, self.inner_scalar_func_args = (
|
|
239
|
+
self.inner_scalar_func_args,
|
|
240
|
+
self.scalar_func_args,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def from_dict(d: dict[str, Any]) -> NormalizedExpr:
|
|
245
|
+
"""Create NormalizedExpr from dictionary.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
|
|
249
|
+
d: Dictionary with keys matching NormalizedExpr fields.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
|
|
253
|
+
Populated NormalizedExpr instance with nested MulGroup and ExprValue objects.
|
|
254
|
+
"""
|
|
255
|
+
return NormalizedExpr(
|
|
256
|
+
add_groups=[MulGroup.from_dict(g) for g in d.get("add_groups", [])],
|
|
257
|
+
sub_groups=[MulGroup.from_dict(g) for g in d.get("sub_groups", [])],
|
|
258
|
+
add_values=[ExprValue.from_dict(v) for v in d.get("add_values", [])],
|
|
259
|
+
sub_values=[ExprValue.from_dict(v) for v in d.get("sub_values", [])],
|
|
260
|
+
agg_func=d.get("agg_func"),
|
|
261
|
+
scalar_func=d.get("scalar_func"),
|
|
262
|
+
inner_scalar_func=d.get("inner_scalar_func"),
|
|
263
|
+
scalar_func_args=d.get("scalar_func_args", []),
|
|
264
|
+
inner_scalar_func_args=d.get("inner_scalar_func_args", []),
|
|
265
|
+
sarg_param_keys=d.get("sarg_param_keys", []),
|
|
266
|
+
isarg_param_keys=d.get("isarg_param_keys", []),
|
|
267
|
+
is_numeric=d.get("is_numeric", True),
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def to_dict(self) -> dict[str, Any]:
|
|
271
|
+
"""Serialize to a plain dictionary.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
|
|
275
|
+
Dictionary with all NormalizedExpr fields, with nested groups and values serialized.
|
|
276
|
+
"""
|
|
277
|
+
return {
|
|
278
|
+
"add_groups": [g.to_dict() for g in self.add_groups],
|
|
279
|
+
"sub_groups": [g.to_dict() for g in self.sub_groups],
|
|
280
|
+
"add_values": [v.to_dict() for v in self.add_values],
|
|
281
|
+
"sub_values": [v.to_dict() for v in self.sub_values],
|
|
282
|
+
"agg_func": self.agg_func,
|
|
283
|
+
"scalar_func": self.scalar_func,
|
|
284
|
+
"inner_scalar_func": self.inner_scalar_func,
|
|
285
|
+
"scalar_func_args": self.scalar_func_args,
|
|
286
|
+
"inner_scalar_func_args": self.inner_scalar_func_args,
|
|
287
|
+
"sarg_param_keys": self.sarg_param_keys,
|
|
288
|
+
"isarg_param_keys": self.isarg_param_keys,
|
|
289
|
+
"is_numeric": self.is_numeric,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def from_column(col: str) -> NormalizedExpr:
|
|
294
|
+
"""Create expression for a bare column reference."""
|
|
295
|
+
return NormalizedExpr(add_groups=[MulGroup(multiply=[col])])
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
def from_agg(agg_func: str, col: str) -> NormalizedExpr:
|
|
299
|
+
"""Create expression for a single aggregation call."""
|
|
300
|
+
return NormalizedExpr(add_groups=[MulGroup(multiply=[col], agg_func=agg_func.lower())])
|
|
301
|
+
|
|
302
|
+
@property
|
|
303
|
+
def signature_key(self) -> str:
|
|
304
|
+
"""Return structural signature for template matching."""
|
|
305
|
+
parts = []
|
|
306
|
+
if self.agg_func:
|
|
307
|
+
parts.append(f"expr_agg={self.agg_func}")
|
|
308
|
+
if self.scalar_func:
|
|
309
|
+
parts.append(f"expr_scalar={self.scalar_func}")
|
|
310
|
+
if self.scalar_func_args:
|
|
311
|
+
parts.append(f"expr_sargs={len(self.scalar_func_args)}")
|
|
312
|
+
if self.inner_scalar_func:
|
|
313
|
+
parts.append(f"expr_inner={self.inner_scalar_func}")
|
|
314
|
+
if self.inner_scalar_func_args:
|
|
315
|
+
parts.append(f"expr_iargs={len(self.inner_scalar_func_args)}")
|
|
316
|
+
parts.extend(f"+{g.signature_key}" for g in self.add_groups)
|
|
317
|
+
parts.extend(f"-{g.signature_key}" for g in self.sub_groups)
|
|
318
|
+
parts.extend(f"+{v.signature_key}" for v in self.add_values)
|
|
319
|
+
parts.extend(f"-{v.signature_key}" for v in self.sub_values)
|
|
320
|
+
return "|".join(parts)
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def has_aggregation(self) -> bool:
|
|
324
|
+
"""Return True if any term contains a SQL aggregation function, any group has agg_func, or the expression has agg_func."""
|
|
325
|
+
if self.agg_func:
|
|
326
|
+
return True
|
|
327
|
+
for group in self.add_groups + self.sub_groups:
|
|
328
|
+
if group.agg_func:
|
|
329
|
+
return True
|
|
330
|
+
for term in group.multiply + group.divide:
|
|
331
|
+
upper = term.upper()
|
|
332
|
+
if any(p in upper for p in AGG_PREFIXES):
|
|
333
|
+
return True
|
|
334
|
+
return False
|
|
335
|
+
|
|
336
|
+
@property
|
|
337
|
+
def primary_column(self) -> str:
|
|
338
|
+
"""Return innermost column reference from the first term, stripping function wrappers and DISTINCT prefix."""
|
|
339
|
+
if not self.add_groups or not self.add_groups[0].multiply:
|
|
340
|
+
return ""
|
|
341
|
+
term = self.add_groups[0].multiply[0]
|
|
342
|
+
while "(" in term:
|
|
343
|
+
start = term.index("(")
|
|
344
|
+
end = term.rindex(")")
|
|
345
|
+
term = term[start + 1 : end]
|
|
346
|
+
if term.upper().startswith("DISTINCT "):
|
|
347
|
+
term = term[9:].strip()
|
|
348
|
+
return term
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def primary_term(self) -> str:
|
|
352
|
+
"""Return the first atomic term as-is."""
|
|
353
|
+
if self.add_groups and self.add_groups[0].multiply:
|
|
354
|
+
return self.add_groups[0].multiply[0]
|
|
355
|
+
return ""
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@dataclass
|
|
359
|
+
class FilterParam:
|
|
360
|
+
"""Filter condition with left expression, operator, and optional right expression for expr-vs-expr comparisons."""
|
|
361
|
+
|
|
362
|
+
left_expr: NormalizedExpr = field(default_factory=NormalizedExpr)
|
|
363
|
+
op: str = "="
|
|
364
|
+
right_expr: NormalizedExpr | None = None
|
|
365
|
+
value_type: str = "string"
|
|
366
|
+
param_key: str = ""
|
|
367
|
+
raw_value: RawValue = None
|
|
368
|
+
bool_op: str = "AND"
|
|
369
|
+
filter_group: int | None = None
|
|
370
|
+
|
|
371
|
+
def __post_init__(self) -> None:
|
|
372
|
+
"""Normalize op, value_type, and bool_op, canonicalize ordering, and consolidate ExprValues to the value side."""
|
|
373
|
+
self.op = self.op.strip().lower()
|
|
374
|
+
self.value_type = self.value_type.strip().lower()
|
|
375
|
+
self.bool_op = self.bool_op.strip().upper() if self.bool_op else "AND"
|
|
376
|
+
if self.bool_op not in ("AND", "OR"):
|
|
377
|
+
self.bool_op = "AND"
|
|
378
|
+
if self.right_expr is not None:
|
|
379
|
+
if self.left_expr.signature_key > self.right_expr.signature_key:
|
|
380
|
+
left_has_col = "." in self.left_expr.signature_key
|
|
381
|
+
right_has_col = "." in self.right_expr.signature_key
|
|
382
|
+
if not (left_has_col and not right_has_col):
|
|
383
|
+
self.left_expr, self.right_expr = self.right_expr, self.left_expr
|
|
384
|
+
self.op = OP_FLIP.get(self.op, self.op)
|
|
385
|
+
for ev in self.left_expr.add_values:
|
|
386
|
+
self.right_expr.sub_values.append(ExprValue(value=ev.value, param_key=ev.param_key))
|
|
387
|
+
for ev in self.left_expr.sub_values:
|
|
388
|
+
self.right_expr.add_values.append(ExprValue(value=ev.value, param_key=ev.param_key))
|
|
389
|
+
self.left_expr.add_values = []
|
|
390
|
+
self.left_expr.sub_values = []
|
|
391
|
+
elif (
|
|
392
|
+
self.raw_value is not None
|
|
393
|
+
and isinstance(self.raw_value, int | float)
|
|
394
|
+
and not isinstance(self.raw_value, bool)
|
|
395
|
+
):
|
|
396
|
+
offset = sum(ev.value for ev in self.left_expr.add_values) - sum(
|
|
397
|
+
ev.value for ev in self.left_expr.sub_values
|
|
398
|
+
)
|
|
399
|
+
self.raw_value = self.raw_value - offset
|
|
400
|
+
self.left_expr.add_values = []
|
|
401
|
+
self.left_expr.sub_values = []
|
|
402
|
+
|
|
403
|
+
@staticmethod
|
|
404
|
+
def from_dict(d: dict[str, Any]) -> FilterParam:
|
|
405
|
+
"""Create FilterParam from dictionary.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
|
|
409
|
+
d: Dictionary with 'left_expr', 'op', optional 'right_expr', 'value_type', and 'param_key'.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
|
|
413
|
+
Populated FilterParam instance.
|
|
414
|
+
"""
|
|
415
|
+
left_raw = d.get("left_expr", {})
|
|
416
|
+
right_raw = d.get("right_expr")
|
|
417
|
+
fg_raw = d.get("filter_group")
|
|
418
|
+
return FilterParam(
|
|
419
|
+
left_expr=(NormalizedExpr.from_dict(left_raw) if isinstance(left_raw, dict) else left_raw),
|
|
420
|
+
op=d.get("op", "="),
|
|
421
|
+
right_expr=(NormalizedExpr.from_dict(right_raw) if isinstance(right_raw, dict) and right_raw else None),
|
|
422
|
+
value_type=d.get("value_type", "string"),
|
|
423
|
+
param_key=d.get("param_key", ""),
|
|
424
|
+
raw_value=d.get("value") or d.get("raw_value"),
|
|
425
|
+
bool_op=d.get("bool_op", "AND"),
|
|
426
|
+
filter_group=int(fg_raw) if fg_raw is not None else None,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def to_dict(self) -> dict[str, Any]:
|
|
430
|
+
"""Serialize to a plain dictionary.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
|
|
434
|
+
Dictionary with filter fields; raw_value is intentionally excluded.
|
|
435
|
+
"""
|
|
436
|
+
d: dict[str, Any] = {
|
|
437
|
+
"left_expr": self.left_expr.to_dict(),
|
|
438
|
+
"op": self.op,
|
|
439
|
+
"right_expr": self.right_expr.to_dict() if self.right_expr else None,
|
|
440
|
+
"value_type": self.value_type,
|
|
441
|
+
"param_key": self.param_key,
|
|
442
|
+
}
|
|
443
|
+
if self.bool_op != "AND":
|
|
444
|
+
d["bool_op"] = self.bool_op
|
|
445
|
+
if self.filter_group is not None:
|
|
446
|
+
d["filter_group"] = self.filter_group
|
|
447
|
+
return d
|
|
448
|
+
|
|
449
|
+
@property
|
|
450
|
+
def signature_key(self) -> str:
|
|
451
|
+
"""Return structural signature for template matching."""
|
|
452
|
+
parts = [self.left_expr.signature_key, self.op, self.value_type]
|
|
453
|
+
if self.right_expr:
|
|
454
|
+
parts.append(f"r:{self.right_expr.signature_key}")
|
|
455
|
+
return "|".join(parts)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@dataclass
|
|
459
|
+
class HavingParam:
|
|
460
|
+
"""Having condition with left expression, operator, and optional right expression for expr-vs-expr comparisons."""
|
|
461
|
+
|
|
462
|
+
left_expr: NormalizedExpr = field(default_factory=NormalizedExpr)
|
|
463
|
+
op: str = "="
|
|
464
|
+
right_expr: NormalizedExpr | None = None
|
|
465
|
+
value_type: str = "number"
|
|
466
|
+
param_key: str = ""
|
|
467
|
+
raw_value: RawValue = None
|
|
468
|
+
bool_op: str = "AND"
|
|
469
|
+
filter_group: int | None = None
|
|
470
|
+
|
|
471
|
+
def __post_init__(self) -> None:
|
|
472
|
+
"""Normalize op, value_type, and bool_op, canonicalize ordering, and consolidate ExprValues to the value side."""
|
|
473
|
+
self.op = self.op.strip().lower()
|
|
474
|
+
self.value_type = self.value_type.strip().lower()
|
|
475
|
+
self.bool_op = self.bool_op.strip().upper() if self.bool_op else "AND"
|
|
476
|
+
if self.bool_op not in ("AND", "OR"):
|
|
477
|
+
self.bool_op = "AND"
|
|
478
|
+
if self.right_expr is not None:
|
|
479
|
+
if self.left_expr.signature_key > self.right_expr.signature_key:
|
|
480
|
+
left_has_col = "." in self.left_expr.signature_key
|
|
481
|
+
right_has_col = "." in self.right_expr.signature_key
|
|
482
|
+
if not (left_has_col and not right_has_col):
|
|
483
|
+
self.left_expr, self.right_expr = self.right_expr, self.left_expr
|
|
484
|
+
self.op = OP_FLIP.get(self.op, self.op)
|
|
485
|
+
for ev in self.left_expr.add_values:
|
|
486
|
+
self.right_expr.sub_values.append(ExprValue(value=ev.value, param_key=ev.param_key))
|
|
487
|
+
for ev in self.left_expr.sub_values:
|
|
488
|
+
self.right_expr.add_values.append(ExprValue(value=ev.value, param_key=ev.param_key))
|
|
489
|
+
self.left_expr.add_values = []
|
|
490
|
+
self.left_expr.sub_values = []
|
|
491
|
+
elif (
|
|
492
|
+
self.raw_value is not None
|
|
493
|
+
and isinstance(self.raw_value, int | float)
|
|
494
|
+
and not isinstance(self.raw_value, bool)
|
|
495
|
+
):
|
|
496
|
+
offset = sum(ev.value for ev in self.left_expr.add_values) - sum(
|
|
497
|
+
ev.value for ev in self.left_expr.sub_values
|
|
498
|
+
)
|
|
499
|
+
self.raw_value = self.raw_value - offset
|
|
500
|
+
self.left_expr.add_values = []
|
|
501
|
+
self.left_expr.sub_values = []
|
|
502
|
+
|
|
503
|
+
@staticmethod
|
|
504
|
+
def from_dict(d: dict[str, Any]) -> HavingParam:
|
|
505
|
+
"""Create HavingParam from dictionary.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
|
|
509
|
+
d: Dictionary with 'left_expr', 'op', optional 'right_expr', 'value_type', and 'param_key'.
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
|
|
513
|
+
Populated HavingParam instance.
|
|
514
|
+
"""
|
|
515
|
+
left_raw = d.get("left_expr", {})
|
|
516
|
+
right_raw = d.get("right_expr")
|
|
517
|
+
fg_raw = d.get("filter_group")
|
|
518
|
+
return HavingParam(
|
|
519
|
+
left_expr=(NormalizedExpr.from_dict(left_raw) if isinstance(left_raw, dict) else left_raw),
|
|
520
|
+
op=d.get("op", "="),
|
|
521
|
+
right_expr=(NormalizedExpr.from_dict(right_raw) if isinstance(right_raw, dict) and right_raw else None),
|
|
522
|
+
value_type=d.get("value_type", "number"),
|
|
523
|
+
param_key=d.get("param_key", ""),
|
|
524
|
+
raw_value=d.get("value") or d.get("raw_value"),
|
|
525
|
+
bool_op=d.get("bool_op", "AND"),
|
|
526
|
+
filter_group=int(fg_raw) if fg_raw is not None else None,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
def to_dict(self) -> dict[str, Any]:
|
|
530
|
+
"""Convert to dictionary."""
|
|
531
|
+
d: dict[str, Any] = {
|
|
532
|
+
"left_expr": self.left_expr.to_dict(),
|
|
533
|
+
"op": self.op,
|
|
534
|
+
"right_expr": self.right_expr.to_dict() if self.right_expr else None,
|
|
535
|
+
"value_type": self.value_type,
|
|
536
|
+
"param_key": self.param_key,
|
|
537
|
+
}
|
|
538
|
+
if self.bool_op != "AND":
|
|
539
|
+
d["bool_op"] = self.bool_op
|
|
540
|
+
if self.filter_group is not None:
|
|
541
|
+
d["filter_group"] = self.filter_group
|
|
542
|
+
return d
|
|
543
|
+
|
|
544
|
+
@property
|
|
545
|
+
def signature_key(self) -> str:
|
|
546
|
+
"""Return structural signature for template matching."""
|
|
547
|
+
parts = [self.left_expr.signature_key, self.op, self.value_type]
|
|
548
|
+
if self.right_expr:
|
|
549
|
+
parts.append(f"r:{self.right_expr.signature_key}")
|
|
550
|
+
return "|".join(parts)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@dataclass
|
|
554
|
+
class SelectCol:
|
|
555
|
+
"""Select column with normalized expression."""
|
|
556
|
+
|
|
557
|
+
expr: NormalizedExpr = field(default_factory=NormalizedExpr)
|
|
558
|
+
|
|
559
|
+
@staticmethod
|
|
560
|
+
def from_dict(d: dict[str, Any]) -> SelectCol:
|
|
561
|
+
"""Create SelectCol from dictionary.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
|
|
565
|
+
d: Dictionary with an 'expr' key containing a NormalizedExpr dict.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
|
|
569
|
+
Populated SelectCol instance.
|
|
570
|
+
"""
|
|
571
|
+
expr_raw = d.get("expr", {})
|
|
572
|
+
return SelectCol(
|
|
573
|
+
expr=(NormalizedExpr.from_dict(expr_raw) if isinstance(expr_raw, dict) else expr_raw),
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def to_dict(self) -> dict[str, Any]:
|
|
577
|
+
"""Serialize to a plain dictionary.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
|
|
581
|
+
Dictionary with the serialized expr.
|
|
582
|
+
"""
|
|
583
|
+
return {"expr": self.expr.to_dict()}
|
|
584
|
+
|
|
585
|
+
@property
|
|
586
|
+
def is_aggregated(self) -> bool:
|
|
587
|
+
"""Return True if expression contains an aggregation function."""
|
|
588
|
+
return self.expr.has_aggregation
|
|
589
|
+
|
|
590
|
+
@property
|
|
591
|
+
def signature_key(self) -> str:
|
|
592
|
+
"""Return structural signature for template matching."""
|
|
593
|
+
return self.expr.signature_key
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
@dataclass
|
|
597
|
+
class OrderByCol:
|
|
598
|
+
"""Order by column with expression and sort direction."""
|
|
599
|
+
|
|
600
|
+
expr: NormalizedExpr = field(default_factory=NormalizedExpr)
|
|
601
|
+
direction: str = "ASC"
|
|
602
|
+
|
|
603
|
+
def __post_init__(self) -> None:
|
|
604
|
+
"""Uppercase direction for consistent comparison."""
|
|
605
|
+
self.direction = self.direction.strip().upper()
|
|
606
|
+
|
|
607
|
+
@staticmethod
|
|
608
|
+
def from_dict(d: dict[str, Any]) -> OrderByCol:
|
|
609
|
+
"""Create OrderByCol from dictionary.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
|
|
613
|
+
d: Dictionary with 'expr' and 'direction' keys.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
|
|
617
|
+
Populated OrderByCol instance.
|
|
618
|
+
"""
|
|
619
|
+
expr_raw = d.get("expr", {})
|
|
620
|
+
return OrderByCol(
|
|
621
|
+
expr=(NormalizedExpr.from_dict(expr_raw) if isinstance(expr_raw, dict) else expr_raw),
|
|
622
|
+
direction=d.get("direction", "ASC"),
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def to_dict(self) -> dict[str, Any]:
|
|
626
|
+
"""Serialize to a plain dictionary.
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
|
|
630
|
+
Dictionary with the serialized expr and direction string.
|
|
631
|
+
"""
|
|
632
|
+
return {"expr": self.expr.to_dict(), "direction": self.direction}
|
|
633
|
+
|
|
634
|
+
@property
|
|
635
|
+
def is_aggregated(self) -> bool:
|
|
636
|
+
"""Return True if expression contains an aggregation function."""
|
|
637
|
+
return self.expr.has_aggregation
|
|
638
|
+
|
|
639
|
+
@property
|
|
640
|
+
def signature_key(self) -> str:
|
|
641
|
+
"""Return structural signature for template matching."""
|
|
642
|
+
return "|".join([self.expr.signature_key, self.direction])
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
@dataclass
|
|
646
|
+
class ConcreteCteStep:
|
|
647
|
+
"""CTE step structural signature for template storage."""
|
|
648
|
+
|
|
649
|
+
cte_name: str
|
|
650
|
+
tables: list[str] = field(default_factory=list)
|
|
651
|
+
select_cols: list[SelectCol] = field(default_factory=list)
|
|
652
|
+
group_by_cols: list[NormalizedExpr] = field(default_factory=list)
|
|
653
|
+
order_by_cols: list[OrderByCol] = field(default_factory=list)
|
|
654
|
+
filters_param: list[FilterParam] = field(default_factory=list)
|
|
655
|
+
having_param: list[HavingParam] = field(default_factory=list)
|
|
656
|
+
output_columns: list[str] = field(default_factory=list)
|
|
657
|
+
grain: str = "row_level"
|
|
658
|
+
limit: int | None = None
|
|
659
|
+
column_map: dict[str, str] = field(default_factory=dict)
|
|
660
|
+
output_column_metadata: dict[str, CteOutputColumnMeta] = field(default_factory=dict)
|
|
661
|
+
chosen_join_candidate_id: str = ""
|
|
662
|
+
chosen_join_path_signature: list[str] = field(default_factory=list)
|
|
663
|
+
|
|
664
|
+
@staticmethod
|
|
665
|
+
def from_dict(d: dict[str, Any]) -> ConcreteCteStep:
|
|
666
|
+
"""Create ConcreteCteStep from dictionary.
|
|
667
|
+
|
|
668
|
+
Args:
|
|
669
|
+
|
|
670
|
+
d: Dictionary with keys matching ConcreteCteStep fields.
|
|
671
|
+
|
|
672
|
+
Returns:
|
|
673
|
+
|
|
674
|
+
Populated ConcreteCteStep with nested expression objects.
|
|
675
|
+
"""
|
|
676
|
+
sc_raw = d.get("select_cols", [])
|
|
677
|
+
gbc_raw = d.get("group_by_cols", [])
|
|
678
|
+
obc_raw = d.get("order_by_cols", [])
|
|
679
|
+
fp_raw = d.get("filters_param", [])
|
|
680
|
+
hp_raw = d.get("having_param", [])
|
|
681
|
+
ocm_raw = d.get("output_column_metadata", {})
|
|
682
|
+
select_cols = [
|
|
683
|
+
(
|
|
684
|
+
SelectCol.from_dict(s)
|
|
685
|
+
if isinstance(s, dict)
|
|
686
|
+
else (SelectCol(expr=NormalizedExpr.from_column(s)) if isinstance(s, str) else s)
|
|
687
|
+
)
|
|
688
|
+
for s in sc_raw
|
|
689
|
+
]
|
|
690
|
+
group_by_cols = [
|
|
691
|
+
(
|
|
692
|
+
NormalizedExpr.from_dict(g)
|
|
693
|
+
if isinstance(g, dict)
|
|
694
|
+
else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
|
|
695
|
+
)
|
|
696
|
+
for g in gbc_raw
|
|
697
|
+
]
|
|
698
|
+
order_by_cols = [
|
|
699
|
+
(
|
|
700
|
+
OrderByCol.from_dict(o)
|
|
701
|
+
if isinstance(o, dict)
|
|
702
|
+
else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
|
|
703
|
+
)
|
|
704
|
+
for o in obc_raw
|
|
705
|
+
]
|
|
706
|
+
return ConcreteCteStep(
|
|
707
|
+
cte_name=d.get("cte_name", ""),
|
|
708
|
+
tables=d.get("tables", []),
|
|
709
|
+
select_cols=select_cols,
|
|
710
|
+
group_by_cols=group_by_cols,
|
|
711
|
+
order_by_cols=order_by_cols,
|
|
712
|
+
filters_param=[FilterParam.from_dict(f) if isinstance(f, dict) else f for f in fp_raw],
|
|
713
|
+
having_param=[HavingParam.from_dict(h) if isinstance(h, dict) else h for h in hp_raw],
|
|
714
|
+
output_columns=d.get("output_columns", []),
|
|
715
|
+
grain=d.get("grain", "row_level"),
|
|
716
|
+
limit=d.get("limit"),
|
|
717
|
+
column_map=d.get("column_map", {}),
|
|
718
|
+
output_column_metadata={
|
|
719
|
+
k: CteOutputColumnMeta.from_dict(v) if isinstance(v, dict) else v for k, v in ocm_raw.items()
|
|
720
|
+
},
|
|
721
|
+
chosen_join_candidate_id=d.get("chosen_join_candidate_id", ""),
|
|
722
|
+
chosen_join_path_signature=d.get("chosen_join_path_signature", []),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def to_dict(self) -> dict[str, Any]:
|
|
726
|
+
"""Serialize to a plain dictionary.
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
|
|
730
|
+
Dictionary with all ConcreteCteStep fields, with nested expressions serialized.
|
|
731
|
+
"""
|
|
732
|
+
return {
|
|
733
|
+
"cte_name": self.cte_name,
|
|
734
|
+
"tables": self.tables,
|
|
735
|
+
"select_cols": [s.to_dict() for s in self.select_cols],
|
|
736
|
+
"group_by_cols": [g.to_dict() for g in self.group_by_cols],
|
|
737
|
+
"order_by_cols": [o.to_dict() for o in self.order_by_cols],
|
|
738
|
+
"filters_param": [f.to_dict() for f in self.filters_param],
|
|
739
|
+
"having_param": [h.to_dict() for h in self.having_param],
|
|
740
|
+
"output_columns": self.output_columns,
|
|
741
|
+
"grain": self.grain,
|
|
742
|
+
"limit": self.limit,
|
|
743
|
+
"column_map": self.column_map,
|
|
744
|
+
"output_column_metadata": {k: v.to_dict() for k, v in self.output_column_metadata.items()},
|
|
745
|
+
"chosen_join_candidate_id": self.chosen_join_candidate_id,
|
|
746
|
+
"chosen_join_path_signature": self.chosen_join_path_signature,
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
@dataclass
|
|
751
|
+
class RuntimeCteStep:
|
|
752
|
+
"""CTE step specification for WITH clause queries with runtime values."""
|
|
753
|
+
|
|
754
|
+
cte_name: str
|
|
755
|
+
description: str = ""
|
|
756
|
+
tables: list[str] = field(default_factory=list)
|
|
757
|
+
select_cols: list[SelectCol] = field(default_factory=list)
|
|
758
|
+
group_by_cols: list[NormalizedExpr] = field(default_factory=list)
|
|
759
|
+
order_by_cols: list[OrderByCol] = field(default_factory=list)
|
|
760
|
+
filters_param: list[FilterParam] = field(default_factory=list)
|
|
761
|
+
having_param: list[HavingParam] = field(default_factory=list)
|
|
762
|
+
param_values: dict[str, ParamValue] = field(default_factory=dict)
|
|
763
|
+
output_columns: list[str] = field(default_factory=list)
|
|
764
|
+
grain: str = "row_level"
|
|
765
|
+
limit: int | None = None
|
|
766
|
+
limit_param_key: str = ""
|
|
767
|
+
column_map: dict[str, str] = field(default_factory=dict)
|
|
768
|
+
output_column_metadata: dict[str, CteOutputColumnMeta] = field(default_factory=dict)
|
|
769
|
+
chosen_join_candidate_id: str = ""
|
|
770
|
+
chosen_join_path_signature: list[str] = field(default_factory=list)
|
|
771
|
+
|
|
772
|
+
@staticmethod
|
|
773
|
+
def from_dict(d: dict[str, Any]) -> RuntimeCteStep:
|
|
774
|
+
"""Create RuntimeCteStep from dictionary.
|
|
775
|
+
|
|
776
|
+
Args:
|
|
777
|
+
|
|
778
|
+
d: Dictionary with keys matching RuntimeCteStep fields.
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
|
|
782
|
+
Populated RuntimeCteStep with nested expression objects and runtime param_values.
|
|
783
|
+
"""
|
|
784
|
+
sc_raw = d.get("select_cols", [])
|
|
785
|
+
gbc_raw = d.get("group_by_cols", [])
|
|
786
|
+
obc_raw = d.get("order_by_cols", [])
|
|
787
|
+
fp_raw = d.get("filters_param", [])
|
|
788
|
+
hp_raw = d.get("having_param", [])
|
|
789
|
+
ocm_raw = d.get("output_column_metadata", {})
|
|
790
|
+
select_cols = [
|
|
791
|
+
(
|
|
792
|
+
SelectCol.from_dict(s)
|
|
793
|
+
if isinstance(s, dict)
|
|
794
|
+
else (SelectCol(expr=NormalizedExpr.from_column(s)) if isinstance(s, str) else s)
|
|
795
|
+
)
|
|
796
|
+
for s in sc_raw
|
|
797
|
+
]
|
|
798
|
+
group_by_cols = [
|
|
799
|
+
(
|
|
800
|
+
NormalizedExpr.from_dict(g)
|
|
801
|
+
if isinstance(g, dict)
|
|
802
|
+
else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
|
|
803
|
+
)
|
|
804
|
+
for g in gbc_raw
|
|
805
|
+
]
|
|
806
|
+
order_by_cols = [
|
|
807
|
+
(
|
|
808
|
+
OrderByCol.from_dict(o)
|
|
809
|
+
if isinstance(o, dict)
|
|
810
|
+
else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
|
|
811
|
+
)
|
|
812
|
+
for o in obc_raw
|
|
813
|
+
]
|
|
814
|
+
return RuntimeCteStep(
|
|
815
|
+
cte_name=d.get("cte_name", ""),
|
|
816
|
+
description=d.get("description", ""),
|
|
817
|
+
tables=d.get("tables", []),
|
|
818
|
+
select_cols=select_cols,
|
|
819
|
+
group_by_cols=group_by_cols,
|
|
820
|
+
order_by_cols=order_by_cols,
|
|
821
|
+
filters_param=[FilterParam.from_dict(f) if isinstance(f, dict) else f for f in fp_raw],
|
|
822
|
+
having_param=[HavingParam.from_dict(h) if isinstance(h, dict) else h for h in hp_raw],
|
|
823
|
+
param_values=d.get("param_values", {}),
|
|
824
|
+
output_columns=d.get("output_columns", []),
|
|
825
|
+
grain=d.get("grain", "row_level"),
|
|
826
|
+
limit=d.get("limit"),
|
|
827
|
+
limit_param_key=d.get("limit_param_key", ""),
|
|
828
|
+
column_map=d.get("column_map", {}),
|
|
829
|
+
output_column_metadata={
|
|
830
|
+
k: CteOutputColumnMeta.from_dict(v) if isinstance(v, dict) else v for k, v in ocm_raw.items()
|
|
831
|
+
},
|
|
832
|
+
chosen_join_candidate_id=d.get("chosen_join_candidate_id", ""),
|
|
833
|
+
chosen_join_path_signature=d.get("chosen_join_path_signature", []),
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
def to_dict(self) -> dict[str, Any]:
|
|
837
|
+
"""Serialize to a plain dictionary.
|
|
838
|
+
|
|
839
|
+
Returns:
|
|
840
|
+
|
|
841
|
+
Dictionary with all RuntimeCteStep fields including param_values and limit_param_key.
|
|
842
|
+
"""
|
|
843
|
+
return {
|
|
844
|
+
"cte_name": self.cte_name,
|
|
845
|
+
"description": self.description,
|
|
846
|
+
"tables": self.tables,
|
|
847
|
+
"select_cols": [s.to_dict() for s in self.select_cols],
|
|
848
|
+
"group_by_cols": [g.to_dict() for g in self.group_by_cols],
|
|
849
|
+
"order_by_cols": [o.to_dict() for o in self.order_by_cols],
|
|
850
|
+
"filters_param": [f.to_dict() for f in self.filters_param],
|
|
851
|
+
"having_param": [h.to_dict() for h in self.having_param],
|
|
852
|
+
"param_values": self.param_values,
|
|
853
|
+
"output_columns": self.output_columns,
|
|
854
|
+
"grain": self.grain,
|
|
855
|
+
"limit": self.limit,
|
|
856
|
+
"limit_param_key": self.limit_param_key,
|
|
857
|
+
"column_map": self.column_map,
|
|
858
|
+
"output_column_metadata": {k: v.to_dict() for k, v in self.output_column_metadata.items()},
|
|
859
|
+
"chosen_join_candidate_id": self.chosen_join_candidate_id,
|
|
860
|
+
"chosen_join_path_signature": self.chosen_join_path_signature,
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
@property
|
|
864
|
+
def expected_rows(self) -> str:
|
|
865
|
+
"""Derive expected_rows from grain and limit."""
|
|
866
|
+
if self.grain == "scalar":
|
|
867
|
+
return "one"
|
|
868
|
+
return "few" if self.limit else "many"
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
@dataclass
|
|
872
|
+
class CteValidationResult:
|
|
873
|
+
"""Result container for CTE chain validation."""
|
|
874
|
+
|
|
875
|
+
is_valid: bool
|
|
876
|
+
cte_outputs: dict[str, list[str]]
|
|
877
|
+
issues: list[IntentIssue] = field(default_factory=list)
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
def _runtime_cte_to_concrete(runtime: RuntimeCteStep) -> ConcreteCteStep:
|
|
881
|
+
"""Convert RuntimeCteStep to ConcreteCteStep for template storage.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
|
|
885
|
+
runtime: The runtime CTE step containing values and expressions.
|
|
886
|
+
|
|
887
|
+
Returns:
|
|
888
|
+
|
|
889
|
+
ConcreteCteStep with runtime-only fields (param_values and limit_param_key) stripped.
|
|
890
|
+
"""
|
|
891
|
+
return ConcreteCteStep(
|
|
892
|
+
cte_name=runtime.cte_name,
|
|
893
|
+
tables=runtime.tables,
|
|
894
|
+
select_cols=runtime.select_cols,
|
|
895
|
+
group_by_cols=runtime.group_by_cols,
|
|
896
|
+
order_by_cols=runtime.order_by_cols,
|
|
897
|
+
filters_param=runtime.filters_param,
|
|
898
|
+
having_param=runtime.having_param,
|
|
899
|
+
output_columns=runtime.output_columns,
|
|
900
|
+
grain=runtime.grain,
|
|
901
|
+
limit=runtime.limit,
|
|
902
|
+
column_map=runtime.column_map,
|
|
903
|
+
output_column_metadata=runtime.output_column_metadata,
|
|
904
|
+
chosen_join_candidate_id=runtime.chosen_join_candidate_id,
|
|
905
|
+
chosen_join_path_signature=runtime.chosen_join_path_signature,
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def concrete_cte_to_runtime(concrete: ConcreteCteStep) -> RuntimeCteStep:
|
|
910
|
+
"""Convert ConcreteCteStep to RuntimeCteStep for pipeline execution.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
|
|
914
|
+
concrete: The stored concrete CTE step from a template.
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
|
|
918
|
+
RuntimeCteStep with blank description and empty param_values ready for execution.
|
|
919
|
+
"""
|
|
920
|
+
return RuntimeCteStep(
|
|
921
|
+
cte_name=concrete.cte_name,
|
|
922
|
+
description="",
|
|
923
|
+
tables=concrete.tables,
|
|
924
|
+
select_cols=concrete.select_cols,
|
|
925
|
+
group_by_cols=concrete.group_by_cols,
|
|
926
|
+
order_by_cols=concrete.order_by_cols,
|
|
927
|
+
filters_param=concrete.filters_param,
|
|
928
|
+
having_param=concrete.having_param,
|
|
929
|
+
param_values={},
|
|
930
|
+
output_columns=concrete.output_columns,
|
|
931
|
+
grain=concrete.grain,
|
|
932
|
+
limit=concrete.limit,
|
|
933
|
+
column_map=concrete.column_map,
|
|
934
|
+
output_column_metadata=concrete.output_column_metadata,
|
|
935
|
+
chosen_join_candidate_id=concrete.chosen_join_candidate_id,
|
|
936
|
+
chosen_join_path_signature=concrete.chosen_join_path_signature,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
@dataclass
|
|
941
|
+
class RuntimeIntent:
|
|
942
|
+
"""Runtime intent container for pipeline execution with structural fields and values."""
|
|
943
|
+
|
|
944
|
+
tables: list[str]
|
|
945
|
+
grain: str
|
|
946
|
+
select_cols: list[SelectCol]
|
|
947
|
+
group_by_cols: list[NormalizedExpr]
|
|
948
|
+
order_by_cols: list[OrderByCol]
|
|
949
|
+
filters_param: list[FilterParam]
|
|
950
|
+
having_param: list[HavingParam] = field(default_factory=list)
|
|
951
|
+
param_values: dict[str, ParamValue] = field(default_factory=dict)
|
|
952
|
+
cte_steps: list[RuntimeCteStep] = field(default_factory=list)
|
|
953
|
+
natural_language: str = ""
|
|
954
|
+
limit: int | None = None
|
|
955
|
+
limit_param_key: str = ""
|
|
956
|
+
column_map: dict[str, str] = field(default_factory=dict)
|
|
957
|
+
chosen_join_candidate_id: str = ""
|
|
958
|
+
chosen_join_path_signature: list[str] = field(default_factory=list)
|
|
959
|
+
extra_tables: set[str] = field(default_factory=set)
|
|
960
|
+
sql_param: str = ""
|
|
961
|
+
sql_display_param: str = ""
|
|
962
|
+
sql_substituted: str = ""
|
|
963
|
+
deterministic_sql: str = ""
|
|
964
|
+
sql_shape: SQLShape | None = None
|
|
965
|
+
schema_invalid: bool = False
|
|
966
|
+
|
|
967
|
+
@property
|
|
968
|
+
def expected_rows(self) -> str:
|
|
969
|
+
"""Derive expected_rows from grain and limit."""
|
|
970
|
+
if self.grain == "scalar":
|
|
971
|
+
return "one"
|
|
972
|
+
return "few" if self.limit else "many"
|
|
973
|
+
|
|
974
|
+
@staticmethod
|
|
975
|
+
def from_dict(d: dict[str, Any]) -> RuntimeIntent:
|
|
976
|
+
"""Create RuntimeIntent from dictionary.
|
|
977
|
+
|
|
978
|
+
Args:
|
|
979
|
+
|
|
980
|
+
d: Dictionary with keys matching RuntimeIntent fields, typically from JSON storage.
|
|
981
|
+
|
|
982
|
+
Returns:
|
|
983
|
+
|
|
984
|
+
Populated RuntimeIntent with all nested objects deserialized.
|
|
985
|
+
"""
|
|
986
|
+
sc_raw = d.get("select_cols", [])
|
|
987
|
+
gbc_raw = d.get("group_by_cols", [])
|
|
988
|
+
obc_raw = d.get("order_by_cols", [])
|
|
989
|
+
fp_raw = d.get("filters_param", [])
|
|
990
|
+
hp_raw = d.get("having_param", [])
|
|
991
|
+
cte_raw = d.get("cte_steps", [])
|
|
992
|
+
join_sig_raw = d.get("chosen_join_path_signature", [])
|
|
993
|
+
if isinstance(join_sig_raw, str):
|
|
994
|
+
join_sig_raw = [join_sig_raw] if join_sig_raw else []
|
|
995
|
+
select_cols = [
|
|
996
|
+
(
|
|
997
|
+
SelectCol.from_dict(s)
|
|
998
|
+
if isinstance(s, dict)
|
|
999
|
+
else (SelectCol(expr=NormalizedExpr.from_column(s)) if isinstance(s, str) else s)
|
|
1000
|
+
)
|
|
1001
|
+
for s in sc_raw
|
|
1002
|
+
]
|
|
1003
|
+
group_by_cols = [
|
|
1004
|
+
(
|
|
1005
|
+
NormalizedExpr.from_dict(g)
|
|
1006
|
+
if isinstance(g, dict)
|
|
1007
|
+
else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
|
|
1008
|
+
)
|
|
1009
|
+
for g in gbc_raw
|
|
1010
|
+
]
|
|
1011
|
+
order_by_cols = [
|
|
1012
|
+
(
|
|
1013
|
+
OrderByCol.from_dict(o)
|
|
1014
|
+
if isinstance(o, dict)
|
|
1015
|
+
else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
|
|
1016
|
+
)
|
|
1017
|
+
for o in obc_raw
|
|
1018
|
+
]
|
|
1019
|
+
return RuntimeIntent(
|
|
1020
|
+
tables=d.get("tables", []),
|
|
1021
|
+
grain=d.get("grain", "row_level"),
|
|
1022
|
+
select_cols=select_cols,
|
|
1023
|
+
group_by_cols=group_by_cols,
|
|
1024
|
+
order_by_cols=order_by_cols,
|
|
1025
|
+
filters_param=[FilterParam.from_dict(fp) if isinstance(fp, dict) else fp for fp in fp_raw],
|
|
1026
|
+
having_param=[HavingParam.from_dict(hp) if isinstance(hp, dict) else hp for hp in hp_raw],
|
|
1027
|
+
param_values=d.get("param_values", {}),
|
|
1028
|
+
cte_steps=[RuntimeCteStep.from_dict(cte) if isinstance(cte, dict) else cte for cte in cte_raw],
|
|
1029
|
+
natural_language=d.get("natural_language", ""),
|
|
1030
|
+
limit=d.get("limit"),
|
|
1031
|
+
limit_param_key=d.get("limit_param_key", ""),
|
|
1032
|
+
column_map=d.get("column_map", {}),
|
|
1033
|
+
chosen_join_candidate_id=d.get("chosen_join_candidate_id", ""),
|
|
1034
|
+
chosen_join_path_signature=join_sig_raw,
|
|
1035
|
+
extra_tables=set(d.get("extra_tables", [])),
|
|
1036
|
+
sql_param=d.get("sql_param", ""),
|
|
1037
|
+
sql_display_param=d.get("sql_display_param", ""),
|
|
1038
|
+
sql_substituted=d.get("sql_substituted", ""),
|
|
1039
|
+
deterministic_sql=d.get("deterministic_sql", ""),
|
|
1040
|
+
sql_shape=(SQLShape.from_dict(d["sql_shape"]) if d.get("sql_shape") else None),
|
|
1041
|
+
schema_invalid=d.get("schema_invalid", False),
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1045
|
+
"""Serialize to a plain dictionary.
|
|
1046
|
+
|
|
1047
|
+
Returns:
|
|
1048
|
+
|
|
1049
|
+
Dictionary with all RuntimeIntent fields, with nested objects serialized recursively.
|
|
1050
|
+
"""
|
|
1051
|
+
return {
|
|
1052
|
+
"tables": self.tables,
|
|
1053
|
+
"grain": self.grain,
|
|
1054
|
+
"expected_rows": self.expected_rows,
|
|
1055
|
+
"select_cols": [s.to_dict() for s in self.select_cols],
|
|
1056
|
+
"group_by_cols": [g.to_dict() for g in self.group_by_cols],
|
|
1057
|
+
"order_by_cols": [o.to_dict() for o in self.order_by_cols],
|
|
1058
|
+
"filters_param": [fp.to_dict() for fp in self.filters_param],
|
|
1059
|
+
"having_param": [hp.to_dict() for hp in self.having_param],
|
|
1060
|
+
"param_values": self.param_values,
|
|
1061
|
+
"cte_steps": [cte.to_dict() for cte in self.cte_steps],
|
|
1062
|
+
"natural_language": self.natural_language,
|
|
1063
|
+
"limit": self.limit,
|
|
1064
|
+
"limit_param_key": self.limit_param_key,
|
|
1065
|
+
"column_map": self.column_map,
|
|
1066
|
+
"chosen_join_candidate_id": self.chosen_join_candidate_id,
|
|
1067
|
+
"chosen_join_path_signature": self.chosen_join_path_signature,
|
|
1068
|
+
"extra_tables": sorted(self.extra_tables),
|
|
1069
|
+
"sql_param": self.sql_param,
|
|
1070
|
+
"sql_display_param": self.sql_display_param,
|
|
1071
|
+
"sql_substituted": self.sql_substituted,
|
|
1072
|
+
"deterministic_sql": self.deterministic_sql,
|
|
1073
|
+
"sql_shape": self.sql_shape.to_dict() if self.sql_shape else None,
|
|
1074
|
+
"schema_invalid": self.schema_invalid,
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
@property
|
|
1078
|
+
def has_aggregation(self) -> bool:
|
|
1079
|
+
"""Return True if any select column uses aggregation."""
|
|
1080
|
+
return any(s.is_aggregated for s in self.select_cols)
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
@dataclass
|
|
1084
|
+
class ConcreteIntent:
|
|
1085
|
+
"""Structural intent for template storage without values or natural language."""
|
|
1086
|
+
|
|
1087
|
+
intent_id: str
|
|
1088
|
+
tables: list[str]
|
|
1089
|
+
grain: str
|
|
1090
|
+
select_cols: list[SelectCol]
|
|
1091
|
+
group_by_cols: list[NormalizedExpr]
|
|
1092
|
+
order_by_cols: list[OrderByCol]
|
|
1093
|
+
filters_param: list[FilterParam]
|
|
1094
|
+
having_param: list[HavingParam] = field(default_factory=list)
|
|
1095
|
+
cte_steps: list[ConcreteCteStep] = field(default_factory=list)
|
|
1096
|
+
limit: int | None = None
|
|
1097
|
+
column_map: dict[str, str] = field(default_factory=dict)
|
|
1098
|
+
chosen_join_candidate_id: str = ""
|
|
1099
|
+
chosen_join_path_signature: list[str] = field(default_factory=list)
|
|
1100
|
+
|
|
1101
|
+
@staticmethod
|
|
1102
|
+
def from_dict(d: dict[str, Any]) -> ConcreteIntent:
|
|
1103
|
+
"""Create ConcreteIntent from dictionary.
|
|
1104
|
+
|
|
1105
|
+
Args:
|
|
1106
|
+
|
|
1107
|
+
d: Dictionary with keys matching ConcreteIntent fields.
|
|
1108
|
+
|
|
1109
|
+
Returns:
|
|
1110
|
+
|
|
1111
|
+
Populated ConcreteIntent with nested expression objects.
|
|
1112
|
+
"""
|
|
1113
|
+
sc_raw = d.get("select_cols", [])
|
|
1114
|
+
gbc_raw = d.get("group_by_cols", [])
|
|
1115
|
+
obc_raw = d.get("order_by_cols", [])
|
|
1116
|
+
fp_raw = d.get("filters_param", [])
|
|
1117
|
+
hp_raw = d.get("having_param", [])
|
|
1118
|
+
cte_raw = d.get("cte_steps", [])
|
|
1119
|
+
select_cols = [
|
|
1120
|
+
(
|
|
1121
|
+
SelectCol.from_dict(s)
|
|
1122
|
+
if isinstance(s, dict)
|
|
1123
|
+
else (SelectCol(expr=NormalizedExpr.from_column(s)) if isinstance(s, str) else s)
|
|
1124
|
+
)
|
|
1125
|
+
for s in sc_raw
|
|
1126
|
+
]
|
|
1127
|
+
group_by_cols = [
|
|
1128
|
+
(
|
|
1129
|
+
NormalizedExpr.from_dict(g)
|
|
1130
|
+
if isinstance(g, dict)
|
|
1131
|
+
else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
|
|
1132
|
+
)
|
|
1133
|
+
for g in gbc_raw
|
|
1134
|
+
]
|
|
1135
|
+
order_by_cols = [
|
|
1136
|
+
(
|
|
1137
|
+
OrderByCol.from_dict(o)
|
|
1138
|
+
if isinstance(o, dict)
|
|
1139
|
+
else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
|
|
1140
|
+
)
|
|
1141
|
+
for o in obc_raw
|
|
1142
|
+
]
|
|
1143
|
+
return ConcreteIntent(
|
|
1144
|
+
intent_id=d.get("intent_id", ""),
|
|
1145
|
+
tables=d.get("tables", []),
|
|
1146
|
+
grain=d.get("grain", "row_level"),
|
|
1147
|
+
select_cols=select_cols,
|
|
1148
|
+
group_by_cols=group_by_cols,
|
|
1149
|
+
order_by_cols=order_by_cols,
|
|
1150
|
+
filters_param=[FilterParam.from_dict(fp) if isinstance(fp, dict) else fp for fp in fp_raw],
|
|
1151
|
+
having_param=[HavingParam.from_dict(hp) if isinstance(hp, dict) else hp for hp in hp_raw],
|
|
1152
|
+
cte_steps=[ConcreteCteStep.from_dict(cte) if isinstance(cte, dict) else cte for cte in cte_raw],
|
|
1153
|
+
limit=d.get("limit"),
|
|
1154
|
+
column_map=d.get("column_map", {}),
|
|
1155
|
+
chosen_join_candidate_id=d.get("chosen_join_candidate_id", ""),
|
|
1156
|
+
chosen_join_path_signature=d.get("chosen_join_path_signature", []),
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1160
|
+
"""Serialize to a plain dictionary.
|
|
1161
|
+
|
|
1162
|
+
Returns:
|
|
1163
|
+
|
|
1164
|
+
Dictionary with all ConcreteIntent fields, with nested expressions serialized.
|
|
1165
|
+
"""
|
|
1166
|
+
return {
|
|
1167
|
+
"intent_id": self.intent_id,
|
|
1168
|
+
"tables": self.tables,
|
|
1169
|
+
"grain": self.grain,
|
|
1170
|
+
"select_cols": [s.to_dict() for s in self.select_cols],
|
|
1171
|
+
"group_by_cols": [g.to_dict() for g in self.group_by_cols],
|
|
1172
|
+
"order_by_cols": [o.to_dict() for o in self.order_by_cols],
|
|
1173
|
+
"filters_param": [fp.to_dict() for fp in self.filters_param],
|
|
1174
|
+
"having_param": [hp.to_dict() for hp in self.having_param],
|
|
1175
|
+
"cte_steps": [cte.to_dict() for cte in self.cte_steps],
|
|
1176
|
+
"limit": self.limit,
|
|
1177
|
+
"column_map": self.column_map,
|
|
1178
|
+
"chosen_join_candidate_id": self.chosen_join_candidate_id,
|
|
1179
|
+
"chosen_join_path_signature": self.chosen_join_path_signature,
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
|
|
1183
|
+
@dataclass
|
|
1184
|
+
class SimulatorIntent:
|
|
1185
|
+
"""Unified intent for Simulator with inline values and expansion metadata."""
|
|
1186
|
+
|
|
1187
|
+
intent_id: str
|
|
1188
|
+
tables: list[str]
|
|
1189
|
+
grain: str
|
|
1190
|
+
select_cols: list[SelectCol]
|
|
1191
|
+
group_by_cols: list[NormalizedExpr]
|
|
1192
|
+
order_by_cols: list[OrderByCol]
|
|
1193
|
+
filters_param: list[FilterParam]
|
|
1194
|
+
having_param: list[HavingParam]
|
|
1195
|
+
param_values: dict[str, ParamValue] = field(default_factory=dict)
|
|
1196
|
+
cte_steps: list[RuntimeCteStep] = field(default_factory=list)
|
|
1197
|
+
question: str = ""
|
|
1198
|
+
expansion_metadata: ExpansionMetadata | None = None
|
|
1199
|
+
limit: int | None = None
|
|
1200
|
+
|
|
1201
|
+
@staticmethod
|
|
1202
|
+
def from_dict(d: dict[str, Any]) -> SimulatorIntent:
|
|
1203
|
+
"""Create SimulatorIntent from dictionary.
|
|
1204
|
+
|
|
1205
|
+
Args:
|
|
1206
|
+
|
|
1207
|
+
d: Dictionary with keys matching SimulatorIntent fields.
|
|
1208
|
+
|
|
1209
|
+
Returns:
|
|
1210
|
+
|
|
1211
|
+
Populated SimulatorIntent with nested expression and metadata objects.
|
|
1212
|
+
"""
|
|
1213
|
+
sc_raw = d.get("select_cols", [])
|
|
1214
|
+
gbc_raw = d.get("group_by_cols", [])
|
|
1215
|
+
obc_raw = d.get("order_by_cols", [])
|
|
1216
|
+
fp_raw = d.get("filters_param", d.get("filters", []))
|
|
1217
|
+
hp_raw = d.get("having_param", d.get("having", []))
|
|
1218
|
+
cte_raw = d.get("cte_steps", [])
|
|
1219
|
+
em_raw = d.get("expansion_metadata")
|
|
1220
|
+
select_cols = [
|
|
1221
|
+
(
|
|
1222
|
+
SelectCol.from_dict(s)
|
|
1223
|
+
if isinstance(s, dict)
|
|
1224
|
+
else (SelectCol(expr=NormalizedExpr.from_column(s)) if isinstance(s, str) else s)
|
|
1225
|
+
)
|
|
1226
|
+
for s in sc_raw
|
|
1227
|
+
]
|
|
1228
|
+
group_by_cols = [
|
|
1229
|
+
(
|
|
1230
|
+
NormalizedExpr.from_dict(g)
|
|
1231
|
+
if isinstance(g, dict)
|
|
1232
|
+
else (NormalizedExpr.from_column(g) if isinstance(g, str) else g)
|
|
1233
|
+
)
|
|
1234
|
+
for g in gbc_raw
|
|
1235
|
+
]
|
|
1236
|
+
order_by_cols = [
|
|
1237
|
+
(
|
|
1238
|
+
OrderByCol.from_dict(o)
|
|
1239
|
+
if isinstance(o, dict)
|
|
1240
|
+
else (OrderByCol(expr=NormalizedExpr.from_column(o)) if isinstance(o, str) else o)
|
|
1241
|
+
)
|
|
1242
|
+
for o in obc_raw
|
|
1243
|
+
]
|
|
1244
|
+
return SimulatorIntent(
|
|
1245
|
+
intent_id=d.get("intent_id", ""),
|
|
1246
|
+
tables=d.get("tables", []),
|
|
1247
|
+
grain=d.get("grain", "row_level"),
|
|
1248
|
+
select_cols=select_cols,
|
|
1249
|
+
group_by_cols=group_by_cols,
|
|
1250
|
+
order_by_cols=order_by_cols,
|
|
1251
|
+
filters_param=[FilterParam.from_dict(fp) if isinstance(fp, dict) else fp for fp in fp_raw],
|
|
1252
|
+
having_param=[HavingParam.from_dict(hp) if isinstance(hp, dict) else hp for hp in hp_raw],
|
|
1253
|
+
param_values=d.get("param_values", {}),
|
|
1254
|
+
cte_steps=[RuntimeCteStep.from_dict(cte) if isinstance(cte, dict) else cte for cte in cte_raw],
|
|
1255
|
+
question=d.get("question", ""),
|
|
1256
|
+
expansion_metadata=ExpansionMetadata.from_dict(em_raw) if em_raw else None,
|
|
1257
|
+
limit=d.get("limit"),
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1261
|
+
"""Convert to dictionary."""
|
|
1262
|
+
result = {
|
|
1263
|
+
"intent_id": self.intent_id,
|
|
1264
|
+
"tables": self.tables,
|
|
1265
|
+
"grain": self.grain,
|
|
1266
|
+
"select_cols": [s.to_dict() for s in self.select_cols],
|
|
1267
|
+
"group_by_cols": [g.to_dict() for g in self.group_by_cols],
|
|
1268
|
+
"order_by_cols": [o.to_dict() for o in self.order_by_cols],
|
|
1269
|
+
"filters_param": [fp.to_dict() for fp in self.filters_param],
|
|
1270
|
+
"having_param": [hp.to_dict() for hp in self.having_param],
|
|
1271
|
+
"param_values": self.param_values,
|
|
1272
|
+
"cte_steps": [cte.to_dict() for cte in self.cte_steps],
|
|
1273
|
+
"question": self.question,
|
|
1274
|
+
"limit": self.limit,
|
|
1275
|
+
}
|
|
1276
|
+
if self.expansion_metadata:
|
|
1277
|
+
result["expansion_metadata"] = self.expansion_metadata.to_dict()
|
|
1278
|
+
return result
|
|
1279
|
+
|
|
1280
|
+
def to_runtime_intent(self) -> RuntimeIntent:
|
|
1281
|
+
"""Convert to RuntimeIntent for pipeline execution.
|
|
1282
|
+
|
|
1283
|
+
Returns:
|
|
1284
|
+
|
|
1285
|
+
RuntimeIntent built from this SimulatorIntent with an inferred column_map.
|
|
1286
|
+
"""
|
|
1287
|
+
column_map = {}
|
|
1288
|
+
for sc in self.select_cols:
|
|
1289
|
+
col = sc.expr.primary_column
|
|
1290
|
+
if "." in col:
|
|
1291
|
+
table, bare = col.split(".", 1)
|
|
1292
|
+
if table in self.tables:
|
|
1293
|
+
column_map[bare] = table
|
|
1294
|
+
for fp in self.filters_param:
|
|
1295
|
+
col = fp.left_expr.primary_column
|
|
1296
|
+
if "." in col:
|
|
1297
|
+
table, bare = col.split(".", 1)
|
|
1298
|
+
if table in self.tables:
|
|
1299
|
+
column_map[bare] = table
|
|
1300
|
+
|
|
1301
|
+
return RuntimeIntent(
|
|
1302
|
+
tables=self.tables,
|
|
1303
|
+
grain=self.grain,
|
|
1304
|
+
select_cols=self.select_cols,
|
|
1305
|
+
group_by_cols=self.group_by_cols,
|
|
1306
|
+
order_by_cols=self.order_by_cols,
|
|
1307
|
+
filters_param=self.filters_param,
|
|
1308
|
+
having_param=self.having_param,
|
|
1309
|
+
param_values=self.param_values,
|
|
1310
|
+
cte_steps=self.cte_steps,
|
|
1311
|
+
natural_language="",
|
|
1312
|
+
limit=self.limit,
|
|
1313
|
+
column_map=column_map,
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
@dataclass
|
|
1318
|
+
class QSimFilter:
|
|
1319
|
+
"""Lightweight filter for QSim intent with column reference and operator."""
|
|
1320
|
+
|
|
1321
|
+
column: str
|
|
1322
|
+
op: str
|
|
1323
|
+
value_type: str
|
|
1324
|
+
right_column: str = ""
|
|
1325
|
+
|
|
1326
|
+
@property
|
|
1327
|
+
def is_expr_comparison(self) -> bool:
|
|
1328
|
+
"""True if filter is expr-vs-expr (right_column references another expression)."""
|
|
1329
|
+
return bool(self.right_column)
|
|
1330
|
+
|
|
1331
|
+
@staticmethod
|
|
1332
|
+
def from_dict(d: dict[str, Any]) -> QSimFilter:
|
|
1333
|
+
"""Create QSimFilter from dictionary.
|
|
1334
|
+
|
|
1335
|
+
Args:
|
|
1336
|
+
|
|
1337
|
+
d: Dictionary with 'column', 'op', 'value_type', and optional 'right_column' keys.
|
|
1338
|
+
|
|
1339
|
+
Returns:
|
|
1340
|
+
|
|
1341
|
+
Populated QSimFilter instance.
|
|
1342
|
+
"""
|
|
1343
|
+
return QSimFilter(
|
|
1344
|
+
column=d.get("column", ""),
|
|
1345
|
+
op=d.get("op", "="),
|
|
1346
|
+
value_type=d.get("value_type", "categorical"),
|
|
1347
|
+
right_column=d.get("right_column", ""),
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1351
|
+
"""Serialize to a plain dictionary.
|
|
1352
|
+
|
|
1353
|
+
Returns:
|
|
1354
|
+
|
|
1355
|
+
Dictionary with filter fields; right_column only included when set.
|
|
1356
|
+
"""
|
|
1357
|
+
result = {"column": self.column, "op": self.op, "value_type": self.value_type}
|
|
1358
|
+
if self.right_column:
|
|
1359
|
+
result["right_column"] = self.right_column
|
|
1360
|
+
return result
|
|
1361
|
+
|
|
1362
|
+
|
|
1363
|
+
@dataclass
|
|
1364
|
+
class QSimHaving:
|
|
1365
|
+
"""Lightweight having condition for QSim intent with aggregate expression."""
|
|
1366
|
+
|
|
1367
|
+
expression: str
|
|
1368
|
+
op: str
|
|
1369
|
+
value_type: str
|
|
1370
|
+
right_expression: str = ""
|
|
1371
|
+
|
|
1372
|
+
@property
|
|
1373
|
+
def is_expression_comparison(self) -> bool:
|
|
1374
|
+
"""True if this is an expression-to-expression comparison."""
|
|
1375
|
+
return bool(self.right_expression)
|
|
1376
|
+
|
|
1377
|
+
@staticmethod
|
|
1378
|
+
def from_dict(d: dict[str, Any]) -> QSimHaving:
|
|
1379
|
+
"""Create QSimHaving from dictionary.
|
|
1380
|
+
|
|
1381
|
+
Args:
|
|
1382
|
+
|
|
1383
|
+
d: Dictionary with 'expression', 'op', 'value_type', and optional 'right_expression' keys.
|
|
1384
|
+
|
|
1385
|
+
Returns:
|
|
1386
|
+
|
|
1387
|
+
Populated QSimHaving instance.
|
|
1388
|
+
"""
|
|
1389
|
+
return QSimHaving(
|
|
1390
|
+
expression=d.get("expression", ""),
|
|
1391
|
+
op=d.get("op", ">"),
|
|
1392
|
+
value_type=d.get("value_type", "number"),
|
|
1393
|
+
right_expression=d.get("right_expression", ""),
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1397
|
+
"""Serialize to a plain dictionary.
|
|
1398
|
+
|
|
1399
|
+
Returns:
|
|
1400
|
+
|
|
1401
|
+
Dictionary with having fields; right_expression only included when set.
|
|
1402
|
+
"""
|
|
1403
|
+
result = {
|
|
1404
|
+
"expression": self.expression,
|
|
1405
|
+
"op": self.op,
|
|
1406
|
+
"value_type": self.value_type,
|
|
1407
|
+
}
|
|
1408
|
+
if self.right_expression:
|
|
1409
|
+
result["right_expression"] = self.right_expression
|
|
1410
|
+
return result
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
@dataclass
|
|
1414
|
+
class QSimIntent:
|
|
1415
|
+
"""Unified intent for QSim question generation with optional values."""
|
|
1416
|
+
|
|
1417
|
+
intent_id: str
|
|
1418
|
+
tables: list[str]
|
|
1419
|
+
grain: str
|
|
1420
|
+
select_cols: list[str]
|
|
1421
|
+
group_by_cols: list[str]
|
|
1422
|
+
order_by_cols: list[str]
|
|
1423
|
+
filters_param: list[QSimFilter]
|
|
1424
|
+
having_param: list[QSimHaving]
|
|
1425
|
+
param_values: dict[str, ParamValue] = field(default_factory=dict)
|
|
1426
|
+
question: str = ""
|
|
1427
|
+
variant_idx: int = 0
|
|
1428
|
+
limit: int | None = None
|
|
1429
|
+
distinct: bool = False
|
|
1430
|
+
|
|
1431
|
+
@staticmethod
|
|
1432
|
+
def from_dict(d: dict[str, Any]) -> QSimIntent:
|
|
1433
|
+
"""Create QSimIntent from dictionary.
|
|
1434
|
+
|
|
1435
|
+
Args:
|
|
1436
|
+
|
|
1437
|
+
d: Dictionary with keys matching QSimIntent fields.
|
|
1438
|
+
|
|
1439
|
+
Returns:
|
|
1440
|
+
|
|
1441
|
+
Populated QSimIntent instance.
|
|
1442
|
+
"""
|
|
1443
|
+
fp_raw = d.get("filters_param", d.get("filters", []))
|
|
1444
|
+
hp_raw = d.get("having_param", d.get("having", []))
|
|
1445
|
+
return QSimIntent(
|
|
1446
|
+
intent_id=d.get("intent_id", ""),
|
|
1447
|
+
tables=d.get("tables", []),
|
|
1448
|
+
grain=d.get("grain", "row_level"),
|
|
1449
|
+
select_cols=d.get("select_cols", []),
|
|
1450
|
+
group_by_cols=d.get("group_by_cols", []),
|
|
1451
|
+
order_by_cols=d.get("order_by_cols", []),
|
|
1452
|
+
filters_param=[QSimFilter.from_dict(fp) if isinstance(fp, dict) else fp for fp in fp_raw],
|
|
1453
|
+
having_param=[QSimHaving.from_dict(hp) if isinstance(hp, dict) else hp for hp in hp_raw],
|
|
1454
|
+
param_values=d.get("param_values", {}),
|
|
1455
|
+
question=d.get("question", ""),
|
|
1456
|
+
variant_idx=d.get("variant_idx", 0),
|
|
1457
|
+
limit=d.get("limit"),
|
|
1458
|
+
distinct=d.get("distinct", False),
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1462
|
+
"""Serialize to a plain dictionary.
|
|
1463
|
+
|
|
1464
|
+
Returns:
|
|
1465
|
+
|
|
1466
|
+
Dictionary with all QSimIntent fields.
|
|
1467
|
+
"""
|
|
1468
|
+
return {
|
|
1469
|
+
"intent_id": self.intent_id,
|
|
1470
|
+
"tables": self.tables,
|
|
1471
|
+
"grain": self.grain,
|
|
1472
|
+
"select_cols": self.select_cols,
|
|
1473
|
+
"group_by_cols": self.group_by_cols,
|
|
1474
|
+
"order_by_cols": self.order_by_cols,
|
|
1475
|
+
"filters_param": [fp.to_dict() for fp in self.filters_param],
|
|
1476
|
+
"having_param": [hp.to_dict() for hp in self.having_param],
|
|
1477
|
+
"param_values": self.param_values,
|
|
1478
|
+
"question": self.question,
|
|
1479
|
+
"variant_idx": self.variant_idx,
|
|
1480
|
+
"limit": self.limit,
|
|
1481
|
+
"distinct": self.distinct,
|
|
1482
|
+
}
|
|
1483
|
+
|
|
1484
|
+
|
|
1485
|
+
def runtime_intent_to_concrete(runtime: RuntimeIntent, intent_id: str) -> ConcreteIntent:
|
|
1486
|
+
"""Extract ConcreteIntent from RuntimeIntent for template storage.
|
|
1487
|
+
|
|
1488
|
+
Args:
|
|
1489
|
+
|
|
1490
|
+
runtime: The RuntimeIntent to convert.
|
|
1491
|
+
intent_id: Unique identifier to assign to the resulting ConcreteIntent.
|
|
1492
|
+
|
|
1493
|
+
Returns:
|
|
1494
|
+
|
|
1495
|
+
ConcreteIntent with runtime-only fields (param_values and natural_language) stripped.
|
|
1496
|
+
"""
|
|
1497
|
+
return ConcreteIntent(
|
|
1498
|
+
intent_id=intent_id,
|
|
1499
|
+
tables=runtime.tables,
|
|
1500
|
+
grain=runtime.grain,
|
|
1501
|
+
select_cols=runtime.select_cols,
|
|
1502
|
+
group_by_cols=runtime.group_by_cols,
|
|
1503
|
+
order_by_cols=runtime.order_by_cols,
|
|
1504
|
+
filters_param=runtime.filters_param,
|
|
1505
|
+
having_param=runtime.having_param,
|
|
1506
|
+
cte_steps=[_runtime_cte_to_concrete(cte) for cte in runtime.cte_steps],
|
|
1507
|
+
limit=runtime.limit,
|
|
1508
|
+
column_map=runtime.column_map,
|
|
1509
|
+
chosen_join_candidate_id=runtime.chosen_join_candidate_id,
|
|
1510
|
+
chosen_join_path_signature=runtime.chosen_join_path_signature,
|
|
1511
|
+
)
|
|
1512
|
+
|
|
1513
|
+
|
|
1514
|
+
@dataclass
|
|
1515
|
+
class ValueHistory:
|
|
1516
|
+
"""Value history for Template tracking historical query values as a flat dict."""
|
|
1517
|
+
|
|
1518
|
+
param_values: list[dict[str, ParamValue]]
|
|
1519
|
+
questions: list[str]
|
|
1520
|
+
natural_language: list[str]
|
|
1521
|
+
|
|
1522
|
+
@staticmethod
|
|
1523
|
+
def from_dict(d: dict[str, Any]) -> ValueHistory:
|
|
1524
|
+
"""Create ValueHistory from dictionary.
|
|
1525
|
+
|
|
1526
|
+
Args:
|
|
1527
|
+
|
|
1528
|
+
d: Dictionary with 'param_values', 'questions', and 'natural_language' list keys.
|
|
1529
|
+
|
|
1530
|
+
Returns:
|
|
1531
|
+
|
|
1532
|
+
Populated ValueHistory instance.
|
|
1533
|
+
"""
|
|
1534
|
+
return ValueHistory(
|
|
1535
|
+
param_values=d.get("param_values", []),
|
|
1536
|
+
questions=d.get("questions", []),
|
|
1537
|
+
natural_language=d.get("natural_language", []),
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1541
|
+
"""Serialize to a plain dictionary.
|
|
1542
|
+
|
|
1543
|
+
Returns:
|
|
1544
|
+
|
|
1545
|
+
Dictionary with the three parallel history lists.
|
|
1546
|
+
"""
|
|
1547
|
+
return {
|
|
1548
|
+
"param_values": self.param_values,
|
|
1549
|
+
"questions": self.questions,
|
|
1550
|
+
"natural_language": self.natural_language,
|
|
1551
|
+
}
|
|
1552
|
+
|
|
1553
|
+
def add(self, param_values: dict[str, ParamValue], question: str, natural_language: str) -> None:
|
|
1554
|
+
"""Append a new entry to all parallel history lists.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
|
|
1558
|
+
param_values: Flat dict of parameter keys to resolved values.
|
|
1559
|
+
question: The natural language question associated with this execution.
|
|
1560
|
+
natural_language: The intent's natural language description.
|
|
1561
|
+
"""
|
|
1562
|
+
self.param_values.append(param_values)
|
|
1563
|
+
self.questions.append(question)
|
|
1564
|
+
self.natural_language.append(natural_language or "")
|
|
1565
|
+
|
|
1566
|
+
def __len__(self) -> int:
|
|
1567
|
+
"""Return number of entries in history."""
|
|
1568
|
+
return len(self.questions)
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
@dataclass
|
|
1572
|
+
class RejectedValueHistory:
|
|
1573
|
+
"""Value history for RejectedTemplate tracking historical rejected values."""
|
|
1574
|
+
|
|
1575
|
+
param_values: list[dict[str, ParamValue]]
|
|
1576
|
+
questions: list[str]
|
|
1577
|
+
natural_language: list[str]
|
|
1578
|
+
rejection_reasons: list[str]
|
|
1579
|
+
rejection_categories: list[str]
|
|
1580
|
+
|
|
1581
|
+
@staticmethod
|
|
1582
|
+
def from_dict(d: dict[str, Any]) -> RejectedValueHistory:
|
|
1583
|
+
"""Create RejectedValueHistory from dictionary.
|
|
1584
|
+
|
|
1585
|
+
Args:
|
|
1586
|
+
|
|
1587
|
+
d: Dictionary with all RejectedValueHistory list keys.
|
|
1588
|
+
|
|
1589
|
+
Returns:
|
|
1590
|
+
|
|
1591
|
+
Populated RejectedValueHistory instance.
|
|
1592
|
+
"""
|
|
1593
|
+
return RejectedValueHistory(
|
|
1594
|
+
param_values=d.get("param_values", []),
|
|
1595
|
+
questions=d.get("questions", []),
|
|
1596
|
+
natural_language=d.get("natural_language", []),
|
|
1597
|
+
rejection_reasons=d.get("rejection_reasons", []),
|
|
1598
|
+
rejection_categories=d.get("rejection_categories", []),
|
|
1599
|
+
)
|
|
1600
|
+
|
|
1601
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1602
|
+
"""Serialize to a plain dictionary.
|
|
1603
|
+
|
|
1604
|
+
Returns:
|
|
1605
|
+
|
|
1606
|
+
Dictionary with all five parallel history lists.
|
|
1607
|
+
"""
|
|
1608
|
+
return {
|
|
1609
|
+
"param_values": self.param_values,
|
|
1610
|
+
"questions": self.questions,
|
|
1611
|
+
"natural_language": self.natural_language,
|
|
1612
|
+
"rejection_reasons": self.rejection_reasons,
|
|
1613
|
+
"rejection_categories": self.rejection_categories,
|
|
1614
|
+
}
|
|
1615
|
+
|
|
1616
|
+
def add(
|
|
1617
|
+
self,
|
|
1618
|
+
param_values: dict[str, ParamValue],
|
|
1619
|
+
question: str,
|
|
1620
|
+
natural_language: str,
|
|
1621
|
+
reason: str,
|
|
1622
|
+
category: str,
|
|
1623
|
+
) -> None:
|
|
1624
|
+
"""Append a new entry to all parallel history lists.
|
|
1625
|
+
|
|
1626
|
+
Args:
|
|
1627
|
+
|
|
1628
|
+
param_values: Flat dict of parameter keys to resolved values.
|
|
1629
|
+
question: The natural language question associated with this rejection.
|
|
1630
|
+
natural_language: The intent's natural language description.
|
|
1631
|
+
reason: Full rejection reason message.
|
|
1632
|
+
category: Coarse rejection category string.
|
|
1633
|
+
"""
|
|
1634
|
+
self.param_values.append(param_values)
|
|
1635
|
+
self.questions.append(question)
|
|
1636
|
+
self.natural_language.append(natural_language or "")
|
|
1637
|
+
self.rejection_reasons.append(reason)
|
|
1638
|
+
self.rejection_categories.append(category)
|
|
1639
|
+
|
|
1640
|
+
def __len__(self) -> int:
|
|
1641
|
+
"""Return number of entries in history."""
|
|
1642
|
+
return len(self.questions)
|
|
1643
|
+
|
|
1644
|
+
|
|
1645
|
+
@dataclass
|
|
1646
|
+
class Template:
|
|
1647
|
+
"""Validated and accepted query template."""
|
|
1648
|
+
|
|
1649
|
+
id: str
|
|
1650
|
+
schema_hash: str
|
|
1651
|
+
intent_signature: ConcreteIntent
|
|
1652
|
+
intent_key: str
|
|
1653
|
+
tables_used: list[str]
|
|
1654
|
+
sql_param: str
|
|
1655
|
+
sql_display_param: str
|
|
1656
|
+
sql_fp: str
|
|
1657
|
+
shape: SQLShape
|
|
1658
|
+
colmap_sig: str
|
|
1659
|
+
value_history: ValueHistory
|
|
1660
|
+
stats: TemplateStats
|
|
1661
|
+
ux_summary: str = ""
|
|
1662
|
+
source: str = "human"
|
|
1663
|
+
trust_level: int = 1
|
|
1664
|
+
structural_defaults: dict[str, str | int | float] = field(default_factory=dict)
|
|
1665
|
+
deterministic_sql: str = ""
|
|
1666
|
+
aliased_sql: str = ""
|
|
1667
|
+
spark_sql_param: str = ""
|
|
1668
|
+
|
|
1669
|
+
@property
|
|
1670
|
+
def chosen_join_candidate_id(self) -> str:
|
|
1671
|
+
"""Delegate to intent_signature."""
|
|
1672
|
+
return self.intent_signature.chosen_join_candidate_id
|
|
1673
|
+
|
|
1674
|
+
@property
|
|
1675
|
+
def chosen_join_path_signature(self) -> list[str]:
|
|
1676
|
+
"""Delegate to intent_signature."""
|
|
1677
|
+
return self.intent_signature.chosen_join_path_signature
|
|
1678
|
+
|
|
1679
|
+
@staticmethod
|
|
1680
|
+
def from_dict(d: dict[str, Any]) -> Template:
|
|
1681
|
+
"""Create Template from dictionary with nested dataclass reconstruction.
|
|
1682
|
+
|
|
1683
|
+
Args:
|
|
1684
|
+
|
|
1685
|
+
d: Dictionary with all Template fields, including nested intent_signature, value_history, stats, and shape dicts.
|
|
1686
|
+
|
|
1687
|
+
Returns:
|
|
1688
|
+
|
|
1689
|
+
Fully populated Template instance.
|
|
1690
|
+
"""
|
|
1691
|
+
intent_sig = d.get("intent_signature", {})
|
|
1692
|
+
if isinstance(intent_sig, dict):
|
|
1693
|
+
intent_sig = ConcreteIntent.from_dict(intent_sig)
|
|
1694
|
+
vh_data = d.get("value_history", {})
|
|
1695
|
+
value_history = ValueHistory.from_dict(vh_data) if isinstance(vh_data, dict) else vh_data
|
|
1696
|
+
stats_data = d.get("stats", {})
|
|
1697
|
+
stats = TemplateStats.from_dict(stats_data) if isinstance(stats_data, dict) else stats_data
|
|
1698
|
+
shape_data = d.get("shape", {})
|
|
1699
|
+
shape = SQLShape.from_dict(shape_data) if isinstance(shape_data, dict) else shape_data
|
|
1700
|
+
return Template(
|
|
1701
|
+
id=d.get("id", ""),
|
|
1702
|
+
schema_hash=d.get("schema_hash", ""),
|
|
1703
|
+
intent_signature=intent_sig,
|
|
1704
|
+
intent_key=d.get("intent_key", ""),
|
|
1705
|
+
tables_used=d.get("tables_used", []),
|
|
1706
|
+
sql_param=d.get("sql_param", ""),
|
|
1707
|
+
sql_display_param=d.get("sql_display_param", ""),
|
|
1708
|
+
sql_fp=d.get("sql_fp", ""),
|
|
1709
|
+
shape=shape,
|
|
1710
|
+
colmap_sig=d.get("colmap_sig", ""),
|
|
1711
|
+
value_history=value_history,
|
|
1712
|
+
stats=stats,
|
|
1713
|
+
ux_summary=d.get("ux_summary", ""),
|
|
1714
|
+
source=d.get("source", "human"),
|
|
1715
|
+
trust_level=d.get("trust_level", 1),
|
|
1716
|
+
structural_defaults=d.get("structural_defaults", {}),
|
|
1717
|
+
deterministic_sql=d.get("deterministic_sql", ""),
|
|
1718
|
+
aliased_sql=d.get("aliased_sql", ""),
|
|
1719
|
+
spark_sql_param=d.get("spark_sql_param", ""),
|
|
1720
|
+
)
|
|
1721
|
+
|
|
1722
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1723
|
+
"""Serialize to a plain dictionary with nested dataclass conversion.
|
|
1724
|
+
|
|
1725
|
+
Returns:
|
|
1726
|
+
|
|
1727
|
+
Dictionary with all Template fields, with intent_signature, value_history, stats, and shape serialized recursively.
|
|
1728
|
+
"""
|
|
1729
|
+
intent_sig = (
|
|
1730
|
+
self.intent_signature.to_dict() if hasattr(self.intent_signature, "to_dict") else self.intent_signature
|
|
1731
|
+
)
|
|
1732
|
+
vh_dict = self.value_history.to_dict() if hasattr(self.value_history, "to_dict") else self.value_history
|
|
1733
|
+
stats_dict = self.stats.to_dict() if hasattr(self.stats, "to_dict") else self.stats
|
|
1734
|
+
return {
|
|
1735
|
+
"id": self.id,
|
|
1736
|
+
"schema_hash": self.schema_hash,
|
|
1737
|
+
"intent_signature": intent_sig,
|
|
1738
|
+
"intent_key": self.intent_key,
|
|
1739
|
+
"tables_used": self.tables_used,
|
|
1740
|
+
"sql_param": self.sql_param,
|
|
1741
|
+
"sql_display_param": self.sql_display_param,
|
|
1742
|
+
"sql_fp": self.sql_fp,
|
|
1743
|
+
"shape": self.shape.to_dict(),
|
|
1744
|
+
"colmap_sig": self.colmap_sig,
|
|
1745
|
+
"value_history": vh_dict,
|
|
1746
|
+
"stats": stats_dict,
|
|
1747
|
+
"ux_summary": self.ux_summary,
|
|
1748
|
+
"source": self.source,
|
|
1749
|
+
"trust_level": self.trust_level,
|
|
1750
|
+
"structural_defaults": self.structural_defaults,
|
|
1751
|
+
"deterministic_sql": self.deterministic_sql,
|
|
1752
|
+
"aliased_sql": self.aliased_sql,
|
|
1753
|
+
"spark_sql_param": self.spark_sql_param,
|
|
1754
|
+
}
|
|
1755
|
+
|
|
1756
|
+
|
|
1757
|
+
@dataclass
|
|
1758
|
+
class RejectedTemplate:
|
|
1759
|
+
"""Rejected query template for negative learning."""
|
|
1760
|
+
|
|
1761
|
+
id: str
|
|
1762
|
+
schema_hash: str
|
|
1763
|
+
intent_signature: ConcreteIntent
|
|
1764
|
+
intent_key: str
|
|
1765
|
+
tables_used: list[str]
|
|
1766
|
+
sql_param: str
|
|
1767
|
+
sql_display_param: str
|
|
1768
|
+
sql_fp: str
|
|
1769
|
+
shape: SQLShape
|
|
1770
|
+
colmap_sig: str
|
|
1771
|
+
value_history: RejectedValueHistory
|
|
1772
|
+
aliased_sql: str = ""
|
|
1773
|
+
spark_sql_param: str = ""
|
|
1774
|
+
|
|
1775
|
+
@property
|
|
1776
|
+
def chosen_join_candidate_id(self) -> str:
|
|
1777
|
+
"""Delegate to intent_signature."""
|
|
1778
|
+
return self.intent_signature.chosen_join_candidate_id
|
|
1779
|
+
|
|
1780
|
+
@property
|
|
1781
|
+
def chosen_join_path_signature(self) -> list[str]:
|
|
1782
|
+
"""Delegate to intent_signature."""
|
|
1783
|
+
return self.intent_signature.chosen_join_path_signature
|
|
1784
|
+
|
|
1785
|
+
@staticmethod
|
|
1786
|
+
def from_dict(d: dict[str, Any]) -> RejectedTemplate:
|
|
1787
|
+
"""Create RejectedTemplate from dictionary with nested dataclass reconstruction.
|
|
1788
|
+
|
|
1789
|
+
Args:
|
|
1790
|
+
|
|
1791
|
+
d: Dictionary with all RejectedTemplate fields, including nested intent_signature, value_history, and shape dicts.
|
|
1792
|
+
|
|
1793
|
+
Returns:
|
|
1794
|
+
|
|
1795
|
+
Fully populated RejectedTemplate instance.
|
|
1796
|
+
"""
|
|
1797
|
+
intent_sig = d.get("intent_signature", {})
|
|
1798
|
+
if isinstance(intent_sig, dict):
|
|
1799
|
+
intent_sig = ConcreteIntent.from_dict(intent_sig)
|
|
1800
|
+
vh_data = d.get("value_history", {})
|
|
1801
|
+
value_history = RejectedValueHistory.from_dict(vh_data) if isinstance(vh_data, dict) else vh_data
|
|
1802
|
+
shape_data = d.get("shape", {})
|
|
1803
|
+
shape = SQLShape.from_dict(shape_data) if isinstance(shape_data, dict) else shape_data
|
|
1804
|
+
return RejectedTemplate(
|
|
1805
|
+
id=d.get("id", ""),
|
|
1806
|
+
schema_hash=d.get("schema_hash", ""),
|
|
1807
|
+
intent_signature=intent_sig,
|
|
1808
|
+
intent_key=d.get("intent_key", ""),
|
|
1809
|
+
tables_used=d.get("tables_used", []),
|
|
1810
|
+
sql_param=d.get("sql_param", ""),
|
|
1811
|
+
sql_display_param=d.get("sql_display_param", d.get("sql_param", "")),
|
|
1812
|
+
sql_fp=d.get("sql_fp", ""),
|
|
1813
|
+
shape=shape,
|
|
1814
|
+
colmap_sig=d.get("colmap_sig", ""),
|
|
1815
|
+
value_history=value_history,
|
|
1816
|
+
aliased_sql=d.get("aliased_sql", ""),
|
|
1817
|
+
spark_sql_param=d.get("spark_sql_param", ""),
|
|
1818
|
+
)
|
|
1819
|
+
|
|
1820
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1821
|
+
"""Convert to dictionary with nested dataclass conversion."""
|
|
1822
|
+
intent_sig = (
|
|
1823
|
+
self.intent_signature.to_dict() if hasattr(self.intent_signature, "to_dict") else self.intent_signature
|
|
1824
|
+
)
|
|
1825
|
+
vh_dict = self.value_history.to_dict() if hasattr(self.value_history, "to_dict") else self.value_history
|
|
1826
|
+
return {
|
|
1827
|
+
"id": self.id,
|
|
1828
|
+
"schema_hash": self.schema_hash,
|
|
1829
|
+
"intent_signature": intent_sig,
|
|
1830
|
+
"intent_key": self.intent_key,
|
|
1831
|
+
"tables_used": self.tables_used,
|
|
1832
|
+
"sql_param": self.sql_param,
|
|
1833
|
+
"sql_display_param": self.sql_display_param,
|
|
1834
|
+
"sql_fp": self.sql_fp,
|
|
1835
|
+
"shape": self.shape.to_dict(),
|
|
1836
|
+
"colmap_sig": self.colmap_sig,
|
|
1837
|
+
"value_history": vh_dict,
|
|
1838
|
+
"aliased_sql": self.aliased_sql,
|
|
1839
|
+
"spark_sql_param": self.spark_sql_param,
|
|
1840
|
+
}
|
|
1841
|
+
|
|
1842
|
+
|
|
1843
|
+
@dataclass
|
|
1844
|
+
class SimulatorResult:
|
|
1845
|
+
"""Result of a single simulator run."""
|
|
1846
|
+
|
|
1847
|
+
intent: RuntimeIntent
|
|
1848
|
+
question: str
|
|
1849
|
+
sql: str | None = None
|
|
1850
|
+
rows: list | None = None
|
|
1851
|
+
success: bool = False
|
|
1852
|
+
error: str | None = None
|
|
1853
|
+
validation_issues: list[str] = field(default_factory=list)
|
|
1854
|
+
confidence: float = 0.0
|
|
1855
|
+
llm_response: str | None = None
|
|
1856
|
+
sql_generation_attempts: int = 0
|
|
1857
|
+
repair_loop_count: int = 0
|
|
1858
|
+
|
|
1859
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1860
|
+
"""Serialize to a plain dictionary.
|
|
1861
|
+
|
|
1862
|
+
Returns:
|
|
1863
|
+
|
|
1864
|
+
Dictionary with all SimulatorResult fields including the serialized intent.
|
|
1865
|
+
"""
|
|
1866
|
+
return {
|
|
1867
|
+
"intent": self.intent.to_dict() if self.intent else None,
|
|
1868
|
+
"question": self.question,
|
|
1869
|
+
"sql": self.sql,
|
|
1870
|
+
"rows": self.rows,
|
|
1871
|
+
"success": self.success,
|
|
1872
|
+
"error": self.error,
|
|
1873
|
+
"validation_issues": self.validation_issues,
|
|
1874
|
+
"confidence": self.confidence,
|
|
1875
|
+
"llm_response": self.llm_response,
|
|
1876
|
+
"sql_generation_attempts": self.sql_generation_attempts,
|
|
1877
|
+
"repair_loop_count": self.repair_loop_count,
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
|
|
1881
|
+
@dataclass
|
|
1882
|
+
class TemplateMatch:
|
|
1883
|
+
"""Result of template matching against intent."""
|
|
1884
|
+
|
|
1885
|
+
intent: RuntimeIntent | None = None
|
|
1886
|
+
best_template: Template | None = None
|
|
1887
|
+
similarity_score: float = 0.0
|
|
1888
|
+
reuse_type: str = "none"
|
|
1889
|
+
semantic_warnings: list[str] = field(default_factory=list)
|
|
1890
|
+
llm_calls: int = 0
|