aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/qsim_struct.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
"""Structural utilities for question-generation simulator intent enumeration.
|
|
2
|
+
|
|
3
|
+
Enumerates FK-connected table sets, generates and caches structural query skeletons, provides column capability helpers (filterable, groupable, aggregatable, comparable pairs), and builds schema context strings for LLM prompts.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import asdict, replace
|
|
11
|
+
from itertools import combinations
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from .config import PolicyConfig, QSimConfig, SimulatorConfig
|
|
15
|
+
from .contracts_base import ColumnRole, QSimSkeleton, SchemaGraph, SkeletonLimits
|
|
16
|
+
from .contracts_core import QSimFilter
|
|
17
|
+
from .core_utils import debug, intent_id
|
|
18
|
+
|
|
19
|
+
_SKELETON_CACHE: dict[frozenset[str], list[QSimSkeleton]] = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def build_fk_adjacency(schema: SchemaGraph) -> dict[str, set[str]]:
|
|
23
|
+
"""Build an undirected FK adjacency map of tables in the schema.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
|
|
27
|
+
schema: Schema graph whose foreign-key definitions are traversed.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
|
|
31
|
+
Dict mapping each table name to the set of table names it shares a foreign-key relationship with (bidirectional edges).
|
|
32
|
+
"""
|
|
33
|
+
adj: dict[str, set[str]] = {t: set() for t in schema.tables}
|
|
34
|
+
|
|
35
|
+
for table in schema.tables.values():
|
|
36
|
+
for fk in table.foreign_keys:
|
|
37
|
+
adj[fk.src_table].add(fk.dst_table)
|
|
38
|
+
adj[fk.dst_table].add(fk.src_table)
|
|
39
|
+
|
|
40
|
+
return adj
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def is_connected(tables: list[str], adj: dict[str, set[str]]) -> bool:
|
|
44
|
+
"""Return whether all tables in *tables* are mutually reachable via FK edges.
|
|
45
|
+
|
|
46
|
+
Uses a BFS from the first table to determine connectivity within the sub-graph induced by the provided table list.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
|
|
50
|
+
tables: List of table names to test for connectivity.
|
|
51
|
+
|
|
52
|
+
adj: Undirected FK adjacency map as returned by ``build_fk_adjacency``.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
|
|
56
|
+
``True`` if all tables are reachable from the first table; ``True`` unconditionally when *tables* contains zero or one entry.
|
|
57
|
+
"""
|
|
58
|
+
if len(tables) <= 1:
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
table_set = set(tables)
|
|
62
|
+
visited = set()
|
|
63
|
+
queue = [tables[0]]
|
|
64
|
+
|
|
65
|
+
while queue:
|
|
66
|
+
current = queue.pop(0)
|
|
67
|
+
if current in visited:
|
|
68
|
+
continue
|
|
69
|
+
visited.add(current)
|
|
70
|
+
|
|
71
|
+
for neighbor in adj.get(current, set()):
|
|
72
|
+
if neighbor in table_set and neighbor not in visited:
|
|
73
|
+
queue.append(neighbor)
|
|
74
|
+
|
|
75
|
+
return visited == table_set
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def enumerate_table_sets(schema: SchemaGraph, max_tables: int = None) -> list[list[str]]:
|
|
79
|
+
"""Enumerate all valid FK-connected table combinations up to *max_tables* in size.
|
|
80
|
+
|
|
81
|
+
Includes all single-table sets and all multi-table combinations that form a connected sub-graph in the FK adjacency graph.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
|
|
85
|
+
schema: Schema graph to derive tables and FK adjacency from.
|
|
86
|
+
|
|
87
|
+
max_tables: Maximum number of tables per combination; defaults to ``QSimConfig.MAX_TABLES_PER_INTENT``.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
|
|
91
|
+
List of table-name lists representing each valid table set.
|
|
92
|
+
"""
|
|
93
|
+
if max_tables is None:
|
|
94
|
+
max_tables = QSimConfig.MAX_TABLES_PER_INTENT
|
|
95
|
+
|
|
96
|
+
adj = build_fk_adjacency(schema)
|
|
97
|
+
table_names = sorted(schema.tables.keys())
|
|
98
|
+
valid_sets: list[list[str]] = []
|
|
99
|
+
|
|
100
|
+
for t in table_names:
|
|
101
|
+
valid_sets.append([t])
|
|
102
|
+
|
|
103
|
+
for size in range(2, max_tables + 1):
|
|
104
|
+
for combo in combinations(table_names, size):
|
|
105
|
+
combo_list = list(combo)
|
|
106
|
+
if is_connected(combo_list, adj):
|
|
107
|
+
valid_sets.append(combo_list)
|
|
108
|
+
|
|
109
|
+
debug(f"[qsim_struct.enumerate_table_sets] found {len(valid_sets)} valid table combinations")
|
|
110
|
+
return valid_sets
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _is_excluded_filter_column(col_name: str) -> bool:
|
|
114
|
+
"""Return whether a column name matches any excluded filter pattern.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
|
|
118
|
+
col_name: Column name string to test.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
|
|
122
|
+
``True`` if any pattern from ``QSimConfig.EXCLUDED_FILTER_PATTERNS`` is found as a substring of *col_name* (case-insensitive).
|
|
123
|
+
"""
|
|
124
|
+
for pattern in QSimConfig.EXCLUDED_FILTER_PATTERNS:
|
|
125
|
+
if pattern in col_name.lower():
|
|
126
|
+
return True
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_filterable_columns(table: str, schema: SchemaGraph, column_roles: dict[str, str]) -> list[tuple[str, str]]:
|
|
131
|
+
"""Return filterable columns for *table*, excluding audit/system columns.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
|
|
135
|
+
table: Table name to inspect.
|
|
136
|
+
|
|
137
|
+
schema: Schema graph containing column metadata.
|
|
138
|
+
|
|
139
|
+
column_roles: Map of ``table.column`` key to role string.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
|
|
143
|
+
List of ``(column_key, role)`` tuples for columns that are marked filterable and do not match any excluded filter pattern.
|
|
144
|
+
"""
|
|
145
|
+
result = []
|
|
146
|
+
table_ir = schema.tables.get(table)
|
|
147
|
+
if not table_ir:
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
for col_name, col_meta in table_ir.columns.items():
|
|
151
|
+
if not col_meta.is_filterable or _is_excluded_filter_column(col_name):
|
|
152
|
+
continue
|
|
153
|
+
col_key = f"{table}.{col_name}"
|
|
154
|
+
role = column_roles.get(col_key, col_meta.role or "unknown")
|
|
155
|
+
result.append((col_key, role))
|
|
156
|
+
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_aggregatable_columns(table: str, schema: SchemaGraph, column_roles: dict[str, str]) -> list[str]:
|
|
161
|
+
"""Return column keys that can be aggregated with SUM/AVG/MIN/MAX.
|
|
162
|
+
|
|
163
|
+
Only columns with the ``NUMERIC_MEASURE`` role are considered aggregatable.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
|
|
167
|
+
table: Table name to inspect.
|
|
168
|
+
|
|
169
|
+
schema: Schema graph containing column metadata.
|
|
170
|
+
|
|
171
|
+
column_roles: Map of ``table.column`` key to role string.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
|
|
175
|
+
List of ``table.column`` key strings for aggregatable columns.
|
|
176
|
+
"""
|
|
177
|
+
result = []
|
|
178
|
+
table_ir = schema.tables.get(table)
|
|
179
|
+
if not table_ir:
|
|
180
|
+
return result
|
|
181
|
+
|
|
182
|
+
for col_name, col_meta in table_ir.columns.items():
|
|
183
|
+
col_key = f"{table}.{col_name}"
|
|
184
|
+
role = column_roles.get(col_key, col_meta.role or "unknown")
|
|
185
|
+
if role == ColumnRole.NUMERIC_MEASURE.value:
|
|
186
|
+
result.append(col_key)
|
|
187
|
+
|
|
188
|
+
return result
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def get_groupable_columns(table: str, schema: SchemaGraph, column_roles: dict[str, str]) -> list[str]:
|
|
192
|
+
"""Return column keys that can be used in GROUP BY clauses.
|
|
193
|
+
|
|
194
|
+
Includes ``CATEGORICAL``, ``TEMPORAL``, and ``NUMERIC_CATEGORICAL`` roles.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
|
|
198
|
+
table: Table name to inspect.
|
|
199
|
+
|
|
200
|
+
schema: Schema graph containing column metadata.
|
|
201
|
+
|
|
202
|
+
column_roles: Map of ``table.column`` key to role string.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
|
|
206
|
+
List of ``table.column`` key strings for groupable columns.
|
|
207
|
+
"""
|
|
208
|
+
result = []
|
|
209
|
+
table_ir = schema.tables.get(table)
|
|
210
|
+
if not table_ir:
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
for col_name, col_meta in table_ir.columns.items():
|
|
214
|
+
col_key = f"{table}.{col_name}"
|
|
215
|
+
role = column_roles.get(col_key, col_meta.role or "unknown")
|
|
216
|
+
if role in (
|
|
217
|
+
ColumnRole.CATEGORICAL.value,
|
|
218
|
+
ColumnRole.TEMPORAL.value,
|
|
219
|
+
ColumnRole.NUMERIC_CATEGORICAL.value,
|
|
220
|
+
):
|
|
221
|
+
result.append(col_key)
|
|
222
|
+
|
|
223
|
+
return result
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def get_comparable_column_pairs(
|
|
227
|
+
table_set: list[str], schema: SchemaGraph, column_roles: dict[str, str]
|
|
228
|
+
) -> list[tuple[str, str, str, str, str]]:
|
|
229
|
+
"""Return cross-table column pairs that can be semantically compared.
|
|
230
|
+
|
|
231
|
+
A pair is considered comparable when both columns share the same role and come from different tables. Numeric roles and temporal roles are matched separately.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
|
|
235
|
+
table_set: List of table names to consider.
|
|
236
|
+
|
|
237
|
+
schema: Schema graph containing column metadata.
|
|
238
|
+
|
|
239
|
+
column_roles: Map of ``table.column`` key to role string.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
|
|
243
|
+
List of 5-tuples ``(table1, col1, table2, col2, role)`` for each comparable column pair found.
|
|
244
|
+
"""
|
|
245
|
+
comparable_pairs = []
|
|
246
|
+
|
|
247
|
+
numeric_roles = {
|
|
248
|
+
ColumnRole.NUMERIC_MEASURE.value,
|
|
249
|
+
ColumnRole.NUMERIC_CATEGORICAL.value,
|
|
250
|
+
}
|
|
251
|
+
temporal_roles = {ColumnRole.TEMPORAL.value}
|
|
252
|
+
|
|
253
|
+
all_numeric = []
|
|
254
|
+
all_temporal = []
|
|
255
|
+
|
|
256
|
+
for table in table_set:
|
|
257
|
+
table_ir = schema.tables.get(table)
|
|
258
|
+
if not table_ir:
|
|
259
|
+
continue
|
|
260
|
+
for col_name, col_meta in table_ir.columns.items():
|
|
261
|
+
col_key = f"{table}.{col_name}"
|
|
262
|
+
role = column_roles.get(col_key, col_meta.role or "unknown")
|
|
263
|
+
if role in numeric_roles:
|
|
264
|
+
all_numeric.append((table, col_name, role))
|
|
265
|
+
elif role in temporal_roles:
|
|
266
|
+
all_temporal.append((table, col_name, role))
|
|
267
|
+
|
|
268
|
+
for i, (t1, c1, r1) in enumerate(all_numeric):
|
|
269
|
+
for t2, c2, r2 in all_numeric[i + 1 :]:
|
|
270
|
+
if t1 != t2 and r1 == r2:
|
|
271
|
+
comparable_pairs.append((t1, c1, t2, c2, r1))
|
|
272
|
+
|
|
273
|
+
for i, (t1, c1, r1) in enumerate(all_temporal):
|
|
274
|
+
for t2, c2, _r2 in all_temporal[i + 1 :]:
|
|
275
|
+
if t1 != t2:
|
|
276
|
+
comparable_pairs.append((t1, c1, t2, c2, r1))
|
|
277
|
+
|
|
278
|
+
return comparable_pairs
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def compute_skeleton_limits(tables: list[str], schema: SchemaGraph, column_roles: dict[str, str]) -> SkeletonLimits:
|
|
282
|
+
"""Compute schema-derived limits for skeleton enumeration.
|
|
283
|
+
|
|
284
|
+
Derives max_filters, max_groupby, and max_having from column capabilities (filterable, groupable, aggregatable) for the given table set, capped by config constants.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
|
|
288
|
+
tables: List of table names in the intent.
|
|
289
|
+
|
|
290
|
+
schema: Schema graph for column metadata.
|
|
291
|
+
|
|
292
|
+
column_roles: Map of table.column key to role string.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
|
|
296
|
+
``SkeletonLimits`` with derived values.
|
|
297
|
+
"""
|
|
298
|
+
all_filterable = []
|
|
299
|
+
all_groupable = []
|
|
300
|
+
all_aggregatable = []
|
|
301
|
+
for table in tables:
|
|
302
|
+
all_filterable.extend(get_filterable_columns(table, schema, column_roles))
|
|
303
|
+
all_groupable.extend(get_groupable_columns(table, schema, column_roles))
|
|
304
|
+
all_aggregatable.extend(get_aggregatable_columns(table, schema, column_roles))
|
|
305
|
+
|
|
306
|
+
num_filterable = len(set(col for col, _ in all_filterable))
|
|
307
|
+
max_filter_cols = min(QSimConfig.MAX_FILTER_COLUMNS, num_filterable)
|
|
308
|
+
max_filters = min(QSimConfig.MAX_FILTERS_PER_INTENT, max_filter_cols * 2)
|
|
309
|
+
max_groupby = min(len(all_groupable), QSimConfig.MAX_GROUP_BY_COLUMNS)
|
|
310
|
+
max_having = min(SimulatorConfig.MAX_HAVING_CONDITIONS, 1 + len(all_aggregatable))
|
|
311
|
+
|
|
312
|
+
return SkeletonLimits(max_filters=max_filters, max_groupby=max_groupby, max_having=max_having)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def compute_intent_id(intent_dict: dict[str, Any]) -> str:
|
|
316
|
+
"""Compute a hash-based intent ID from the structural fields of an intent dict.
|
|
317
|
+
|
|
318
|
+
Tables, select columns, group-by columns, filters, and HAVING conditions are all sorted before hashing to ensure canonical equality regardless of ordering.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
|
|
322
|
+
intent_dict: Dict with keys ``"tables"``, ``"grain"``, ``"select_cols"``, ``"group_by_cols"``, ``"filters_param"``, and ``"having_param"``.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
|
|
326
|
+
A short hash string suitable for use as a deduplicated intent identifier.
|
|
327
|
+
"""
|
|
328
|
+
structural = {
|
|
329
|
+
"tables": sorted(intent_dict.get("tables", [])),
|
|
330
|
+
"grain": intent_dict.get("grain", "row_level"),
|
|
331
|
+
"select_cols": sorted(intent_dict.get("select_cols", [])),
|
|
332
|
+
"group_by_cols": sorted(intent_dict.get("group_by_cols", [])),
|
|
333
|
+
"filters_param": sorted(
|
|
334
|
+
intent_dict.get("filters_param", []),
|
|
335
|
+
key=lambda x: str(x.get("column", "")) if isinstance(x, dict) else "",
|
|
336
|
+
),
|
|
337
|
+
"having_param": sorted(
|
|
338
|
+
intent_dict.get("having_param", []),
|
|
339
|
+
key=lambda x: str(x.get("expression", "")) if isinstance(x, dict) else "",
|
|
340
|
+
),
|
|
341
|
+
}
|
|
342
|
+
return intent_id(structural)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def generate_all_skeletons(tables: list[str], schema: SchemaGraph, column_roles: dict[str, str]) -> list[QSimSkeleton]:
|
|
346
|
+
"""Generate all valid structural ``QSimSkeleton`` instances for a table set.
|
|
347
|
+
|
|
348
|
+
Results are cached in the module-level ``_SKELETON_CACHE`` so that repeated calls for the same table set are free.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
|
|
352
|
+
tables: Ordered list of table names defining the skeleton's table set.
|
|
353
|
+
|
|
354
|
+
schema: Schema graph used to determine available filterable, groupable, and aggregatable columns.
|
|
355
|
+
|
|
356
|
+
column_roles: Map of ``table.column`` key to role string.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
|
|
360
|
+
List of all ``QSimSkeleton`` combinations valid for the given table set and schema capabilities.
|
|
361
|
+
"""
|
|
362
|
+
global _SKELETON_CACHE
|
|
363
|
+
|
|
364
|
+
table_key = frozenset(tables)
|
|
365
|
+
if table_key in _SKELETON_CACHE:
|
|
366
|
+
debug(f"[qsim_struct.generate_all_skeletons] cache_hit: {len(_SKELETON_CACHE[table_key])} skeletons")
|
|
367
|
+
return _SKELETON_CACHE[table_key]
|
|
368
|
+
|
|
369
|
+
limits = compute_skeleton_limits(tables, schema, column_roles)
|
|
370
|
+
max_filters = limits.max_filters
|
|
371
|
+
max_groupby = limits.max_groupby
|
|
372
|
+
max_having = limits.max_having
|
|
373
|
+
|
|
374
|
+
is_single_table = len(tables) == 1
|
|
375
|
+
has_comparable_pairs = len(get_comparable_column_pairs(tables, schema, column_roles)) > 0
|
|
376
|
+
|
|
377
|
+
skeletons = []
|
|
378
|
+
|
|
379
|
+
for has_agg in [True, False]:
|
|
380
|
+
for num_filters in range(0, max_filters + 1):
|
|
381
|
+
groupby_options = range(1, max_groupby + 1) if has_agg else [0]
|
|
382
|
+
for num_groupby in groupby_options:
|
|
383
|
+
for has_orderby in [True, False]:
|
|
384
|
+
having_options = [True, False] if has_agg and num_groupby > 0 else [False]
|
|
385
|
+
for has_having in having_options:
|
|
386
|
+
distinct_options = [True, False] if not has_agg and is_single_table else [False]
|
|
387
|
+
for has_distinct in distinct_options:
|
|
388
|
+
expr_cmp_options = [True, False] if has_comparable_pairs and num_filters > 0 else [False]
|
|
389
|
+
for has_expr_cmp in expr_cmp_options:
|
|
390
|
+
skeletons.append(
|
|
391
|
+
QSimSkeleton(
|
|
392
|
+
tables=tables,
|
|
393
|
+
has_aggregation=has_agg,
|
|
394
|
+
num_filters=num_filters,
|
|
395
|
+
num_groupby=num_groupby,
|
|
396
|
+
has_orderby=has_orderby,
|
|
397
|
+
has_having=has_having,
|
|
398
|
+
has_distinct=has_distinct,
|
|
399
|
+
has_expr_comparison=has_expr_cmp,
|
|
400
|
+
)
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
_SKELETON_CACHE[table_key] = skeletons
|
|
404
|
+
|
|
405
|
+
debug(
|
|
406
|
+
f"[qsim_struct.generate_all_skeletons] created {len(skeletons)} skeletons for tables={tables}, max_filters={max_filters}, max_groupby={max_groupby}, max_having={max_having}"
|
|
407
|
+
)
|
|
408
|
+
return skeletons
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def load_or_create_skeletons(
|
|
412
|
+
schema: SchemaGraph, column_roles: dict[str, str]
|
|
413
|
+
) -> dict[frozenset[str], list[QSimSkeleton]]:
|
|
414
|
+
"""Load the skeleton cache from disk or generate and persist it.
|
|
415
|
+
|
|
416
|
+
If the cached file exists and its schema hash matches, the cache is loaded; otherwise new skeletons are generated and saved.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
|
|
420
|
+
schema: Schema graph used for skeleton generation and hash comparison.
|
|
421
|
+
|
|
422
|
+
column_roles: Map of ``table.column`` key to role string.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
|
|
426
|
+
Module-level skeleton cache dict mapping frozen table sets to their ``QSimSkeleton`` lists.
|
|
427
|
+
"""
|
|
428
|
+
global _SKELETON_CACHE
|
|
429
|
+
|
|
430
|
+
skeleton_path = QSimConfig.SKELETONS_JSON_PATH
|
|
431
|
+
|
|
432
|
+
if not PolicyConfig.REGENERATE_SKELETON_CACHE and os.path.exists(skeleton_path):
|
|
433
|
+
try:
|
|
434
|
+
with open(skeleton_path, encoding="utf-8") as f:
|
|
435
|
+
cache_data = json.load(f)
|
|
436
|
+
|
|
437
|
+
cached_hash = cache_data.get("schema_hash", "")
|
|
438
|
+
if cached_hash != schema.schema_hash:
|
|
439
|
+
debug(
|
|
440
|
+
f"[qsim_struct.load_or_create_skeletons] schema_hash mismatch: {cached_hash} != {schema.schema_hash}, regenerating"
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
skeletons_data = cache_data.get("skeletons", {})
|
|
444
|
+
for table_key_str, skel_list in skeletons_data.items():
|
|
445
|
+
table_key = frozenset(table_key_str.split("|"))
|
|
446
|
+
_SKELETON_CACHE[table_key] = [
|
|
447
|
+
QSimSkeleton(
|
|
448
|
+
tables=s["tables"],
|
|
449
|
+
has_aggregation=s["has_aggregation"],
|
|
450
|
+
num_filters=s["num_filters"],
|
|
451
|
+
num_groupby=s["num_groupby"],
|
|
452
|
+
has_orderby=s["has_orderby"],
|
|
453
|
+
has_having=s["has_having"],
|
|
454
|
+
has_distinct=s.get("has_distinct", False),
|
|
455
|
+
has_expr_comparison=s.get(
|
|
456
|
+
"has_expr_comparison",
|
|
457
|
+
s.get("has_column_comparison", False),
|
|
458
|
+
),
|
|
459
|
+
)
|
|
460
|
+
for s in skel_list
|
|
461
|
+
]
|
|
462
|
+
debug(f"[qsim_struct.load_or_create_skeletons] loaded {len(_SKELETON_CACHE)} table sets from cache")
|
|
463
|
+
return _SKELETON_CACHE
|
|
464
|
+
except Exception as e:
|
|
465
|
+
debug(f"[qsim_struct.load_or_create_skeletons] cache_load_failed: {e}")
|
|
466
|
+
|
|
467
|
+
debug("[qsim_struct.load_or_create_skeletons] generating new skeletons")
|
|
468
|
+
table_sets = enumerate_table_sets(schema, QSimConfig.MAX_TABLES_PER_INTENT)
|
|
469
|
+
|
|
470
|
+
for table_set in table_sets:
|
|
471
|
+
generate_all_skeletons(table_set, schema, column_roles)
|
|
472
|
+
|
|
473
|
+
cache_data = {
|
|
474
|
+
"schema_hash": schema.schema_hash,
|
|
475
|
+
"num_table_sets": len(_SKELETON_CACHE),
|
|
476
|
+
"skeletons": {"|".join(sorted(k)): [asdict(s) for s in v] for k, v in _SKELETON_CACHE.items()},
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
debug(f"[qsim_struct.load_or_create_skeletons] saving {len(_SKELETON_CACHE)} table sets to cache")
|
|
480
|
+
with open(skeleton_path, "w", encoding="utf-8") as f:
|
|
481
|
+
json.dump(cache_data, f, indent=2)
|
|
482
|
+
|
|
483
|
+
return _SKELETON_CACHE
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def decompose_between_filter(f: QSimFilter) -> list[QSimFilter]:
|
|
487
|
+
"""Decompose a ``BETWEEN`` ``QSimFilter`` into a ``>=`` and a ``<=`` pair.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
|
|
491
|
+
f: The filter to decompose.
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
|
|
495
|
+
A list of two ``QSimFilter`` instances (lower ``>=`` and upper ``<=`` bounds) when ``f.op == "between"``; otherwise a single-element list with *f* unchanged.
|
|
496
|
+
"""
|
|
497
|
+
if f.op != "between":
|
|
498
|
+
return [f]
|
|
499
|
+
return [
|
|
500
|
+
replace(f, op=">="),
|
|
501
|
+
replace(f, op="<="),
|
|
502
|
+
]
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def build_schema_context(tables: list[str], schema: SchemaGraph) -> str:
|
|
506
|
+
"""Build a schema context string for inclusion in LLM prompts.
|
|
507
|
+
|
|
508
|
+
Lists each table with its description and all column names annotated with data type, PK/FK markers, and filterability.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
|
|
512
|
+
tables: Ordered list of table names to include.
|
|
513
|
+
|
|
514
|
+
schema: Schema graph containing table and column metadata.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
|
|
518
|
+
A formatted multi-line string with one ``TABLE ...`` block per table, joined by blank lines.
|
|
519
|
+
"""
|
|
520
|
+
context_parts = []
|
|
521
|
+
|
|
522
|
+
for table in tables:
|
|
523
|
+
table_ir = schema.tables.get(table)
|
|
524
|
+
if not table_ir:
|
|
525
|
+
continue
|
|
526
|
+
|
|
527
|
+
col_descriptions = []
|
|
528
|
+
for col_name, col_meta in table_ir.columns.items():
|
|
529
|
+
col_type = col_meta.data_type or "unknown"
|
|
530
|
+
col_desc = f"{col_name} ({col_type})"
|
|
531
|
+
if col_meta.is_primary_key:
|
|
532
|
+
col_desc += " [PK]"
|
|
533
|
+
if col_meta.is_foreign_key:
|
|
534
|
+
fk_target = f"{col_meta.fk_target[0]}.{col_meta.fk_target[1]}" if col_meta.fk_target else "?"
|
|
535
|
+
col_desc += f" [FK -> {fk_target}]"
|
|
536
|
+
if col_meta.is_filterable:
|
|
537
|
+
col_desc += " [filterable]"
|
|
538
|
+
col_descriptions.append(col_desc)
|
|
539
|
+
|
|
540
|
+
table_desc = table_ir.description or f"{table} table"
|
|
541
|
+
context_parts.append(f"TABLE {table} ({table_desc}):\n " + "\n ".join(col_descriptions))
|
|
542
|
+
|
|
543
|
+
return "\n\n".join(context_parts)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def validate_column_exists(col_ref: str, tables: list[str], schema: SchemaGraph) -> bool:
|
|
547
|
+
"""Return whether a ``table.column`` reference is valid within a set of tables.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
|
|
551
|
+
col_ref: Fully-qualified column reference string (``"table.column"``).
|
|
552
|
+
|
|
553
|
+
tables: Allowed table names; the reference's table must be in this list.
|
|
554
|
+
|
|
555
|
+
schema: Schema graph for column existence checks.
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
|
|
559
|
+
``True`` if *col_ref* is in ``table.column`` format, the table is in *tables*, and the column exists in the schema; ``False`` otherwise.
|
|
560
|
+
"""
|
|
561
|
+
if "." not in col_ref:
|
|
562
|
+
return False
|
|
563
|
+
table, col = col_ref.split(".", 1)
|
|
564
|
+
if table not in tables:
|
|
565
|
+
return False
|
|
566
|
+
table_ir = schema.tables.get(table)
|
|
567
|
+
if not table_ir:
|
|
568
|
+
return False
|
|
569
|
+
return col in table_ir.columns
|