aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/qsim_sample.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
"""Value sampling and instantiation for the question-generation simulator.
|
|
2
|
+
|
|
3
|
+
Implements operator-aware value selection across categorical, numeric, boolean, and temporal domains, with coordinated range sampling for decomposed BETWEEN pairs (lower from [0.15, 0.35], upper from [0.65, 0.85]) and deterministic HAVING value generation keyed by variant index.
|
|
4
|
+
Skips instantiation for IS NULL filters and column-to-column comparisons.
|
|
5
|
+
Populates QSimIntent param_values via index-based keys, computes per-intent variance scores, and performs single-pass proportional variant allocation across the full intent set.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import random
|
|
11
|
+
from dataclasses import replace
|
|
12
|
+
from datetime import datetime, timedelta
|
|
13
|
+
|
|
14
|
+
from .config import (
|
|
15
|
+
AGG_PATTERN,
|
|
16
|
+
HAVING_COUNT_VALUES,
|
|
17
|
+
HAVING_MIN_MAX_VALUES,
|
|
18
|
+
HAVING_SUM_AVG_VALUES,
|
|
19
|
+
QSimConfig,
|
|
20
|
+
)
|
|
21
|
+
from .contracts_base import SchemaGraph, ValueDomain
|
|
22
|
+
from .contracts_core import QSimFilter, QSimHaving, QSimIntent
|
|
23
|
+
from .core_utils import debug
|
|
24
|
+
from .qsim_struct import decompose_between_filter
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _is_integer_type(data_type: str | None) -> bool:
|
|
28
|
+
"""Determine if data type represents an integer."""
|
|
29
|
+
if not data_type:
|
|
30
|
+
return False
|
|
31
|
+
dtype_lower = data_type.lower()
|
|
32
|
+
if dtype_lower in (
|
|
33
|
+
"integer",
|
|
34
|
+
"int",
|
|
35
|
+
"bigint",
|
|
36
|
+
"smallint",
|
|
37
|
+
"tinyint",
|
|
38
|
+
"long",
|
|
39
|
+
"short",
|
|
40
|
+
):
|
|
41
|
+
return True
|
|
42
|
+
if "int" in dtype_lower or dtype_lower in ("long", "short"):
|
|
43
|
+
if "interval" not in dtype_lower:
|
|
44
|
+
return True
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _parse_date(val: str) -> datetime | None:
|
|
49
|
+
"""Parse date string to datetime object."""
|
|
50
|
+
if "T" in val:
|
|
51
|
+
val = val.split("T")[0]
|
|
52
|
+
elif " " in val:
|
|
53
|
+
val = val.split(" ")[0]
|
|
54
|
+
|
|
55
|
+
for fmt in ("%Y-%m-%d", "%Y/%m/%d", "%d-%m-%Y", "%d/%m/%Y"):
|
|
56
|
+
try:
|
|
57
|
+
return datetime.strptime(val, fmt)
|
|
58
|
+
except ValueError:
|
|
59
|
+
continue
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _format_date(dt: datetime) -> str:
|
|
64
|
+
"""Format datetime to date-only string."""
|
|
65
|
+
return dt.strftime("%Y-%m-%d")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _extract_date_part(val: str) -> str:
|
|
69
|
+
"""Extract date part from datetime string."""
|
|
70
|
+
if "T" in val:
|
|
71
|
+
return val.split("T")[0]
|
|
72
|
+
if " " in val:
|
|
73
|
+
return val.split(" ")[0]
|
|
74
|
+
return val
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _sample_categorical(domain: ValueDomain, variant_idx: int) -> str | None:
|
|
78
|
+
"""Sample categorical value."""
|
|
79
|
+
values_list = domain.values
|
|
80
|
+
if values_list:
|
|
81
|
+
idx = variant_idx % len(values_list)
|
|
82
|
+
return values_list[idx]
|
|
83
|
+
if domain.min_val is not None and domain.max_val is not None:
|
|
84
|
+
try:
|
|
85
|
+
min_v = int(float(domain.min_val))
|
|
86
|
+
max_v = int(float(domain.max_val))
|
|
87
|
+
range_size = max(1, max_v - min_v + 1)
|
|
88
|
+
value = min_v + (variant_idx % range_size)
|
|
89
|
+
return str(value)
|
|
90
|
+
except (ValueError, TypeError):
|
|
91
|
+
return str(domain.min_val)
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _sample_boolean(domain: ValueDomain, variant_idx: int) -> str | None:
|
|
96
|
+
"""Sample boolean value."""
|
|
97
|
+
values_list = domain.values
|
|
98
|
+
if values_list:
|
|
99
|
+
normalized = []
|
|
100
|
+
for v in values_list:
|
|
101
|
+
if isinstance(v, bool):
|
|
102
|
+
normalized.append("true" if v else "false")
|
|
103
|
+
elif isinstance(v, str):
|
|
104
|
+
normalized.append(v.lower() if v.lower() in ("true", "false") else v)
|
|
105
|
+
else:
|
|
106
|
+
normalized.append(str(v))
|
|
107
|
+
idx = variant_idx % len(normalized)
|
|
108
|
+
return normalized[idx]
|
|
109
|
+
default_bools = ["true", "false"]
|
|
110
|
+
idx = variant_idx % len(default_bools)
|
|
111
|
+
return default_bools[idx]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _sample_numeric_categorical(domain: ValueDomain, variant_idx: int) -> str | None:
|
|
115
|
+
"""Sample numeric categorical value from discrete set."""
|
|
116
|
+
values_list = domain.values
|
|
117
|
+
if values_list:
|
|
118
|
+
idx = variant_idx % len(values_list)
|
|
119
|
+
val = values_list[idx]
|
|
120
|
+
return str(int(float(val))) if isinstance(val, int | float) else str(val)
|
|
121
|
+
if domain.min_val is not None and domain.max_val is not None:
|
|
122
|
+
try:
|
|
123
|
+
min_v = int(float(domain.min_val))
|
|
124
|
+
max_v = int(float(domain.max_val))
|
|
125
|
+
range_size = max(1, max_v - min_v + 1)
|
|
126
|
+
value = min_v + (variant_idx % range_size)
|
|
127
|
+
return str(value)
|
|
128
|
+
except (ValueError, TypeError):
|
|
129
|
+
return str(int(float(domain.min_val)))
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _sample_numeric(domain: ValueDomain, op: str, variant_idx: int) -> str | None:
|
|
134
|
+
"""Sample numeric value with operator awareness."""
|
|
135
|
+
if domain.min_val is not None and domain.max_val is not None:
|
|
136
|
+
try:
|
|
137
|
+
min_v = float(domain.min_val)
|
|
138
|
+
max_v = float(domain.max_val)
|
|
139
|
+
range_size = max_v - min_v
|
|
140
|
+
is_integer = _is_integer_type(domain.data_type)
|
|
141
|
+
|
|
142
|
+
if op == "=":
|
|
143
|
+
if is_integer:
|
|
144
|
+
int_range = max(1, int(range_size + 1))
|
|
145
|
+
value = int(min_v + (variant_idx % int_range))
|
|
146
|
+
else:
|
|
147
|
+
segment = (variant_idx % 10) / 10.0
|
|
148
|
+
value = min_v + segment * range_size
|
|
149
|
+
value = round(value, 2) if abs(value) >= 1 else round(value, 4)
|
|
150
|
+
elif op in (">", ">="):
|
|
151
|
+
lower_bound = min_v + range_size * 0.2
|
|
152
|
+
upper_bound = min_v + range_size * 0.5
|
|
153
|
+
value = lower_bound + (variant_idx % 5) * (upper_bound - lower_bound) / 5
|
|
154
|
+
value = int(round(value)) if is_integer else (round(value, 2) if abs(value) >= 1 else round(value, 4))
|
|
155
|
+
elif op in ("<", "<="):
|
|
156
|
+
lower_bound = min_v + range_size * 0.5
|
|
157
|
+
upper_bound = min_v + range_size * 0.8
|
|
158
|
+
value = lower_bound + (variant_idx % 5) * (upper_bound - lower_bound) / 5
|
|
159
|
+
value = int(round(value)) if is_integer else (round(value, 2) if abs(value) >= 1 else round(value, 4))
|
|
160
|
+
else:
|
|
161
|
+
if range_size > 0:
|
|
162
|
+
segment = (variant_idx % 10) / 10.0
|
|
163
|
+
value = min_v + segment * range_size
|
|
164
|
+
else:
|
|
165
|
+
value = min_v
|
|
166
|
+
value = int(round(value)) if is_integer else (round(value, 2) if abs(value) >= 1 else round(value, 4))
|
|
167
|
+
|
|
168
|
+
return str(value)
|
|
169
|
+
except (ValueError, TypeError):
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
values_list = domain.values
|
|
173
|
+
if values_list:
|
|
174
|
+
idx = variant_idx % len(values_list)
|
|
175
|
+
return values_list[idx]
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _sample_temporal(domain: ValueDomain, op: str, variant_idx: int) -> str | None:
|
|
181
|
+
"""Sample temporal value with date interpolation."""
|
|
182
|
+
if domain.min_val is not None and domain.max_val is not None:
|
|
183
|
+
try:
|
|
184
|
+
min_dt = _parse_date(str(domain.min_val))
|
|
185
|
+
max_dt = _parse_date(str(domain.max_val))
|
|
186
|
+
|
|
187
|
+
if min_dt is None or max_dt is None:
|
|
188
|
+
return _extract_date_part(str(domain.min_val))
|
|
189
|
+
|
|
190
|
+
total_days = (max_dt - min_dt).days
|
|
191
|
+
if total_days <= 0:
|
|
192
|
+
return _format_date(min_dt)
|
|
193
|
+
|
|
194
|
+
if op in (">", ">="):
|
|
195
|
+
segment = 0.2 + ((variant_idx % 5) / 5.0) * 0.15
|
|
196
|
+
elif op in ("<", "<="):
|
|
197
|
+
segment = 0.65 + ((variant_idx % 5) / 5.0) * 0.15
|
|
198
|
+
else:
|
|
199
|
+
segment = (variant_idx % 10) / 10.0
|
|
200
|
+
|
|
201
|
+
offset_days = int(total_days * segment)
|
|
202
|
+
result_dt = min_dt + timedelta(days=offset_days)
|
|
203
|
+
return _format_date(result_dt)
|
|
204
|
+
except (ValueError, TypeError):
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
values_list = domain.values
|
|
208
|
+
if values_list:
|
|
209
|
+
idx = variant_idx % len(values_list)
|
|
210
|
+
return _extract_date_part(values_list[idx])
|
|
211
|
+
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _sample_in_values(domain: ValueDomain, value_type: str, variant_idx: int) -> str | None:
|
|
216
|
+
"""Sample multiple values for IN/NOT IN operators."""
|
|
217
|
+
if value_type == "categorical":
|
|
218
|
+
values_list = domain.values
|
|
219
|
+
if values_list:
|
|
220
|
+
n_values = min(3 + (variant_idx % 3), len(values_list))
|
|
221
|
+
start_idx = variant_idx % max(1, len(values_list) - n_values + 1)
|
|
222
|
+
values = values_list[start_idx : start_idx + n_values]
|
|
223
|
+
return "'" + "','".join(values) + "'"
|
|
224
|
+
|
|
225
|
+
elif value_type == "numeric_categorical":
|
|
226
|
+
values_list = domain.values
|
|
227
|
+
if values_list:
|
|
228
|
+
n_values = min(3 + (variant_idx % 3), len(values_list))
|
|
229
|
+
start_idx = variant_idx % max(1, len(values_list) - n_values + 1)
|
|
230
|
+
values = values_list[start_idx : start_idx + n_values]
|
|
231
|
+
int_values = [str(int(float(v))) if isinstance(v, int | float) else str(v) for v in values]
|
|
232
|
+
return ",".join(int_values)
|
|
233
|
+
if domain.min_val is not None and domain.max_val is not None:
|
|
234
|
+
try:
|
|
235
|
+
min_v = int(float(domain.min_val))
|
|
236
|
+
max_v = int(float(domain.max_val))
|
|
237
|
+
range_size = max(1, max_v - min_v + 1)
|
|
238
|
+
n_values = min(3 + (variant_idx % 3), range_size)
|
|
239
|
+
values = []
|
|
240
|
+
for i in range(n_values):
|
|
241
|
+
value = min_v + ((variant_idx + i) % range_size)
|
|
242
|
+
values.append(str(value))
|
|
243
|
+
return ",".join(values)
|
|
244
|
+
except (ValueError, TypeError):
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
elif value_type == "boolean":
|
|
248
|
+
values_list = domain.values
|
|
249
|
+
if values_list:
|
|
250
|
+
normalized = []
|
|
251
|
+
for v in values_list:
|
|
252
|
+
if isinstance(v, bool):
|
|
253
|
+
normalized.append("true" if v else "false")
|
|
254
|
+
elif isinstance(v, str):
|
|
255
|
+
normalized.append(v.lower() if v.lower() in ("true", "false") else v)
|
|
256
|
+
else:
|
|
257
|
+
normalized.append(str(v))
|
|
258
|
+
return ",".join(normalized)
|
|
259
|
+
return "true,false"
|
|
260
|
+
|
|
261
|
+
elif value_type in ("numeric", "temporal"):
|
|
262
|
+
if domain.min_val is not None and domain.max_val is not None:
|
|
263
|
+
try:
|
|
264
|
+
min_v = float(domain.min_val)
|
|
265
|
+
max_v = float(domain.max_val)
|
|
266
|
+
range_size = max_v - min_v
|
|
267
|
+
is_integer = _is_integer_type(domain.data_type)
|
|
268
|
+
n_values = 2 + (variant_idx % 3)
|
|
269
|
+
values = []
|
|
270
|
+
for i in range(n_values):
|
|
271
|
+
segment = ((variant_idx + i) % 10) / 10.0
|
|
272
|
+
val = min_v + segment * range_size
|
|
273
|
+
val = int(round(val)) if is_integer else (round(val, 2) if abs(val) >= 1 else round(val, 4))
|
|
274
|
+
values.append(str(val))
|
|
275
|
+
return ",".join(values)
|
|
276
|
+
except (ValueError, TypeError):
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def sample_value_from_domain(domain: ValueDomain, value_type: str, op: str = "=", variant_idx: int = 0) -> str | None:
|
|
283
|
+
"""Sample a concrete value from a column's domain with operator awareness."""
|
|
284
|
+
if value_type == "null" or op in ("is null", "is not null"):
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
if op in ("in", "not in"):
|
|
288
|
+
return _sample_in_values(domain, value_type, variant_idx)
|
|
289
|
+
|
|
290
|
+
if value_type == "categorical":
|
|
291
|
+
return _sample_categorical(domain, variant_idx)
|
|
292
|
+
|
|
293
|
+
if value_type == "numeric_categorical":
|
|
294
|
+
return _sample_numeric_categorical(domain, variant_idx)
|
|
295
|
+
|
|
296
|
+
if value_type == "numeric":
|
|
297
|
+
return _sample_numeric(domain, op, variant_idx)
|
|
298
|
+
|
|
299
|
+
if value_type == "temporal":
|
|
300
|
+
return _sample_temporal(domain, op, variant_idx)
|
|
301
|
+
|
|
302
|
+
if value_type == "boolean":
|
|
303
|
+
return _sample_boolean(domain, variant_idx)
|
|
304
|
+
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _identify_range_pairs(filters: list[QSimFilter]) -> dict[str, dict[str, int]]:
|
|
309
|
+
"""Identify columns with paired range filters (>= and </<= on same column)."""
|
|
310
|
+
column_ops: dict[str, dict[str, int]] = {}
|
|
311
|
+
for idx, f in enumerate(filters):
|
|
312
|
+
if f.is_expr_comparison:
|
|
313
|
+
continue
|
|
314
|
+
if f.op in (">", ">="):
|
|
315
|
+
column_ops.setdefault(f.column, {})["lower_idx"] = idx
|
|
316
|
+
elif f.op in ("<", "<="):
|
|
317
|
+
column_ops.setdefault(f.column, {})["upper_idx"] = idx
|
|
318
|
+
return {col: ops for col, ops in column_ops.items() if "lower_idx" in ops and "upper_idx" in ops}
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _sample_numeric_range(domain: ValueDomain, variant_idx: int) -> tuple[str | None, str | None]:
|
|
322
|
+
"""Sample coordinated numeric range values."""
|
|
323
|
+
if domain.min_val is None or domain.max_val is None:
|
|
324
|
+
return None, None
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
min_v = float(domain.min_val)
|
|
328
|
+
max_v = float(domain.max_val)
|
|
329
|
+
range_size = max_v - min_v
|
|
330
|
+
if range_size <= 0:
|
|
331
|
+
return None, None
|
|
332
|
+
|
|
333
|
+
is_integer = _is_integer_type(domain.data_type)
|
|
334
|
+
|
|
335
|
+
lower_segment = 0.15 + ((variant_idx % 5) / 5.0) * 0.2
|
|
336
|
+
upper_segment = 0.65 + ((variant_idx % 5) / 5.0) * 0.2
|
|
337
|
+
|
|
338
|
+
lower_val = min_v + lower_segment * range_size
|
|
339
|
+
upper_val = min_v + upper_segment * range_size
|
|
340
|
+
|
|
341
|
+
if is_integer:
|
|
342
|
+
lower_val = int(round(lower_val))
|
|
343
|
+
upper_val = int(round(upper_val))
|
|
344
|
+
if lower_val >= upper_val:
|
|
345
|
+
upper_val = min(lower_val + 1, int(max_v))
|
|
346
|
+
else:
|
|
347
|
+
lower_val = round(lower_val, 2) if abs(lower_val) >= 1 else round(lower_val, 4)
|
|
348
|
+
upper_val = round(upper_val, 2) if abs(upper_val) >= 1 else round(upper_val, 4)
|
|
349
|
+
|
|
350
|
+
return str(lower_val), str(upper_val)
|
|
351
|
+
except (ValueError, TypeError):
|
|
352
|
+
return None, None
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _sample_temporal_range(domain: ValueDomain, variant_idx: int) -> tuple[str | None, str | None]:
|
|
356
|
+
"""Sample coordinated temporal range values with date interpolation."""
|
|
357
|
+
if domain.min_val is None or domain.max_val is None:
|
|
358
|
+
return None, None
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
min_dt = _parse_date(str(domain.min_val))
|
|
362
|
+
max_dt = _parse_date(str(domain.max_val))
|
|
363
|
+
|
|
364
|
+
if min_dt is None or max_dt is None:
|
|
365
|
+
lower_val = _extract_date_part(str(domain.min_val))
|
|
366
|
+
upper_val = _extract_date_part(str(domain.max_val))
|
|
367
|
+
return lower_val, upper_val
|
|
368
|
+
|
|
369
|
+
total_days = (max_dt - min_dt).days
|
|
370
|
+
if total_days <= 0:
|
|
371
|
+
return _format_date(min_dt), _format_date(max_dt)
|
|
372
|
+
|
|
373
|
+
lower_segment = 0.15 + ((variant_idx % 5) / 5.0) * 0.2
|
|
374
|
+
upper_segment = 0.65 + ((variant_idx % 5) / 5.0) * 0.2
|
|
375
|
+
|
|
376
|
+
lower_days = int(total_days * lower_segment)
|
|
377
|
+
upper_days = int(total_days * upper_segment)
|
|
378
|
+
|
|
379
|
+
lower_dt = min_dt + timedelta(days=lower_days)
|
|
380
|
+
upper_dt = min_dt + timedelta(days=upper_days)
|
|
381
|
+
|
|
382
|
+
return _format_date(lower_dt), _format_date(upper_dt)
|
|
383
|
+
except (ValueError, TypeError):
|
|
384
|
+
return None, None
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def sample_coordinated_range(domain: ValueDomain, value_type: str, variant_idx: int) -> tuple[str | None, str | None]:
|
|
388
|
+
"""Sample coordinated lower and upper values for range pairs."""
|
|
389
|
+
if value_type not in ("numeric", "temporal"):
|
|
390
|
+
return None, None
|
|
391
|
+
|
|
392
|
+
if value_type == "numeric":
|
|
393
|
+
return _sample_numeric_range(domain, variant_idx)
|
|
394
|
+
|
|
395
|
+
if value_type == "temporal":
|
|
396
|
+
return _sample_temporal_range(domain, variant_idx)
|
|
397
|
+
|
|
398
|
+
return None, None
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def deterministic_having_value(agg_func: str, variant_idx: int, having_idx: int = 0) -> str:
|
|
402
|
+
"""Generate deterministic HAVING value from predefined pools."""
|
|
403
|
+
offset = variant_idx * 3 + having_idx
|
|
404
|
+
|
|
405
|
+
if agg_func == "count":
|
|
406
|
+
value = HAVING_COUNT_VALUES[offset % len(HAVING_COUNT_VALUES)]
|
|
407
|
+
return str(value)
|
|
408
|
+
|
|
409
|
+
if agg_func in {"sum", "avg"}:
|
|
410
|
+
value = HAVING_SUM_AVG_VALUES[offset % len(HAVING_SUM_AVG_VALUES)]
|
|
411
|
+
return str(value)
|
|
412
|
+
|
|
413
|
+
if agg_func in {"min", "max"}:
|
|
414
|
+
value = HAVING_MIN_MAX_VALUES[offset % len(HAVING_MIN_MAX_VALUES)]
|
|
415
|
+
return str(value)
|
|
416
|
+
|
|
417
|
+
idx = offset % len(HAVING_COUNT_VALUES)
|
|
418
|
+
return str(HAVING_COUNT_VALUES[idx])
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _compute_intent_variance(intent: QSimIntent, value_domains: dict[str, ValueDomain]) -> int:
|
|
422
|
+
"""Compute variance score for intent instantiation potential."""
|
|
423
|
+
variance_score = 0
|
|
424
|
+
|
|
425
|
+
for f in intent.filters_param:
|
|
426
|
+
if f.is_expr_comparison:
|
|
427
|
+
continue
|
|
428
|
+
col_key = f.column
|
|
429
|
+
domain = value_domains.get(col_key)
|
|
430
|
+
if domain:
|
|
431
|
+
if domain.values:
|
|
432
|
+
variance_score += len(domain.values)
|
|
433
|
+
elif domain.min_val is not None and domain.max_val is not None:
|
|
434
|
+
variance_score += 10
|
|
435
|
+
|
|
436
|
+
if intent.filters_param:
|
|
437
|
+
variance_score += 10 * len(intent.having_param)
|
|
438
|
+
else:
|
|
439
|
+
variance_score += 5 * len(intent.having_param)
|
|
440
|
+
|
|
441
|
+
return variance_score
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _instantiate_intent(
|
|
445
|
+
intent: QSimIntent, value_domains: dict[str, ValueDomain], variant_idx: int = 0
|
|
446
|
+
) -> QSimIntent | None:
|
|
447
|
+
"""Populate QSimIntent filter/having values via index-based param_values."""
|
|
448
|
+
|
|
449
|
+
decomposed_filters: list[QSimFilter] = []
|
|
450
|
+
for f in intent.filters_param:
|
|
451
|
+
decomposed_filters.extend(decompose_between_filter(f))
|
|
452
|
+
|
|
453
|
+
range_pairs = _identify_range_pairs(decomposed_filters)
|
|
454
|
+
range_values: dict[str, tuple[str, str]] = {}
|
|
455
|
+
|
|
456
|
+
for col_key, pair_indices in range_pairs.items():
|
|
457
|
+
domain = value_domains.get(col_key)
|
|
458
|
+
if domain is None:
|
|
459
|
+
continue
|
|
460
|
+
lower_idx = pair_indices["lower_idx"]
|
|
461
|
+
value_type = decomposed_filters[lower_idx].value_type
|
|
462
|
+
lower_val, upper_val = sample_coordinated_range(domain, value_type, variant_idx)
|
|
463
|
+
if lower_val is not None and upper_val is not None:
|
|
464
|
+
range_values[col_key] = (lower_val, upper_val)
|
|
465
|
+
|
|
466
|
+
new_filters: list[QSimFilter] = []
|
|
467
|
+
new_param_values: dict[str, any] = {}
|
|
468
|
+
|
|
469
|
+
for filter_idx, f in enumerate(decomposed_filters):
|
|
470
|
+
param_key = f"f{filter_idx}"
|
|
471
|
+
|
|
472
|
+
if f.is_expr_comparison:
|
|
473
|
+
new_filters.append(f)
|
|
474
|
+
debug(f"[qsim_sample.instantiate_intent] expr_comparison: {f.column} {f.op} {f.right_column}")
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
col_key = f.column
|
|
478
|
+
value_type = f.value_type
|
|
479
|
+
op = f.op
|
|
480
|
+
|
|
481
|
+
if value_type == "null" or op in ("is null", "is not null"):
|
|
482
|
+
new_filters.append(replace(f, value_type="null"))
|
|
483
|
+
debug(f"[qsim_sample.instantiate_intent] null_filter: {col_key} {op}")
|
|
484
|
+
continue
|
|
485
|
+
|
|
486
|
+
domain = value_domains.get(col_key)
|
|
487
|
+
|
|
488
|
+
if domain is None:
|
|
489
|
+
debug(f"[qsim_sample.instantiate_intent] no_domain: {col_key}")
|
|
490
|
+
new_filters.append(f)
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
if col_key in range_values:
|
|
494
|
+
lower_val, upper_val = range_values[col_key]
|
|
495
|
+
if f.op in (">", ">="):
|
|
496
|
+
value = lower_val
|
|
497
|
+
elif f.op in ("<", "<="):
|
|
498
|
+
value = upper_val
|
|
499
|
+
else:
|
|
500
|
+
combined_idx = variant_idx * len(decomposed_filters) + filter_idx
|
|
501
|
+
value = sample_value_from_domain(domain, value_type, f.op, combined_idx)
|
|
502
|
+
else:
|
|
503
|
+
combined_idx = variant_idx * len(decomposed_filters) + filter_idx
|
|
504
|
+
value = sample_value_from_domain(domain, value_type, f.op, combined_idx)
|
|
505
|
+
|
|
506
|
+
if value is not None:
|
|
507
|
+
new_param_values[param_key] = value
|
|
508
|
+
|
|
509
|
+
new_filters.append(f)
|
|
510
|
+
|
|
511
|
+
new_having: list[QSimHaving] = []
|
|
512
|
+
for having_idx, h in enumerate(intent.having_param):
|
|
513
|
+
param_key = f"h{having_idx}"
|
|
514
|
+
|
|
515
|
+
agg_match = AGG_PATTERN.match(h.expression)
|
|
516
|
+
agg_func = agg_match.group(1).lower() if agg_match else "count"
|
|
517
|
+
value = deterministic_having_value(agg_func, variant_idx, having_idx)
|
|
518
|
+
new_param_values[param_key] = value
|
|
519
|
+
|
|
520
|
+
new_having.append(h)
|
|
521
|
+
|
|
522
|
+
return QSimIntent(
|
|
523
|
+
intent_id=intent.intent_id,
|
|
524
|
+
tables=intent.tables,
|
|
525
|
+
grain=intent.grain,
|
|
526
|
+
select_cols=intent.select_cols,
|
|
527
|
+
group_by_cols=intent.group_by_cols,
|
|
528
|
+
order_by_cols=intent.order_by_cols,
|
|
529
|
+
filters_param=new_filters,
|
|
530
|
+
having_param=new_having,
|
|
531
|
+
param_values=new_param_values,
|
|
532
|
+
question="",
|
|
533
|
+
variant_idx=variant_idx,
|
|
534
|
+
limit=intent.limit,
|
|
535
|
+
distinct=intent.distinct,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def instantiate_all(intents: list[QSimIntent], schema: SchemaGraph, num_questions: int = None) -> list[QSimIntent]:
|
|
540
|
+
"""Generate QSimIntent instances with populated filter/having values via single-pass proportional allocation."""
|
|
541
|
+
if num_questions is None:
|
|
542
|
+
num_questions = QSimConfig.QUESTIONS_COUNT
|
|
543
|
+
|
|
544
|
+
random.seed(QSimConfig.RANDOM_SEED)
|
|
545
|
+
|
|
546
|
+
avg_variants = num_questions / len(intents) if intents else 0
|
|
547
|
+
if avg_variants < QSimConfig.MIN_AVG_VARIANTS_PER_INTENT:
|
|
548
|
+
debug(
|
|
549
|
+
f"[qsim_sample.instantiate_all] WARNING: avg_variants={avg_variants:.2f} below MIN={QSimConfig.MIN_AVG_VARIANTS_PER_INTENT}"
|
|
550
|
+
)
|
|
551
|
+
if avg_variants > QSimConfig.MAX_AVG_VARIANTS_PER_INTENT:
|
|
552
|
+
raise ValueError(
|
|
553
|
+
f"Intent/variant ratio unrealistic: {len(intents)} intents cannot generate {num_questions} diverse questions (avg={avg_variants:.1f} > max={QSimConfig.MAX_AVG_VARIANTS_PER_INTENT})"
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
value_domains: dict[str, ValueDomain] = {}
|
|
557
|
+
for table_name, table_meta in schema.tables.items():
|
|
558
|
+
for col_name, col_meta in table_meta.columns.items():
|
|
559
|
+
col_key = f"{table_name}.{col_name}"
|
|
560
|
+
value_domains[col_key] = ValueDomain(
|
|
561
|
+
values=col_meta.top_k_values or [],
|
|
562
|
+
min_val=col_meta.min_val,
|
|
563
|
+
max_val=col_meta.max_val,
|
|
564
|
+
data_type=col_meta.data_type,
|
|
565
|
+
)
|
|
566
|
+
debug(f"[qsim_sample.instantiate_all] value_domains: {len(value_domains)} columns")
|
|
567
|
+
|
|
568
|
+
variances: dict[str, float] = {}
|
|
569
|
+
for intent in intents:
|
|
570
|
+
variances[intent.intent_id] = _compute_intent_variance(intent, value_domains)
|
|
571
|
+
|
|
572
|
+
total_variance = sum(v for v in variances.values() if v > 0)
|
|
573
|
+
|
|
574
|
+
allocations: dict[str, int] = {}
|
|
575
|
+
if total_variance == 0:
|
|
576
|
+
for intent in intents:
|
|
577
|
+
allocations[intent.intent_id] = 1
|
|
578
|
+
else:
|
|
579
|
+
for intent in intents:
|
|
580
|
+
v = variances[intent.intent_id]
|
|
581
|
+
if v == 0:
|
|
582
|
+
allocations[intent.intent_id] = 1
|
|
583
|
+
else:
|
|
584
|
+
share = v / total_variance
|
|
585
|
+
allocations[intent.intent_id] = max(1, round(num_questions * share))
|
|
586
|
+
|
|
587
|
+
debug(
|
|
588
|
+
f"[qsim_sample.instantiate_all] total_variance={total_variance:.2f}, allocations_sum={sum(allocations.values())}"
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
instantiated: list[QSimIntent] = []
|
|
592
|
+
|
|
593
|
+
for intent in intents:
|
|
594
|
+
max_variants = allocations[intent.intent_id]
|
|
595
|
+
for variant_idx in range(max_variants):
|
|
596
|
+
result = _instantiate_intent(intent, value_domains, variant_idx)
|
|
597
|
+
if result is not None:
|
|
598
|
+
instantiated.append(result)
|
|
599
|
+
|
|
600
|
+
if len(instantiated) > num_questions:
|
|
601
|
+
random.shuffle(instantiated)
|
|
602
|
+
instantiated = instantiated[:num_questions]
|
|
603
|
+
debug(f"[qsim_sample.instantiate_all] truncated: {len(instantiated)}/{num_questions}")
|
|
604
|
+
elif len(instantiated) < num_questions:
|
|
605
|
+
debug(f"[qsim_sample.instantiate_all] limit_reached: {len(instantiated)}/{num_questions}")
|
|
606
|
+
else:
|
|
607
|
+
debug(f"[qsim_sample.instantiate_all] created: {len(instantiated)} intents")
|
|
608
|
+
|
|
609
|
+
return instantiated
|