aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/core_utils.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
1
|
+
"""Core utility functions shared across all pipeline stages, with no project-level dependencies other than config.
|
|
2
|
+
|
|
3
|
+
Provides hashing and stable JSON serialization for intent and schema fingerprinting, SQL canonicalization and parameter substitution, LLM communication wrappers with retry logic, string processing helpers, and user interaction and display utilities.
|
|
4
|
+
|
|
5
|
+
This module is the lowest-level shared layer and is imported by every other module that needs LLM access, logging, or SQL normalization.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
import re
|
|
13
|
+
import time
|
|
14
|
+
from decimal import Decimal
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from openai import OpenAI
|
|
18
|
+
|
|
19
|
+
from .config import EngineConfig, PolicyConfig
|
|
20
|
+
|
|
21
|
+
_client = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_client() -> OpenAI:
|
|
25
|
+
"""Lazy-initialize OpenAI client on first use.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
|
|
29
|
+
Singleton OpenAI client configured from EngineConfig.
|
|
30
|
+
"""
|
|
31
|
+
global _client
|
|
32
|
+
if _client is None:
|
|
33
|
+
_client = OpenAI(api_key=EngineConfig.API_TOKEN, base_url=EngineConfig.OPENAI_BASE_URL)
|
|
34
|
+
return _client
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def log(msg: str) -> None:
|
|
38
|
+
"""Print a log message with [LOG] prefix when VERBOSE mode is enabled.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
|
|
42
|
+
msg: Message to print.
|
|
43
|
+
"""
|
|
44
|
+
if PolicyConfig.VERBOSE:
|
|
45
|
+
print(f"[LOG] {msg}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def debug(msg: str) -> None:
|
|
49
|
+
"""Print a debug message with [DEBUG] prefix when DEBUG mode is enabled.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
|
|
53
|
+
msg: Message to print.
|
|
54
|
+
"""
|
|
55
|
+
if PolicyConfig.DEBUG:
|
|
56
|
+
print(f"[DEBUG] {msg}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def sha256(s: str) -> str:
|
|
60
|
+
"""Compute SHA-256 hash of a string.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
|
|
64
|
+
s: Input string to hash.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
|
|
68
|
+
Hex digest string of the SHA-256 hash.
|
|
69
|
+
"""
|
|
70
|
+
return hashlib.sha256(s.encode("utf-8")).hexdigest()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _strip_fences(s: str) -> str:
|
|
74
|
+
"""Remove markdown code fences from a string.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
|
|
78
|
+
s: Raw string that may be wrapped in triple-backtick fences.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
|
|
82
|
+
String with leading and trailing markdown fences removed and stripped.
|
|
83
|
+
"""
|
|
84
|
+
s = s.strip()
|
|
85
|
+
if s.startswith("```"):
|
|
86
|
+
s = re.sub(r"^```[a-zA-Z]*\s*", "", s)
|
|
87
|
+
s = re.sub(r"\s*```$", "", s)
|
|
88
|
+
return s.strip()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def canonicalize_sql(sql: str) -> str:
|
|
92
|
+
"""Normalize SQL whitespace, formatting, and join equality operand order.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
|
|
96
|
+
sql: Raw SQL string, possibly with extra whitespace or markdown fences.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
|
|
100
|
+
Canonicalized SQL string with consistent spacing and canonical operand order in equality conditions.
|
|
101
|
+
"""
|
|
102
|
+
s = _strip_fences(sql).strip()
|
|
103
|
+
s = s.rstrip(";").strip()
|
|
104
|
+
s = re.sub(r"^EXPLAIN\s+(?:ANALYZE\s+)?", "", s, flags=re.IGNORECASE)
|
|
105
|
+
s = re.sub(r"\s+", " ", s).strip()
|
|
106
|
+
s = re.sub(r"\(\s+", "(", s)
|
|
107
|
+
s = re.sub(r"\s+\)", ")", s)
|
|
108
|
+
s = re.sub(r"\s*,\s*", ", ", s)
|
|
109
|
+
s = re.sub(r"(?<![><!=])=(?![>=])", " = ", s)
|
|
110
|
+
s = re.sub(r"\s+", " ", s).strip()
|
|
111
|
+
|
|
112
|
+
def normalize_equality(m: re.Match) -> str:
|
|
113
|
+
left, right = m.group(1).strip(), m.group(2).strip()
|
|
114
|
+
if left > right:
|
|
115
|
+
left, right = right, left
|
|
116
|
+
return f"{left} = {right}"
|
|
117
|
+
|
|
118
|
+
s = re.sub(r"([^\s()><!=]+)\s*=\s*([^\s()><!=]+)", normalize_equality, s)
|
|
119
|
+
return s
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def stable_json(o: Any) -> str:
|
|
123
|
+
"""Serialize an object to JSON with stable key ordering.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
|
|
127
|
+
o: Any JSON-serializable object.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
|
|
131
|
+
Compact JSON string with sorted keys for deterministic hashing.
|
|
132
|
+
"""
|
|
133
|
+
return json.dumps(o, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def colmap_signature(column_map: dict[str, str]) -> str:
|
|
137
|
+
"""Compute a stable hash signature for a column map.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
|
|
141
|
+
column_map: Mapping from bare column names to table names.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
|
|
145
|
+
SHA-256 hex digest of the sorted, stable-serialized column map.
|
|
146
|
+
"""
|
|
147
|
+
return sha256(stable_json(sorted(column_map.items())))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def intent_id(d: dict[str, Any]) -> str:
|
|
151
|
+
"""Generate a 16-character intent identifier from a dictionary.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
|
|
155
|
+
d: Dictionary representing a canonical intent structure.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
|
|
159
|
+
First 16 characters of the SHA-256 hex digest of the stable JSON.
|
|
160
|
+
"""
|
|
161
|
+
return sha256(stable_json(d))[:16]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def issue_sig(rule_ids: list[str]) -> str:
|
|
165
|
+
"""Generate an issue signature from a list of rule identifiers.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
|
|
169
|
+
rule_ids: List of issue or rule ID strings.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
|
|
173
|
+
SHA-256 hex digest of the sorted rule IDs for ABAB loop detection.
|
|
174
|
+
"""
|
|
175
|
+
return sha256(stable_json(sorted(rule_ids)))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def schema_hash_fp(tables_dict: dict[str, Any]) -> str:
|
|
179
|
+
"""Generate a schema fingerprint from a tables dictionary.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
|
|
183
|
+
tables_dict: Raw tables dictionary as returned by schema profiling.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
|
|
187
|
+
SHA-256 hex digest used as the schema_hash for a SchemaGraph.
|
|
188
|
+
"""
|
|
189
|
+
return sha256(stable_json({"tables": tables_dict}))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def normalize_question(q: str) -> str:
|
|
193
|
+
"""Normalize a user question while preserving the case of quoted values.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
|
|
197
|
+
q: Raw user question string.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
|
|
201
|
+
Lowercased, whitespace-normalized question with special characters removed, but with quoted literals restored to their original case.
|
|
202
|
+
"""
|
|
203
|
+
q = q.strip()
|
|
204
|
+
q = re.sub(r"[\u2018\u2019\u201c\u201d]", "'", q)
|
|
205
|
+
|
|
206
|
+
quoted_values = []
|
|
207
|
+
|
|
208
|
+
def preserve_quoted(m):
|
|
209
|
+
quoted_values.append(m.group(1))
|
|
210
|
+
return f"__QUOTED_{len(quoted_values) - 1}__"
|
|
211
|
+
|
|
212
|
+
q = re.sub(r"'([^']*)'", preserve_quoted, q)
|
|
213
|
+
|
|
214
|
+
q = q.lower()
|
|
215
|
+
q = re.sub(r"[^a-z0-9\s_:/\-\.,\?]", " ", q)
|
|
216
|
+
q = re.sub(r"\s+", " ", q).strip()
|
|
217
|
+
|
|
218
|
+
for i, val in enumerate(quoted_values):
|
|
219
|
+
q = q.replace(f"__quoted_{i}__", f"'{val}'")
|
|
220
|
+
|
|
221
|
+
return q
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _extract_first_json_object(s: str) -> str | None:
|
|
225
|
+
"""Extract the first JSON object from a string by scanning for balanced braces.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
|
|
229
|
+
s: String that may contain a JSON object embedded in other text.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
|
|
233
|
+
The first balanced {...} substring, or None if not found.
|
|
234
|
+
"""
|
|
235
|
+
s = _strip_fences(s)
|
|
236
|
+
start = s.find("{")
|
|
237
|
+
if start == -1:
|
|
238
|
+
return None
|
|
239
|
+
depth = 0
|
|
240
|
+
for i in range(start, len(s)):
|
|
241
|
+
if s[i] == "{":
|
|
242
|
+
depth += 1
|
|
243
|
+
elif s[i] == "}":
|
|
244
|
+
depth -= 1
|
|
245
|
+
if depth == 0:
|
|
246
|
+
return s[start : i + 1]
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def safe_json_loads(s: str) -> Any | None:
|
|
251
|
+
"""Safely parse JSON with fallback to fragment extraction.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
|
|
255
|
+
s: Raw string to parse as JSON.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
|
|
259
|
+
Parsed Python object, or None if parsing fails entirely.
|
|
260
|
+
"""
|
|
261
|
+
s = s.strip()
|
|
262
|
+
try:
|
|
263
|
+
return json.loads(s)
|
|
264
|
+
except Exception:
|
|
265
|
+
pass
|
|
266
|
+
frag = _extract_first_json_object(s)
|
|
267
|
+
if frag:
|
|
268
|
+
try:
|
|
269
|
+
return json.loads(frag)
|
|
270
|
+
except Exception:
|
|
271
|
+
return None
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def parameter_abstract(sql: str) -> tuple[str, dict[str, Any]]:
|
|
276
|
+
"""Replace literal values in SQL with named parameter placeholders.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
|
|
280
|
+
sql: SQL string with inline literal values.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
|
|
284
|
+
Tuple of (parameterized SQL string, dict mapping placeholder keys to original values).
|
|
285
|
+
"""
|
|
286
|
+
s = sql
|
|
287
|
+
params: dict[str, Any] = {}
|
|
288
|
+
idx = 1
|
|
289
|
+
|
|
290
|
+
def put(val: str) -> str:
|
|
291
|
+
nonlocal idx
|
|
292
|
+
key = f"p{idx}"
|
|
293
|
+
idx += 1
|
|
294
|
+
params[key] = val
|
|
295
|
+
return f":{key}"
|
|
296
|
+
|
|
297
|
+
s = re.sub(r"(\b\d{4}-\d{2}-\d{2}\b)", lambda m: put(m.group(1)), s)
|
|
298
|
+
s = re.sub(r"(\b\d{2}/\d{2}/\d{4}\b)", lambda m: put(m.group(1)), s)
|
|
299
|
+
s = re.sub(r"(\b\d+(\.\d+)?\b)", lambda m: put(m.group(1)), s)
|
|
300
|
+
s = re.sub(r"('([^']|'')*')", lambda m: put(m.group(1)), s)
|
|
301
|
+
s = re.sub(r"\s+", " ", s).strip()
|
|
302
|
+
return s, params
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
_TASK_PROFILES: dict[str, dict[str, Any]] = {
|
|
306
|
+
"intent": {
|
|
307
|
+
"model": EngineConfig.OPENAI_MODEL_INTENT,
|
|
308
|
+
"reasoning": {"effort": "low", "summary": "concise"},
|
|
309
|
+
},
|
|
310
|
+
"schema": {
|
|
311
|
+
"model": EngineConfig.OPENAI_MODEL_SCHEMA,
|
|
312
|
+
"reasoning": {"effort": "low", "summary": "concise"},
|
|
313
|
+
},
|
|
314
|
+
"sql": {
|
|
315
|
+
"model": EngineConfig.OPENAI_MODEL_SQL,
|
|
316
|
+
"temperature": 0,
|
|
317
|
+
},
|
|
318
|
+
"default": {
|
|
319
|
+
"model": EngineConfig.OPENAI_MODEL,
|
|
320
|
+
"temperature": 0,
|
|
321
|
+
},
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def llm_chat(
|
|
326
|
+
system: str,
|
|
327
|
+
user: str,
|
|
328
|
+
max_retries: int = 3,
|
|
329
|
+
timeout: float = 60.0,
|
|
330
|
+
task: str = "default",
|
|
331
|
+
) -> str:
|
|
332
|
+
"""Send a chat completion request to the LLM with task-based model routing.
|
|
333
|
+
|
|
334
|
+
Selects the model and API parameters based on the task identifier. Reasoning-capable models receive reasoning parameters instead of temperature. All requests enforce structured JSON output via text.format.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
|
|
338
|
+
system: System prompt string.
|
|
339
|
+
user: User prompt string.
|
|
340
|
+
max_retries: Maximum number of attempts before raising.
|
|
341
|
+
timeout: Per-request timeout in seconds.
|
|
342
|
+
task: Routing key selecting model and parameters. One of "intent", "sql", "schema", or "default".
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
|
|
346
|
+
Stripped response text from the LLM.
|
|
347
|
+
"""
|
|
348
|
+
profile = _TASK_PROFILES.get(task, _TASK_PROFILES["default"])
|
|
349
|
+
model = profile["model"]
|
|
350
|
+
|
|
351
|
+
kwargs: dict[str, Any] = {
|
|
352
|
+
"model": model,
|
|
353
|
+
"input": [
|
|
354
|
+
{
|
|
355
|
+
"role": "system",
|
|
356
|
+
"content": [{"type": "input_text", "text": system}],
|
|
357
|
+
},
|
|
358
|
+
{"role": "user", "content": [{"type": "input_text", "text": user}]},
|
|
359
|
+
],
|
|
360
|
+
"timeout": timeout,
|
|
361
|
+
"text": {"format": {"type": "json_object"}},
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
if "reasoning" in profile:
|
|
365
|
+
kwargs["reasoning"] = profile["reasoning"]
|
|
366
|
+
else:
|
|
367
|
+
kwargs["temperature"] = profile.get("temperature", 0)
|
|
368
|
+
|
|
369
|
+
client = _get_client()
|
|
370
|
+
|
|
371
|
+
debug(f"[core_utils.llm_chat] task={task} system_len={len(system)} user_len={len(user)}")
|
|
372
|
+
|
|
373
|
+
for attempt in range(max_retries):
|
|
374
|
+
try:
|
|
375
|
+
start = time.time()
|
|
376
|
+
r = client.responses.create(**kwargs)
|
|
377
|
+
elapsed = time.time() - start
|
|
378
|
+
output = r.output_text.strip()
|
|
379
|
+
debug(f"[core_utils.llm_chat] RAW OUTPUT:\n{output}")
|
|
380
|
+
debug(
|
|
381
|
+
f"[core_utils.llm_chat] model={model} task={task} "
|
|
382
|
+
f"completed in {elapsed:.1f}s (attempt {attempt + 1}/{max_retries})"
|
|
383
|
+
)
|
|
384
|
+
return output
|
|
385
|
+
except Exception as e:
|
|
386
|
+
elapsed = time.time() - start
|
|
387
|
+
log(
|
|
388
|
+
f"[core_utils.llm_chat] timeout or error after {elapsed:.1f}s (attempt {attempt + 1}/{max_retries}): {str(e)[:100]}"
|
|
389
|
+
)
|
|
390
|
+
if attempt < max_retries - 1:
|
|
391
|
+
wait = 2**attempt
|
|
392
|
+
log(f"[core_utils.llm_chat] retrying in {wait}s...")
|
|
393
|
+
time.sleep(wait)
|
|
394
|
+
else:
|
|
395
|
+
raise RuntimeError(f"LLM call failed after {max_retries} attempts: {str(e)}") from e
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def normalize_sql_operator_spaces(sql: str) -> str:
|
|
399
|
+
"""Collapse spaces inside comparison operators in SQL so parsers see a single token.
|
|
400
|
+
|
|
401
|
+
Replaces ``> =`` with ``>=``, ``< =`` with ``<=``, and ``! =`` with ``!=`` so that LLM-generated SQL with spaces inside operators parses correctly.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
|
|
405
|
+
sql: Raw SQL string that may contain spaced operators.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
|
|
409
|
+
SQL string with operator spaces collapsed.
|
|
410
|
+
|
|
411
|
+
"""
|
|
412
|
+
if not sql or not sql.strip():
|
|
413
|
+
return sql
|
|
414
|
+
s = sql.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=")
|
|
415
|
+
return s
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def normalize_sql(sql: str) -> str:
|
|
419
|
+
"""Normalize SQL ORDER BY clauses to include explicit ASC or DESC direction.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
|
|
423
|
+
sql: Raw SQL string.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
|
|
427
|
+
Canonicalized SQL with each ORDER BY column explicitly marked ASC or DESC.
|
|
428
|
+
"""
|
|
429
|
+
s = canonicalize_sql(sql)
|
|
430
|
+
if not s:
|
|
431
|
+
return s
|
|
432
|
+
|
|
433
|
+
s_upper = s.upper()
|
|
434
|
+
order_by_pos = s_upper.find("ORDER BY")
|
|
435
|
+
if order_by_pos != -1:
|
|
436
|
+
before_order = s[: order_by_pos + 8]
|
|
437
|
+
after_order = s[order_by_pos + 8 :]
|
|
438
|
+
|
|
439
|
+
limit_pos = after_order.upper().find("LIMIT")
|
|
440
|
+
if limit_pos != -1:
|
|
441
|
+
order_clause = after_order[:limit_pos].strip()
|
|
442
|
+
rest = after_order[limit_pos:]
|
|
443
|
+
else:
|
|
444
|
+
order_clause = after_order.strip()
|
|
445
|
+
rest = ""
|
|
446
|
+
|
|
447
|
+
normalized_items = []
|
|
448
|
+
for item in order_clause.split(","):
|
|
449
|
+
item = item.strip()
|
|
450
|
+
if not item:
|
|
451
|
+
continue
|
|
452
|
+
item_upper = item.upper()
|
|
453
|
+
if item_upper.endswith(" ASC") or item_upper.endswith(" DESC"):
|
|
454
|
+
normalized_items.append(item)
|
|
455
|
+
else:
|
|
456
|
+
normalized_items.append(f"{item} ASC")
|
|
457
|
+
|
|
458
|
+
if normalized_items:
|
|
459
|
+
s = f"{before_order} {', '.join(normalized_items)}"
|
|
460
|
+
if rest:
|
|
461
|
+
s = f"{s} {rest}"
|
|
462
|
+
|
|
463
|
+
return s
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def llm_json(system: str, user: str, retries: int = 1, task: str = "default") -> dict[str, Any]:
|
|
467
|
+
"""Request a JSON response from the LLM with retry on parse failure.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
|
|
471
|
+
system: System prompt string.
|
|
472
|
+
user: User prompt string.
|
|
473
|
+
retries: Number of additional retry attempts if the initial response is not valid JSON.
|
|
474
|
+
task: Routing key for model selection passed to llm_chat.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
|
|
478
|
+
Parsed JSON dict, or an empty dict if all attempts fail.
|
|
479
|
+
"""
|
|
480
|
+
raw = llm_chat(system, user, task=task)
|
|
481
|
+
parsed = safe_json_loads(raw)
|
|
482
|
+
if isinstance(parsed, dict):
|
|
483
|
+
debug(f"[core_utils.llm_json] parsed keys={list(parsed.keys())}")
|
|
484
|
+
return parsed
|
|
485
|
+
|
|
486
|
+
if raw.strip().upper().startswith("SELECT"):
|
|
487
|
+
debug("[core_utils.llm_json] raw_sql_detected wrapping")
|
|
488
|
+
sql_statement = raw.strip()
|
|
489
|
+
wrapped = {"sql": sql_statement, "chosen_join_candidate_id": "J00"}
|
|
490
|
+
return wrapped
|
|
491
|
+
|
|
492
|
+
debug("[core_utils.llm_json] parse_failed: retrying")
|
|
493
|
+
for attempt in range(max(0, retries)):
|
|
494
|
+
debug(f"[core_utils.llm_json] retry: {attempt + 1}")
|
|
495
|
+
raw = llm_chat(
|
|
496
|
+
system,
|
|
497
|
+
user + "\n\nFORMAT_ERROR: Output ONLY valid JSON that matches the required schema. Do NOT output raw SQL.",
|
|
498
|
+
task=task,
|
|
499
|
+
)
|
|
500
|
+
parsed = safe_json_loads(raw)
|
|
501
|
+
if isinstance(parsed, dict):
|
|
502
|
+
debug(f"[core_utils.llm_json] retry_success: keys={list(parsed.keys())}")
|
|
503
|
+
return parsed
|
|
504
|
+
|
|
505
|
+
if raw.strip().upper().startswith("SELECT"):
|
|
506
|
+
debug("[core_utils.llm_json] retry_sql_detected: wrapping")
|
|
507
|
+
sql_statement = raw.strip()
|
|
508
|
+
wrapped = {"sql": sql_statement, "chosen_join_candidate_id": "J00"}
|
|
509
|
+
return wrapped
|
|
510
|
+
|
|
511
|
+
debug("[core_utils.llm_json] all_retries_failed")
|
|
512
|
+
return {}
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def llm_sql_with_join(system: str, user: str, task: str = "sql") -> tuple[str, str, bool, dict[str, str]]:
|
|
516
|
+
"""Request SQL generation with join candidate selection from the LLM.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
|
|
520
|
+
system: System prompt string.
|
|
521
|
+
user: User prompt string including join candidates and expression guides.
|
|
522
|
+
task: Routing key for model selection passed to llm_json.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
|
|
526
|
+
Tuple of (sql, chosen_join_candidate_id, success_bool, cte_join_candidate_ids_dict).
|
|
527
|
+
"""
|
|
528
|
+
debug("[core_utils.llm_sql_with_join] requesting")
|
|
529
|
+
parsed = llm_json(system, user, retries=1, task=task)
|
|
530
|
+
if not parsed:
|
|
531
|
+
debug("[core_utils.llm_sql_with_join] llm_json returned empty")
|
|
532
|
+
return "", "", False, {}
|
|
533
|
+
|
|
534
|
+
sql = parsed.get("sql") if isinstance(parsed.get("sql"), str) else ""
|
|
535
|
+
chosen = parsed.get("chosen_join_candidate_id") if isinstance(parsed.get("chosen_join_candidate_id"), str) else ""
|
|
536
|
+
|
|
537
|
+
cte_join_ids: dict[str, str] = {}
|
|
538
|
+
raw_cte_ids = parsed.get("chosen_cte_join_candidate_ids")
|
|
539
|
+
if isinstance(raw_cte_ids, dict):
|
|
540
|
+
for cte_name, cid in raw_cte_ids.items():
|
|
541
|
+
if isinstance(cte_name, str) and isinstance(cid, str):
|
|
542
|
+
cte_join_ids[cte_name] = cid
|
|
543
|
+
|
|
544
|
+
if sql and not chosen:
|
|
545
|
+
if "join" not in sql.lower():
|
|
546
|
+
chosen = "J00"
|
|
547
|
+
debug("[core_utils.llm_sql_with_join] default_J00: no_joins")
|
|
548
|
+
|
|
549
|
+
sql = normalize_sql_operator_spaces(sql) if sql else ""
|
|
550
|
+
sql = canonicalize_sql(sql) if sql else ""
|
|
551
|
+
ok = bool(sql)
|
|
552
|
+
debug(f"[core_utils.llm_sql_with_join] result: join={chosen} ok={ok} cte_ids={cte_join_ids}")
|
|
553
|
+
debug(f"[core_utils.llm_sql_with_join] result SQL:\n{sql}")
|
|
554
|
+
return sql, chosen, ok, cte_join_ids
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def substitute_params(sql_param: str, params: dict[str, Any]) -> str:
|
|
558
|
+
"""Substitute parameter placeholders in SQL with their actual values.
|
|
559
|
+
|
|
560
|
+
Replaces :key placeholders, reduces trivial 1.0 * coefficients and +0 offsets, and strips LIMIT None.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
|
|
564
|
+
sql_param: Parameterized SQL string with :key placeholders.
|
|
565
|
+
params: Dict mapping placeholder keys to their resolved values.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
|
|
569
|
+
SQL string with all placeholders replaced and trivial arithmetic simplified.
|
|
570
|
+
"""
|
|
571
|
+
result = sql_param
|
|
572
|
+
for key in sorted(params.keys(), key=lambda k: -len(k)):
|
|
573
|
+
val = params[key]
|
|
574
|
+
if not key:
|
|
575
|
+
continue
|
|
576
|
+
placeholder = f":{key}"
|
|
577
|
+
if isinstance(val, list):
|
|
578
|
+
formatted_items = []
|
|
579
|
+
for item in val:
|
|
580
|
+
if isinstance(item, str):
|
|
581
|
+
formatted_items.append(f"'{item}'")
|
|
582
|
+
else:
|
|
583
|
+
formatted_items.append(str(item))
|
|
584
|
+
result = result.replace(placeholder, ", ".join(formatted_items))
|
|
585
|
+
elif isinstance(val, bool):
|
|
586
|
+
result = result.replace(placeholder, "TRUE" if val else "FALSE")
|
|
587
|
+
elif isinstance(val, str):
|
|
588
|
+
if val.startswith("'") and val.endswith("'") and "','" in val:
|
|
589
|
+
result = result.replace(placeholder, val)
|
|
590
|
+
elif re.match(r"^-?\d+(?:\.\d+)?(?:,\s*-?\d+(?:\.\d+)?)*$", val):
|
|
591
|
+
result = result.replace(placeholder, val)
|
|
592
|
+
else:
|
|
593
|
+
result = result.replace(placeholder, f"'{val}'")
|
|
594
|
+
else:
|
|
595
|
+
result = result.replace(placeholder, str(val))
|
|
596
|
+
result = re.sub(r"\b1(?:\.0)?\s*\*\s*", "", result)
|
|
597
|
+
result = re.sub(r"\s*\*\s*1(?:\.0)?\b", "", result)
|
|
598
|
+
result = re.sub(r"\s*[+\-]\s*0(?:\.0)?\b", "", result)
|
|
599
|
+
result = re.sub(r"\s*LIMIT\s+None\b", "", result, flags=re.IGNORECASE)
|
|
600
|
+
return result
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def sql_fp(sql_param: str) -> str:
|
|
604
|
+
"""Compute a SQL fingerprint hash from a parameterized SQL string.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
|
|
608
|
+
sql_param: Parameterized SQL string.
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
|
|
612
|
+
SHA-256 hex digest of the lowercased SQL for template deduplication.
|
|
613
|
+
"""
|
|
614
|
+
return sha256(sql_param.lower())
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def _format_cell(v) -> str:
|
|
618
|
+
"""Format a single query result cell as a clean display string.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
|
|
622
|
+
v: Cell value of any type returned by the database driver.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
|
|
626
|
+
Human-readable string representation of the cell value.
|
|
627
|
+
"""
|
|
628
|
+
if v is None:
|
|
629
|
+
return "NULL"
|
|
630
|
+
if isinstance(v, Decimal):
|
|
631
|
+
return f"{v:f}" if v == v.to_integral_value() else f"{v}"
|
|
632
|
+
if isinstance(v, str):
|
|
633
|
+
return v
|
|
634
|
+
return str(v)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def print_query_result(
|
|
638
|
+
rows: list[tuple],
|
|
639
|
+
sql: str,
|
|
640
|
+
context: str = "Query Results",
|
|
641
|
+
headers: list[str] | None = None,
|
|
642
|
+
) -> None:
|
|
643
|
+
"""Display query results in a standardized tabular format.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
|
|
647
|
+
rows: List of row tuples returned by the database.
|
|
648
|
+
sql: The SQL string that produced the results.
|
|
649
|
+
context: Section heading to display above the results.
|
|
650
|
+
headers: Optional list of column header names.
|
|
651
|
+
"""
|
|
652
|
+
print(f"\n{context}\n")
|
|
653
|
+
print(f"SQL:\n {sql}\n")
|
|
654
|
+
|
|
655
|
+
if len(rows) == 1 and len(rows[0]) == 1:
|
|
656
|
+
val = rows[0][0]
|
|
657
|
+
print(f"Answer: {_format_cell(val)}\n")
|
|
658
|
+
else:
|
|
659
|
+
sample = rows[:5]
|
|
660
|
+
formatted = [[_format_cell(v) for v in row] for row in sample]
|
|
661
|
+
num_cols = max(len(r) for r in formatted) if formatted else 0
|
|
662
|
+
col_headers = (headers or [])[:num_cols]
|
|
663
|
+
while len(col_headers) < num_cols:
|
|
664
|
+
col_headers.append(f"col{len(col_headers) + 1}")
|
|
665
|
+
widths = [len(h) for h in col_headers]
|
|
666
|
+
for row in formatted:
|
|
667
|
+
for i, cell in enumerate(row):
|
|
668
|
+
if i < len(widths):
|
|
669
|
+
widths[i] = max(widths[i], len(cell))
|
|
670
|
+
header_line = " ".join(h.ljust(widths[i]) for i, h in enumerate(col_headers))
|
|
671
|
+
sep_line = " ".join("-" * widths[i] for i in range(num_cols))
|
|
672
|
+
print(f" {header_line}")
|
|
673
|
+
print(f" {sep_line}")
|
|
674
|
+
for row in formatted:
|
|
675
|
+
line = " ".join((row[i] if i < len(row) else "").ljust(widths[i]) for i in range(num_cols))
|
|
676
|
+
print(f" {line}")
|
|
677
|
+
if len(rows) > 5:
|
|
678
|
+
print(f" ... ({len(rows)} total rows)")
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def ask_user_choice(prompt: str, options: list[str], silent_no: bool = False) -> str | None:
|
|
682
|
+
"""Prompt the user to select from a list of options.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
|
|
686
|
+
prompt: Question or instruction to display to the user.
|
|
687
|
+
options: List of valid option strings shown in brackets.
|
|
688
|
+
silent_no: If True, suppresses the termination message on 'n' input.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
|
|
692
|
+
'y', 'n', or None if input is invalid or interrupted.
|
|
693
|
+
"""
|
|
694
|
+
options_display = "/".join(options)
|
|
695
|
+
print(f"{prompt} [{options_display}]: ", end="")
|
|
696
|
+
try:
|
|
697
|
+
user_input = input().strip()
|
|
698
|
+
except (EOFError, KeyboardInterrupt):
|
|
699
|
+
print("\nUser terminated.")
|
|
700
|
+
return None
|
|
701
|
+
|
|
702
|
+
if not user_input or not user_input.strip():
|
|
703
|
+
print("\nInvalid input.")
|
|
704
|
+
return None
|
|
705
|
+
|
|
706
|
+
normalized = user_input.lower()
|
|
707
|
+
if normalized in ("y", "yes"):
|
|
708
|
+
print("Yes")
|
|
709
|
+
return "y"
|
|
710
|
+
elif normalized in ("n", "no"):
|
|
711
|
+
print("No")
|
|
712
|
+
if not silent_no:
|
|
713
|
+
print("\nUser terminated.")
|
|
714
|
+
return "n"
|
|
715
|
+
|
|
716
|
+
print("\nInvalid input.")
|
|
717
|
+
return None
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def print_info(title: str, items: dict[str, Any] = None, footer: str = None) -> None:
|
|
721
|
+
"""Display an informational message with optional structured key-value data.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
|
|
725
|
+
title: Main heading to display.
|
|
726
|
+
items: Optional dict of key-value pairs to display as indented lines.
|
|
727
|
+
footer: Optional trailing message to print after the items.
|
|
728
|
+
"""
|
|
729
|
+
print(f"\n{title}")
|
|
730
|
+
if items:
|
|
731
|
+
for key, val in items.items():
|
|
732
|
+
if isinstance(val, list | tuple | set):
|
|
733
|
+
val = ", ".join(str(v) for v in val)
|
|
734
|
+
print(f" {key}: {val}")
|
|
735
|
+
if footer:
|
|
736
|
+
print(f"\n{footer}")
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def join_sig_string(sig: list[str]) -> str:
|
|
740
|
+
"""Convert a join signature list to a pipe-delimited string for hashing.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
|
|
744
|
+
sig: List of join path component strings.
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
|
|
748
|
+
Pipe-delimited string of the join signature components.
|
|
749
|
+
"""
|
|
750
|
+
return "|".join(sig)
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
def explain_db_error_nl(err: str) -> str:
|
|
754
|
+
"""Convert a database error message to a human-readable explanation.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
|
|
758
|
+
err: Raw error string returned by the database driver.
|
|
759
|
+
|
|
760
|
+
Returns:
|
|
761
|
+
|
|
762
|
+
Plain-English explanation of the likely cause and suggested fix.
|
|
763
|
+
"""
|
|
764
|
+
e = (err or "").strip()
|
|
765
|
+
low = e.lower()
|
|
766
|
+
if "syntax error at or near" in low:
|
|
767
|
+
return "SQL syntax error. The query is not valid SQL (likely a misplaced keyword, comma, or parenthesis)."
|
|
768
|
+
if "does not exist" in low and "column" in low:
|
|
769
|
+
return "A referenced column name does not exist in that table (wrong column or wrong table alias)."
|
|
770
|
+
if "does not exist" in low and "relation" in low:
|
|
771
|
+
return "A referenced table or view does not exist (wrong table name or missing schema qualification)."
|
|
772
|
+
if "ambiguous" in low and "column" in low:
|
|
773
|
+
return "A column name is ambiguous because multiple joined tables share that column name. It needs a table alias prefix."
|
|
774
|
+
if "operator does not exist" in low:
|
|
775
|
+
return "A comparison uses incompatible types (e.g., comparing text to number/date). Casting or a different operator is needed."
|
|
776
|
+
if "invalid input syntax" in low:
|
|
777
|
+
return "A literal value cannot be parsed into the required type (e.g., invalid date or non-numeric text in a numeric field)."
|
|
778
|
+
if "division by zero" in low:
|
|
779
|
+
return "The query attempted division by zero. It needs a guard (e.g., NULLIF)."
|
|
780
|
+
if "more than one row returned by a subquery" in low:
|
|
781
|
+
return "A subquery used as a scalar returned multiple rows. It needs aggregation or a LIMIT."
|
|
782
|
+
if "permission denied" in low:
|
|
783
|
+
return "Database permissions do not allow this operation on the referenced object."
|
|
784
|
+
return "Database rejected the query. It needs a structural correction (tables/columns/joins/filters)."
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def is_repair_oscillating(fingerprints: list[str]) -> bool:
|
|
788
|
+
"""Detect AA or ABA oscillation in a sequence of repair fingerprints.
|
|
789
|
+
|
|
790
|
+
Used by both the SQL generation repair loop and the result-repair loop to detect when the LLM is producing the same output repeatedly (AA) or cycling between two outputs (ABA or ABAB).
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
|
|
794
|
+
fingerprints: Ordered list of fingerprint strings, one per repair iteration. Each fingerprint is typically a SQL fingerprint or an issue-signature string.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
|
|
798
|
+
True if the last two entries are identical (AA) or the last three form an A-B-A cycle.
|
|
799
|
+
"""
|
|
800
|
+
n = len(fingerprints)
|
|
801
|
+
if n >= 2 and fingerprints[-1] == fingerprints[-2]:
|
|
802
|
+
return True
|
|
803
|
+
if n >= 3 and fingerprints[-1] == fingerprints[-3]:
|
|
804
|
+
return True
|
|
805
|
+
return False
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
def normalize_op(op: str) -> str:
|
|
809
|
+
"""Normalize a comparison operator string to its canonical form.
|
|
810
|
+
|
|
811
|
+
Converts aliases such as '==', '<>', 'ne', 'gte', and others to standard SQL operators.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
|
|
815
|
+
op: Raw operator string as provided by the LLM.
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
|
|
819
|
+
Lowercased canonical operator string (for example '=', '!=', or '>=' ).
|
|
820
|
+
"""
|
|
821
|
+
op_lower = re.sub(r"\s+", " ", op.lower().strip())
|
|
822
|
+
mapping = {
|
|
823
|
+
"==": "=",
|
|
824
|
+
"<>": "!=",
|
|
825
|
+
"ne": "!=",
|
|
826
|
+
"eq": "=",
|
|
827
|
+
"gt": ">",
|
|
828
|
+
"lt": "<",
|
|
829
|
+
"ge": ">=",
|
|
830
|
+
"le": "<=",
|
|
831
|
+
"gte": ">=",
|
|
832
|
+
"lte": "<=",
|
|
833
|
+
}
|
|
834
|
+
return mapping.get(op_lower, op_lower)
|