sqlas 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlas/__init__.py +69 -0
- sqlas/context.py +268 -0
- sqlas/core.py +208 -0
- sqlas/correctness.py +289 -0
- sqlas/evaluate.py +218 -0
- sqlas/production.py +74 -0
- sqlas/py.typed +0 -0
- sqlas/quality.py +172 -0
- sqlas/response.py +133 -0
- sqlas/runner.py +133 -0
- sqlas/safety.py +76 -0
- sqlas-1.1.0.dist-info/METADATA +322 -0
- sqlas-1.1.0.dist-info/RECORD +16 -0
- sqlas-1.1.0.dist-info/WHEEL +5 -0
- sqlas-1.1.0.dist-info/licenses/LICENSE +21 -0
- sqlas-1.1.0.dist-info/top_level.txt +1 -0
sqlas/correctness.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core SQL Correctness Metrics.
|
|
3
|
+
- Execution Accuracy (output + structure + efficiency)
|
|
4
|
+
- Syntax Validity (sqlglot parse)
|
|
5
|
+
- Semantic Equivalence (LLM-as-Judge)
|
|
6
|
+
- Result Set Similarity (Jaccard on normalized result sets)
|
|
7
|
+
|
|
8
|
+
Author: SQLAS Contributors
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
import sqlite3
|
|
14
|
+
|
|
15
|
+
import sqlglot
|
|
16
|
+
|
|
17
|
+
from sqlas.core import LLMJudge, _parse_score
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Default SQL execution timeout in seconds
|
|
22
|
+
_DEFAULT_TIMEOUT_S = 30
|
|
23
|
+
_PROGRESS_INTERVAL = 1_000_000 # check timeout every N SQLite VM instructions
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _connect_readonly(db_path: str, timeout_s: int = _DEFAULT_TIMEOUT_S) -> sqlite3.Connection:
|
|
27
|
+
"""Open a read-only SQLite connection with a timeout guard."""
|
|
28
|
+
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_s)
|
|
29
|
+
_start = time.monotonic()
|
|
30
|
+
|
|
31
|
+
def _progress_handler():
|
|
32
|
+
if time.monotonic() - _start > timeout_s:
|
|
33
|
+
return 1 # non-zero → abort query
|
|
34
|
+
return 0
|
|
35
|
+
|
|
36
|
+
conn.set_progress_handler(_progress_handler, _PROGRESS_INTERVAL)
|
|
37
|
+
return conn
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ── Helpers ─────────────────────────────────────────────────────────────────
|
|
41
|
+
|
|
42
|
+
def _extract_row_numbers(row) -> list[float]:
|
|
43
|
+
return sorted([round(float(v), 2) for v in row if isinstance(v, (int, float)) and v is not None])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _values_found_in(needle: list[float], haystack: list[float], tol: float = 0.5) -> float:
|
|
47
|
+
if not needle:
|
|
48
|
+
return 1.0
|
|
49
|
+
remaining = list(haystack)
|
|
50
|
+
matched = 0
|
|
51
|
+
for nv in needle:
|
|
52
|
+
best_idx, best_diff = -1, float("inf")
|
|
53
|
+
for i, hv in enumerate(remaining):
|
|
54
|
+
diff = abs(nv - hv)
|
|
55
|
+
if diff < best_diff:
|
|
56
|
+
best_diff, best_idx = diff, i
|
|
57
|
+
if best_idx >= 0 and best_diff <= tol:
|
|
58
|
+
remaining.pop(best_idx)
|
|
59
|
+
matched += 1
|
|
60
|
+
return matched / len(needle)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _row_values_match(pred_nums: list[float], gold_nums: list[float], tol: float = 0.5) -> float:
|
|
64
|
+
if not gold_nums and not pred_nums:
|
|
65
|
+
return 1.0
|
|
66
|
+
if not gold_nums:
|
|
67
|
+
return 0.8
|
|
68
|
+
if not pred_nums:
|
|
69
|
+
return 0.0
|
|
70
|
+
if len(pred_nums) < len(gold_nums):
|
|
71
|
+
subset_score = _values_found_in(pred_nums, gold_nums, tol)
|
|
72
|
+
return 1.0 if subset_score >= 0.99 else subset_score
|
|
73
|
+
return _values_found_in(gold_nums, pred_nums, tol)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _match_result_sets(pred_rows: list, gold_rows: list) -> float:
|
|
77
|
+
if not gold_rows:
|
|
78
|
+
return 1.0 if not pred_rows else 0.5
|
|
79
|
+
pred_nums_list = [_extract_row_numbers(r) for r in pred_rows]
|
|
80
|
+
gold_nums_list = [_extract_row_numbers(r) for r in gold_rows]
|
|
81
|
+
used_pred = set()
|
|
82
|
+
total_score = 0.0
|
|
83
|
+
for gn in gold_nums_list:
|
|
84
|
+
best_score, best_pi = 0.0, -1
|
|
85
|
+
for pi, pn in enumerate(pred_nums_list):
|
|
86
|
+
if pi in used_pred:
|
|
87
|
+
continue
|
|
88
|
+
score = _row_values_match(pn, gn)
|
|
89
|
+
if score > best_score:
|
|
90
|
+
best_score, best_pi = score, pi
|
|
91
|
+
if best_pi >= 0:
|
|
92
|
+
used_pred.add(best_pi)
|
|
93
|
+
total_score += best_score
|
|
94
|
+
return total_score / len(gold_rows)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ── Public API ──────────────────────────────────────────────────────────────
|
|
98
|
+
|
|
99
|
+
def execution_accuracy(generated_sql: str, gold_sql: str, db_path: str) -> tuple[float, dict]:
|
|
100
|
+
"""
|
|
101
|
+
Semantic execution accuracy.
|
|
102
|
+
|
|
103
|
+
Formula: 60% Output Match + 20% Structure Match + 20% Efficiency
|
|
104
|
+
|
|
105
|
+
Output Match: Row-by-row numeric comparison. Ignores label differences
|
|
106
|
+
(0 vs 'Male'), tolerates ROUND, handles extra columns.
|
|
107
|
+
Structure: Same row count.
|
|
108
|
+
Efficiency: Generated query speed vs gold query speed.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
generated_sql: SQL produced by the agent
|
|
112
|
+
gold_sql: Ground-truth SQL
|
|
113
|
+
db_path: Path to SQLite database (or any sqlite3-compatible path)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
(score, details) where score is 0.0–1.0
|
|
117
|
+
"""
|
|
118
|
+
try:
|
|
119
|
+
conn = _connect_readonly(db_path)
|
|
120
|
+
except Exception as e:
|
|
121
|
+
return 0.0, {"error": f"db_connect_failed: {e}"}
|
|
122
|
+
try:
|
|
123
|
+
start = time.perf_counter()
|
|
124
|
+
gold_result = conn.execute(gold_sql).fetchall()
|
|
125
|
+
gold_time = max((time.perf_counter() - start) * 1000, 0.01)
|
|
126
|
+
|
|
127
|
+
start = time.perf_counter()
|
|
128
|
+
pred_result = conn.execute(generated_sql).fetchall()
|
|
129
|
+
pred_time = max((time.perf_counter() - start) * 1000, 0.01)
|
|
130
|
+
except Exception as e:
|
|
131
|
+
return 0.0, {"error": str(e)}
|
|
132
|
+
finally:
|
|
133
|
+
conn.close()
|
|
134
|
+
|
|
135
|
+
output_score = _match_result_sets(pred_result, gold_result)
|
|
136
|
+
|
|
137
|
+
struct_score = 0.0
|
|
138
|
+
if len(pred_result) == len(gold_result):
|
|
139
|
+
struct_score = 1.0
|
|
140
|
+
elif pred_result and gold_result:
|
|
141
|
+
struct_score = min(len(pred_result), len(gold_result)) / max(len(pred_result), len(gold_result))
|
|
142
|
+
|
|
143
|
+
time_ratio = gold_time / pred_time if pred_time > 0 else 1.0
|
|
144
|
+
efficiency = min(time_ratio, 1.0)
|
|
145
|
+
|
|
146
|
+
final = round(0.6 * output_score + 0.2 * struct_score + 0.2 * efficiency, 4)
|
|
147
|
+
|
|
148
|
+
return final, {
|
|
149
|
+
"output_score": round(output_score, 4),
|
|
150
|
+
"structural_score": round(struct_score, 4),
|
|
151
|
+
"efficiency_score": round(efficiency, 4),
|
|
152
|
+
"predicted_rows": len(pred_result),
|
|
153
|
+
"gold_rows": len(gold_result),
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def syntax_valid(sql: str, dialect: str = "sqlite") -> float:
|
|
158
|
+
"""Check if SQL parses without errors. Returns 1.0 or 0.0."""
|
|
159
|
+
try:
|
|
160
|
+
results = sqlglot.parse(sql, dialect=dialect)
|
|
161
|
+
return 1.0 if results and results[0] is not None else 0.0
|
|
162
|
+
except Exception:
|
|
163
|
+
return 0.0
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def semantic_equivalence(
|
|
167
|
+
question: str,
|
|
168
|
+
generated_sql: str,
|
|
169
|
+
llm_judge: LLMJudge,
|
|
170
|
+
gold_sql: str | None = None,
|
|
171
|
+
) -> tuple[float, dict]:
|
|
172
|
+
"""
|
|
173
|
+
LLM judges whether the SQL correctly answers the user's question.
|
|
174
|
+
Handles alias differences, join variations, CASE WHEN labels.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
question: User's natural language question
|
|
178
|
+
generated_sql: SQL produced by the agent
|
|
179
|
+
llm_judge: Function (prompt: str) -> str
|
|
180
|
+
gold_sql: Optional reference SQL for comparison
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
(score, details) where score is 0.0–1.0
|
|
184
|
+
"""
|
|
185
|
+
gold_section = f"\n**Reference SQL:**\n```sql\n{gold_sql}\n```" if gold_sql else ""
|
|
186
|
+
|
|
187
|
+
prompt = f"""You are a SQL expert judge. Evaluate if the Generated SQL correctly answers the User Question.
|
|
188
|
+
{gold_section}
|
|
189
|
+
|
|
190
|
+
**User Question:** {question}
|
|
191
|
+
|
|
192
|
+
**Generated SQL:**
|
|
193
|
+
```sql
|
|
194
|
+
{generated_sql}
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
Evaluate:
|
|
198
|
+
1. Does the SQL retrieve the correct data to answer the question?
|
|
199
|
+
2. Are the right tables, columns, and filters used?
|
|
200
|
+
3. Are aggregations applied correctly?
|
|
201
|
+
4. Are JOINs correct and necessary?
|
|
202
|
+
|
|
203
|
+
Score 0.0 to 1.0:
|
|
204
|
+
- 1.0: Perfectly answers the question
|
|
205
|
+
- 0.7-0.9: Minor issues not affecting the core answer
|
|
206
|
+
- 0.4-0.6: Partially correct, missing key elements
|
|
207
|
+
- 0.0-0.3: Wrong approach or major errors
|
|
208
|
+
|
|
209
|
+
Respond EXACTLY:
|
|
210
|
+
Semantic_Score: [score]
|
|
211
|
+
Reasoning: [one sentence]"""
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
result = llm_judge(prompt)
|
|
215
|
+
except Exception as e:
|
|
216
|
+
logger.warning("LLM judge failed in semantic_equivalence: %s", e)
|
|
217
|
+
return 0.0, {"error": str(e)}
|
|
218
|
+
|
|
219
|
+
score, reasoning = _parse_score(result, "Semantic_Score")
|
|
220
|
+
return score, {"reasoning": reasoning}
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def result_set_similarity(
|
|
224
|
+
generated_sql: str,
|
|
225
|
+
gold_sql: str,
|
|
226
|
+
db_path: str,
|
|
227
|
+
) -> tuple[float, dict]:
|
|
228
|
+
"""
|
|
229
|
+
RAGAS Answer Similarity for SQL agents.
|
|
230
|
+
|
|
231
|
+
Computes Jaccard similarity on normalized result sets between
|
|
232
|
+
generated and gold SQL execution outputs.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
generated_sql: SQL produced by the agent
|
|
236
|
+
gold_sql: Ground-truth SQL
|
|
237
|
+
db_path: Path to SQLite database
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
(similarity score 0.0–1.0, details dict)
|
|
241
|
+
"""
|
|
242
|
+
try:
|
|
243
|
+
conn = _connect_readonly(db_path)
|
|
244
|
+
except Exception as e:
|
|
245
|
+
return 0.0, {"error": f"db_connect_failed: {e}"}
|
|
246
|
+
try:
|
|
247
|
+
gold_rows = conn.execute(gold_sql).fetchall()
|
|
248
|
+
gold_desc = conn.execute(gold_sql).description
|
|
249
|
+
pred_rows = conn.execute(generated_sql).fetchall()
|
|
250
|
+
pred_desc = conn.execute(generated_sql).description
|
|
251
|
+
except Exception as e:
|
|
252
|
+
return 0.0, {"error": str(e)}
|
|
253
|
+
finally:
|
|
254
|
+
conn.close()
|
|
255
|
+
|
|
256
|
+
def _normalize_row(row):
|
|
257
|
+
cells = []
|
|
258
|
+
for v in row:
|
|
259
|
+
if isinstance(v, float):
|
|
260
|
+
cells.append(round(v, 2))
|
|
261
|
+
elif isinstance(v, str):
|
|
262
|
+
cells.append(v.strip().lower())
|
|
263
|
+
else:
|
|
264
|
+
cells.append(v)
|
|
265
|
+
return tuple(cells)
|
|
266
|
+
|
|
267
|
+
gold_set = {_normalize_row(r) for r in gold_rows}
|
|
268
|
+
pred_set = {_normalize_row(r) for r in pred_rows}
|
|
269
|
+
|
|
270
|
+
union = gold_set | pred_set
|
|
271
|
+
intersection = gold_set & pred_set
|
|
272
|
+
|
|
273
|
+
jaccard = len(intersection) / len(union) if union else 1.0
|
|
274
|
+
|
|
275
|
+
# Column count match
|
|
276
|
+
gold_cols = len(gold_desc) if gold_desc else 0
|
|
277
|
+
pred_cols = len(pred_desc) if pred_desc else 0
|
|
278
|
+
col_match = 1.0 if gold_cols == pred_cols else min(gold_cols, pred_cols) / max(gold_cols, pred_cols) if max(gold_cols, pred_cols) > 0 else 1.0
|
|
279
|
+
|
|
280
|
+
score = round(0.8 * jaccard + 0.2 * col_match, 4)
|
|
281
|
+
|
|
282
|
+
return score, {
|
|
283
|
+
"jaccard": round(jaccard, 4),
|
|
284
|
+
"column_match": round(col_match, 4),
|
|
285
|
+
"generated_row_count": len(pred_rows),
|
|
286
|
+
"gold_row_count": len(gold_rows),
|
|
287
|
+
"intersection_size": len(intersection),
|
|
288
|
+
"union_size": len(union),
|
|
289
|
+
}
|
sqlas/evaluate.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main evaluation API — single query and batch evaluation.
|
|
3
|
+
|
|
4
|
+
Author: SQLAS Contributors
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
from sqlas.core import SQLASScores, TestCase, LLMJudge, WEIGHTS, compute_composite_score
|
|
11
|
+
from sqlas.correctness import execution_accuracy, syntax_valid, semantic_equivalence, result_set_similarity
|
|
12
|
+
from sqlas.quality import sql_quality, schema_compliance, complexity_match
|
|
13
|
+
from sqlas.production import data_scan_efficiency, execution_result
|
|
14
|
+
from sqlas.response import faithfulness, answer_relevance, answer_completeness, fluency
|
|
15
|
+
from sqlas.safety import safety_score, read_only_compliance
|
|
16
|
+
from sqlas.context import context_precision, context_recall, entity_recall, noise_robustness
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def evaluate(
|
|
22
|
+
question: str,
|
|
23
|
+
generated_sql: str,
|
|
24
|
+
llm_judge: LLMJudge,
|
|
25
|
+
gold_sql: str | None = None,
|
|
26
|
+
db_path: str | None = None,
|
|
27
|
+
response: str | None = None,
|
|
28
|
+
result_data: dict | None = None,
|
|
29
|
+
valid_tables: set[str] | None = None,
|
|
30
|
+
valid_columns: dict[str, set[str]] | None = None,
|
|
31
|
+
schema_context: str = "",
|
|
32
|
+
expected_nonempty: bool = True,
|
|
33
|
+
pii_columns: list[str] | None = None,
|
|
34
|
+
weights: dict | None = None,
|
|
35
|
+
) -> SQLASScores:
|
|
36
|
+
"""
|
|
37
|
+
Evaluate a single SQL agent query across all SQLAS metrics.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
question: User's natural language question
|
|
41
|
+
generated_sql: SQL produced by the agent
|
|
42
|
+
llm_judge: Function (prompt: str) -> str for LLM-as-judge metrics
|
|
43
|
+
gold_sql: Ground-truth SQL (optional, enables execution accuracy & context metrics)
|
|
44
|
+
db_path: Path to SQLite database (required for execution accuracy)
|
|
45
|
+
response: Agent's natural language response (optional, enables faithfulness/relevance)
|
|
46
|
+
result_data: Query result dict: {columns, rows, row_count, execution_time_ms}
|
|
47
|
+
valid_tables: Set of valid table names (enables schema compliance)
|
|
48
|
+
valid_columns: Dict of {table: {col1, col2}} (enables schema compliance)
|
|
49
|
+
schema_context: Brief schema text for SQL quality judge
|
|
50
|
+
expected_nonempty: Whether non-empty result is expected
|
|
51
|
+
pii_columns: Custom PII column names for safety check
|
|
52
|
+
weights: Custom weight dict (defaults to SQLAS production weights)
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
SQLASScores with all metrics and overall_score
|
|
56
|
+
"""
|
|
57
|
+
# ── Input validation ────────────────────────────────────────────────
|
|
58
|
+
if not generated_sql or not isinstance(generated_sql, str):
|
|
59
|
+
logger.error("generated_sql must be a non-empty string")
|
|
60
|
+
scores = SQLASScores()
|
|
61
|
+
scores.details["error"] = "generated_sql is empty or invalid"
|
|
62
|
+
return scores
|
|
63
|
+
|
|
64
|
+
if db_path and not os.path.exists(db_path):
|
|
65
|
+
logger.error("db_path does not exist: %s", db_path)
|
|
66
|
+
scores = SQLASScores()
|
|
67
|
+
scores.details["error"] = f"db_path not found: {db_path}"
|
|
68
|
+
return scores
|
|
69
|
+
|
|
70
|
+
if weights:
|
|
71
|
+
weight_sum = sum(weights.values())
|
|
72
|
+
if abs(weight_sum - 1.0) > 0.01:
|
|
73
|
+
logger.warning("Custom weights sum to %.4f (expected ~1.0)", weight_sum)
|
|
74
|
+
|
|
75
|
+
scores = SQLASScores()
|
|
76
|
+
|
|
77
|
+
# ── 1. Core Correctness ─────────────────────────────────────────────
|
|
78
|
+
scores.syntax_valid = syntax_valid(generated_sql)
|
|
79
|
+
|
|
80
|
+
if gold_sql and db_path:
|
|
81
|
+
ex_acc, ex_details = execution_accuracy(generated_sql, gold_sql, db_path)
|
|
82
|
+
scores.execution_accuracy = ex_acc
|
|
83
|
+
scores.details["execution_accuracy"] = ex_details
|
|
84
|
+
|
|
85
|
+
# Efficiency (VES) — reuse timing from execution_accuracy
|
|
86
|
+
scores.efficiency_score = ex_details.get("efficiency_score", 0.0)
|
|
87
|
+
else:
|
|
88
|
+
scores.execution_accuracy = 1.0 if result_data else 0.0
|
|
89
|
+
scores.efficiency_score = 1.0 if result_data else 0.0
|
|
90
|
+
|
|
91
|
+
sem, sem_details = semantic_equivalence(question, generated_sql, llm_judge, gold_sql)
|
|
92
|
+
scores.semantic_equivalence = sem
|
|
93
|
+
scores.details["semantic_equivalence"] = sem_details
|
|
94
|
+
|
|
95
|
+
# ── 2. Context Quality (RAGAS-mapped) ───────────────────────────────
|
|
96
|
+
if gold_sql:
|
|
97
|
+
cp, cp_details = context_precision(generated_sql, gold_sql)
|
|
98
|
+
scores.context_precision = cp
|
|
99
|
+
scores.details["context_precision"] = cp_details
|
|
100
|
+
|
|
101
|
+
cr, cr_details = context_recall(generated_sql, gold_sql)
|
|
102
|
+
scores.context_recall = cr
|
|
103
|
+
scores.details["context_recall"] = cr_details
|
|
104
|
+
|
|
105
|
+
er, er_details = entity_recall(generated_sql, gold_sql)
|
|
106
|
+
scores.entity_recall = er
|
|
107
|
+
scores.details["entity_recall"] = er_details
|
|
108
|
+
|
|
109
|
+
nr, nr_details = noise_robustness(generated_sql, gold_sql, valid_tables, valid_columns)
|
|
110
|
+
scores.noise_robustness = nr
|
|
111
|
+
scores.details["noise_robustness"] = nr_details
|
|
112
|
+
|
|
113
|
+
if db_path:
|
|
114
|
+
rs, rs_details = result_set_similarity(generated_sql, gold_sql, db_path)
|
|
115
|
+
scores.result_set_similarity = rs
|
|
116
|
+
scores.details["result_set_similarity"] = rs_details
|
|
117
|
+
|
|
118
|
+
# ── 3. SQL Quality ──────────────────────────────────────────────────
|
|
119
|
+
if valid_tables and valid_columns:
|
|
120
|
+
sc, sc_details = schema_compliance(generated_sql, valid_tables, valid_columns)
|
|
121
|
+
scores.schema_compliance = sc
|
|
122
|
+
scores.details["schema_compliance"] = sc_details
|
|
123
|
+
else:
|
|
124
|
+
scores.schema_compliance = 1.0 # can't check without schema
|
|
125
|
+
|
|
126
|
+
sq, sq_details = sql_quality(question, generated_sql, llm_judge, schema_context)
|
|
127
|
+
scores.sql_quality = sq
|
|
128
|
+
scores.details["sql_quality"] = sq_details
|
|
129
|
+
|
|
130
|
+
cm, cm_details = complexity_match(question, generated_sql, llm_judge)
|
|
131
|
+
scores.complexity_match = cm
|
|
132
|
+
scores.details["complexity_match"] = cm_details
|
|
133
|
+
|
|
134
|
+
# ── 4. Production Execution ─────────────────────────────────────────
|
|
135
|
+
exec_eval = execution_result(result_data, expected_nonempty)
|
|
136
|
+
scores.execution_success = exec_eval["execution_success"]
|
|
137
|
+
scores.execution_time_ms = exec_eval["execution_time_ms"]
|
|
138
|
+
scores.result_row_count = exec_eval["result_row_count"]
|
|
139
|
+
scores.empty_result_penalty = exec_eval["empty_result_penalty"]
|
|
140
|
+
scores.row_explosion_detected = exec_eval["row_explosion_detected"]
|
|
141
|
+
|
|
142
|
+
scan, scan_details = data_scan_efficiency(generated_sql, scores.result_row_count)
|
|
143
|
+
scores.data_scan_efficiency = scan
|
|
144
|
+
scores.details["data_scan"] = scan_details
|
|
145
|
+
|
|
146
|
+
# ── 5. Response Quality ─────────────────────────────────────────────
|
|
147
|
+
if response and result_data:
|
|
148
|
+
result_preview = f"Columns: {result_data.get('columns', [])}\n"
|
|
149
|
+
for row in result_data.get("rows", [])[:15]:
|
|
150
|
+
result_preview += f"{row}\n"
|
|
151
|
+
|
|
152
|
+
f_score, f_details = faithfulness(question, response, result_preview, llm_judge)
|
|
153
|
+
scores.faithfulness = f_score
|
|
154
|
+
scores.details["faithfulness"] = f_details
|
|
155
|
+
|
|
156
|
+
r_score, r_details = answer_relevance(question, response, llm_judge)
|
|
157
|
+
scores.answer_relevance = r_score
|
|
158
|
+
scores.details["answer_relevance"] = r_details
|
|
159
|
+
|
|
160
|
+
c_score, c_details = answer_completeness(question, response, result_preview, llm_judge)
|
|
161
|
+
scores.answer_completeness = c_score
|
|
162
|
+
scores.details["answer_completeness"] = c_details
|
|
163
|
+
|
|
164
|
+
fl_score, fl_details = fluency(response, llm_judge)
|
|
165
|
+
scores.fluency = fl_score
|
|
166
|
+
scores.details["fluency"] = fl_details
|
|
167
|
+
|
|
168
|
+
# ── 6. Safety ───────────────────────────────────────────────────────
|
|
169
|
+
scores.read_only_compliance = read_only_compliance(generated_sql)
|
|
170
|
+
|
|
171
|
+
safety, safety_details = safety_score(generated_sql, response or "", pii_columns)
|
|
172
|
+
scores.safety_score = safety
|
|
173
|
+
scores.details["safety"] = safety_details
|
|
174
|
+
|
|
175
|
+
# ── Composite ───────────────────────────────────────────────────────
|
|
176
|
+
scores.overall_score = compute_composite_score(scores, weights)
|
|
177
|
+
|
|
178
|
+
return scores
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def evaluate_batch(
|
|
182
|
+
test_cases: list[dict],
|
|
183
|
+
llm_judge: LLMJudge,
|
|
184
|
+
db_path: str | None = None,
|
|
185
|
+
valid_tables: set[str] | None = None,
|
|
186
|
+
valid_columns: dict[str, set[str]] | None = None,
|
|
187
|
+
schema_context: str = "",
|
|
188
|
+
pii_columns: list[str] | None = None,
|
|
189
|
+
weights: dict | None = None,
|
|
190
|
+
) -> list[SQLASScores]:
|
|
191
|
+
"""
|
|
192
|
+
Evaluate a batch of test cases.
|
|
193
|
+
|
|
194
|
+
Each dict in test_cases should have:
|
|
195
|
+
question, generated_sql, and optionally:
|
|
196
|
+
gold_sql, response, result_data, expected_nonempty
|
|
197
|
+
|
|
198
|
+
Returns list of SQLASScores.
|
|
199
|
+
"""
|
|
200
|
+
results = []
|
|
201
|
+
for tc in test_cases:
|
|
202
|
+
scores = evaluate(
|
|
203
|
+
question=tc["question"],
|
|
204
|
+
generated_sql=tc["generated_sql"],
|
|
205
|
+
llm_judge=llm_judge,
|
|
206
|
+
gold_sql=tc.get("gold_sql"),
|
|
207
|
+
db_path=db_path,
|
|
208
|
+
response=tc.get("response"),
|
|
209
|
+
result_data=tc.get("result_data"),
|
|
210
|
+
valid_tables=valid_tables,
|
|
211
|
+
valid_columns=valid_columns,
|
|
212
|
+
schema_context=tc.get("schema_context", schema_context),
|
|
213
|
+
expected_nonempty=tc.get("expected_nonempty", True),
|
|
214
|
+
pii_columns=pii_columns,
|
|
215
|
+
weights=weights,
|
|
216
|
+
)
|
|
217
|
+
results.append(scores)
|
|
218
|
+
return results
|
sqlas/production.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Production Execution Metrics.
|
|
3
|
+
- Data Scan Efficiency (full scan detection)
|
|
4
|
+
- Execution Result (success, empty result, row explosion)
|
|
5
|
+
|
|
6
|
+
Author: SQLAS Contributors
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def data_scan_efficiency(sql: str, result_row_count: int = 0) -> tuple[float, dict]:
|
|
13
|
+
"""
|
|
14
|
+
Detect inefficient data access patterns:
|
|
15
|
+
- SELECT * without WHERE
|
|
16
|
+
- Missing filters on large queries
|
|
17
|
+
- Cartesian products from bad JOINs
|
|
18
|
+
- No LIMIT on detail queries
|
|
19
|
+
"""
|
|
20
|
+
upper = sql.upper()
|
|
21
|
+
issues = []
|
|
22
|
+
score = 1.0
|
|
23
|
+
|
|
24
|
+
if "SELECT *" in upper or "SELECT *" in upper:
|
|
25
|
+
issues.append("SELECT * — should specify columns")
|
|
26
|
+
score -= 0.2
|
|
27
|
+
|
|
28
|
+
has_where = "WHERE" in upper
|
|
29
|
+
has_group = "GROUP BY" in upper
|
|
30
|
+
has_limit = "LIMIT" in upper
|
|
31
|
+
|
|
32
|
+
if not has_where and not has_group and not has_limit:
|
|
33
|
+
issues.append("No WHERE, GROUP BY, or LIMIT — potential full scan")
|
|
34
|
+
score -= 0.3
|
|
35
|
+
|
|
36
|
+
if result_row_count > 10000 and "JOIN" in upper:
|
|
37
|
+
issues.append(f"Large result ({result_row_count} rows) from JOIN — possible cartesian product")
|
|
38
|
+
score -= 0.3
|
|
39
|
+
|
|
40
|
+
if not has_group and not has_limit and result_row_count > 100:
|
|
41
|
+
issues.append("No LIMIT on detail query returning many rows")
|
|
42
|
+
score -= 0.1
|
|
43
|
+
|
|
44
|
+
return max(score, 0.0), {"issues": issues or ["none"]}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def execution_result(
|
|
48
|
+
data: dict | None,
|
|
49
|
+
expected_nonempty: bool = True,
|
|
50
|
+
) -> dict:
|
|
51
|
+
"""
|
|
52
|
+
Evaluate execution outcome.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
data: Query result dict with keys: row_count, execution_time_ms, truncated
|
|
56
|
+
expected_nonempty: Whether non-empty result is expected
|
|
57
|
+
"""
|
|
58
|
+
if data is None:
|
|
59
|
+
return {
|
|
60
|
+
"execution_success": 0.0,
|
|
61
|
+
"empty_result_penalty": 0.0,
|
|
62
|
+
"row_explosion_detected": False,
|
|
63
|
+
"execution_time_ms": 0,
|
|
64
|
+
"result_row_count": 0,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
row_count = data.get("row_count", 0)
|
|
68
|
+
return {
|
|
69
|
+
"execution_success": 1.0,
|
|
70
|
+
"execution_time_ms": data.get("execution_time_ms", 0),
|
|
71
|
+
"result_row_count": row_count,
|
|
72
|
+
"empty_result_penalty": 0.0 if (expected_nonempty and row_count == 0) else 1.0,
|
|
73
|
+
"row_explosion_detected": row_count > 50000,
|
|
74
|
+
}
|
sqlas/py.typed
ADDED
|
File without changes
|