sqlas 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlas/correctness.py ADDED
@@ -0,0 +1,289 @@
1
+ """
2
+ Core SQL Correctness Metrics.
3
+ - Execution Accuracy (output + structure + efficiency)
4
+ - Syntax Validity (sqlglot parse)
5
+ - Semantic Equivalence (LLM-as-Judge)
6
+ - Result Set Similarity (Jaccard on normalized result sets)
7
+
8
+ Author: SQLAS Contributors
9
+ """
10
+
11
+ import logging
12
+ import time
13
+ import sqlite3
14
+
15
+ import sqlglot
16
+
17
+ from sqlas.core import LLMJudge, _parse_score
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Default SQL execution timeout in seconds
22
+ _DEFAULT_TIMEOUT_S = 30
23
+ _PROGRESS_INTERVAL = 1_000_000 # check timeout every N SQLite VM instructions
24
+
25
+
26
+ def _connect_readonly(db_path: str, timeout_s: int = _DEFAULT_TIMEOUT_S) -> sqlite3.Connection:
27
+ """Open a read-only SQLite connection with a timeout guard."""
28
+ conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_s)
29
+ _start = time.monotonic()
30
+
31
+ def _progress_handler():
32
+ if time.monotonic() - _start > timeout_s:
33
+ return 1 # non-zero → abort query
34
+ return 0
35
+
36
+ conn.set_progress_handler(_progress_handler, _PROGRESS_INTERVAL)
37
+ return conn
38
+
39
+
40
+ # ── Helpers ─────────────────────────────────────────────────────────────────
41
+
42
+ def _extract_row_numbers(row) -> list[float]:
43
+ return sorted([round(float(v), 2) for v in row if isinstance(v, (int, float)) and v is not None])
44
+
45
+
46
+ def _values_found_in(needle: list[float], haystack: list[float], tol: float = 0.5) -> float:
47
+ if not needle:
48
+ return 1.0
49
+ remaining = list(haystack)
50
+ matched = 0
51
+ for nv in needle:
52
+ best_idx, best_diff = -1, float("inf")
53
+ for i, hv in enumerate(remaining):
54
+ diff = abs(nv - hv)
55
+ if diff < best_diff:
56
+ best_diff, best_idx = diff, i
57
+ if best_idx >= 0 and best_diff <= tol:
58
+ remaining.pop(best_idx)
59
+ matched += 1
60
+ return matched / len(needle)
61
+
62
+
63
+ def _row_values_match(pred_nums: list[float], gold_nums: list[float], tol: float = 0.5) -> float:
64
+ if not gold_nums and not pred_nums:
65
+ return 1.0
66
+ if not gold_nums:
67
+ return 0.8
68
+ if not pred_nums:
69
+ return 0.0
70
+ if len(pred_nums) < len(gold_nums):
71
+ subset_score = _values_found_in(pred_nums, gold_nums, tol)
72
+ return 1.0 if subset_score >= 0.99 else subset_score
73
+ return _values_found_in(gold_nums, pred_nums, tol)
74
+
75
+
76
+ def _match_result_sets(pred_rows: list, gold_rows: list) -> float:
77
+ if not gold_rows:
78
+ return 1.0 if not pred_rows else 0.5
79
+ pred_nums_list = [_extract_row_numbers(r) for r in pred_rows]
80
+ gold_nums_list = [_extract_row_numbers(r) for r in gold_rows]
81
+ used_pred = set()
82
+ total_score = 0.0
83
+ for gn in gold_nums_list:
84
+ best_score, best_pi = 0.0, -1
85
+ for pi, pn in enumerate(pred_nums_list):
86
+ if pi in used_pred:
87
+ continue
88
+ score = _row_values_match(pn, gn)
89
+ if score > best_score:
90
+ best_score, best_pi = score, pi
91
+ if best_pi >= 0:
92
+ used_pred.add(best_pi)
93
+ total_score += best_score
94
+ return total_score / len(gold_rows)
95
+
96
+
97
+ # ── Public API ──────────────────────────────────────────────────────────────
98
+
99
+ def execution_accuracy(generated_sql: str, gold_sql: str, db_path: str) -> tuple[float, dict]:
100
+ """
101
+ Semantic execution accuracy.
102
+
103
+ Formula: 60% Output Match + 20% Structure Match + 20% Efficiency
104
+
105
+ Output Match: Row-by-row numeric comparison. Ignores label differences
106
+ (0 vs 'Male'), tolerates ROUND, handles extra columns.
107
+ Structure: Same row count.
108
+ Efficiency: Generated query speed vs gold query speed.
109
+
110
+ Args:
111
+ generated_sql: SQL produced by the agent
112
+ gold_sql: Ground-truth SQL
113
+ db_path: Path to SQLite database (or any sqlite3-compatible path)
114
+
115
+ Returns:
116
+ (score, details) where score is 0.0–1.0
117
+ """
118
+ try:
119
+ conn = _connect_readonly(db_path)
120
+ except Exception as e:
121
+ return 0.0, {"error": f"db_connect_failed: {e}"}
122
+ try:
123
+ start = time.perf_counter()
124
+ gold_result = conn.execute(gold_sql).fetchall()
125
+ gold_time = max((time.perf_counter() - start) * 1000, 0.01)
126
+
127
+ start = time.perf_counter()
128
+ pred_result = conn.execute(generated_sql).fetchall()
129
+ pred_time = max((time.perf_counter() - start) * 1000, 0.01)
130
+ except Exception as e:
131
+ return 0.0, {"error": str(e)}
132
+ finally:
133
+ conn.close()
134
+
135
+ output_score = _match_result_sets(pred_result, gold_result)
136
+
137
+ struct_score = 0.0
138
+ if len(pred_result) == len(gold_result):
139
+ struct_score = 1.0
140
+ elif pred_result and gold_result:
141
+ struct_score = min(len(pred_result), len(gold_result)) / max(len(pred_result), len(gold_result))
142
+
143
+ time_ratio = gold_time / pred_time if pred_time > 0 else 1.0
144
+ efficiency = min(time_ratio, 1.0)
145
+
146
+ final = round(0.6 * output_score + 0.2 * struct_score + 0.2 * efficiency, 4)
147
+
148
+ return final, {
149
+ "output_score": round(output_score, 4),
150
+ "structural_score": round(struct_score, 4),
151
+ "efficiency_score": round(efficiency, 4),
152
+ "predicted_rows": len(pred_result),
153
+ "gold_rows": len(gold_result),
154
+ }
155
+
156
+
157
+ def syntax_valid(sql: str, dialect: str = "sqlite") -> float:
158
+ """Check if SQL parses without errors. Returns 1.0 or 0.0."""
159
+ try:
160
+ results = sqlglot.parse(sql, dialect=dialect)
161
+ return 1.0 if results and results[0] is not None else 0.0
162
+ except Exception:
163
+ return 0.0
164
+
165
+
166
+ def semantic_equivalence(
167
+ question: str,
168
+ generated_sql: str,
169
+ llm_judge: LLMJudge,
170
+ gold_sql: str | None = None,
171
+ ) -> tuple[float, dict]:
172
+ """
173
+ LLM judges whether the SQL correctly answers the user's question.
174
+ Handles alias differences, join variations, CASE WHEN labels.
175
+
176
+ Args:
177
+ question: User's natural language question
178
+ generated_sql: SQL produced by the agent
179
+ llm_judge: Function (prompt: str) -> str
180
+ gold_sql: Optional reference SQL for comparison
181
+
182
+ Returns:
183
+ (score, details) where score is 0.0–1.0
184
+ """
185
+ gold_section = f"\n**Reference SQL:**\n```sql\n{gold_sql}\n```" if gold_sql else ""
186
+
187
+ prompt = f"""You are a SQL expert judge. Evaluate if the Generated SQL correctly answers the User Question.
188
+ {gold_section}
189
+
190
+ **User Question:** {question}
191
+
192
+ **Generated SQL:**
193
+ ```sql
194
+ {generated_sql}
195
+ ```
196
+
197
+ Evaluate:
198
+ 1. Does the SQL retrieve the correct data to answer the question?
199
+ 2. Are the right tables, columns, and filters used?
200
+ 3. Are aggregations applied correctly?
201
+ 4. Are JOINs correct and necessary?
202
+
203
+ Score 0.0 to 1.0:
204
+ - 1.0: Perfectly answers the question
205
+ - 0.7-0.9: Minor issues not affecting the core answer
206
+ - 0.4-0.6: Partially correct, missing key elements
207
+ - 0.0-0.3: Wrong approach or major errors
208
+
209
+ Respond EXACTLY:
210
+ Semantic_Score: [score]
211
+ Reasoning: [one sentence]"""
212
+
213
+ try:
214
+ result = llm_judge(prompt)
215
+ except Exception as e:
216
+ logger.warning("LLM judge failed in semantic_equivalence: %s", e)
217
+ return 0.0, {"error": str(e)}
218
+
219
+ score, reasoning = _parse_score(result, "Semantic_Score")
220
+ return score, {"reasoning": reasoning}
221
+
222
+
223
+ def result_set_similarity(
224
+ generated_sql: str,
225
+ gold_sql: str,
226
+ db_path: str,
227
+ ) -> tuple[float, dict]:
228
+ """
229
+ RAGAS Answer Similarity for SQL agents.
230
+
231
+ Computes Jaccard similarity on normalized result sets between
232
+ generated and gold SQL execution outputs.
233
+
234
+ Args:
235
+ generated_sql: SQL produced by the agent
236
+ gold_sql: Ground-truth SQL
237
+ db_path: Path to SQLite database
238
+
239
+ Returns:
240
+ (similarity score 0.0–1.0, details dict)
241
+ """
242
+ try:
243
+ conn = _connect_readonly(db_path)
244
+ except Exception as e:
245
+ return 0.0, {"error": f"db_connect_failed: {e}"}
246
+ try:
247
+ gold_rows = conn.execute(gold_sql).fetchall()
248
+ gold_desc = conn.execute(gold_sql).description
249
+ pred_rows = conn.execute(generated_sql).fetchall()
250
+ pred_desc = conn.execute(generated_sql).description
251
+ except Exception as e:
252
+ return 0.0, {"error": str(e)}
253
+ finally:
254
+ conn.close()
255
+
256
+ def _normalize_row(row):
257
+ cells = []
258
+ for v in row:
259
+ if isinstance(v, float):
260
+ cells.append(round(v, 2))
261
+ elif isinstance(v, str):
262
+ cells.append(v.strip().lower())
263
+ else:
264
+ cells.append(v)
265
+ return tuple(cells)
266
+
267
+ gold_set = {_normalize_row(r) for r in gold_rows}
268
+ pred_set = {_normalize_row(r) for r in pred_rows}
269
+
270
+ union = gold_set | pred_set
271
+ intersection = gold_set & pred_set
272
+
273
+ jaccard = len(intersection) / len(union) if union else 1.0
274
+
275
+ # Column count match
276
+ gold_cols = len(gold_desc) if gold_desc else 0
277
+ pred_cols = len(pred_desc) if pred_desc else 0
278
+ col_match = 1.0 if gold_cols == pred_cols else min(gold_cols, pred_cols) / max(gold_cols, pred_cols) if max(gold_cols, pred_cols) > 0 else 1.0
279
+
280
+ score = round(0.8 * jaccard + 0.2 * col_match, 4)
281
+
282
+ return score, {
283
+ "jaccard": round(jaccard, 4),
284
+ "column_match": round(col_match, 4),
285
+ "generated_row_count": len(pred_rows),
286
+ "gold_row_count": len(gold_rows),
287
+ "intersection_size": len(intersection),
288
+ "union_size": len(union),
289
+ }
sqlas/evaluate.py ADDED
@@ -0,0 +1,218 @@
1
+ """
2
+ Main evaluation API — single query and batch evaluation.
3
+
4
+ Author: SQLAS Contributors
5
+ """
6
+
7
+ import logging
8
+ import os
9
+
10
+ from sqlas.core import SQLASScores, TestCase, LLMJudge, WEIGHTS, compute_composite_score
11
+ from sqlas.correctness import execution_accuracy, syntax_valid, semantic_equivalence, result_set_similarity
12
+ from sqlas.quality import sql_quality, schema_compliance, complexity_match
13
+ from sqlas.production import data_scan_efficiency, execution_result
14
+ from sqlas.response import faithfulness, answer_relevance, answer_completeness, fluency
15
+ from sqlas.safety import safety_score, read_only_compliance
16
+ from sqlas.context import context_precision, context_recall, entity_recall, noise_robustness
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def evaluate(
22
+ question: str,
23
+ generated_sql: str,
24
+ llm_judge: LLMJudge,
25
+ gold_sql: str | None = None,
26
+ db_path: str | None = None,
27
+ response: str | None = None,
28
+ result_data: dict | None = None,
29
+ valid_tables: set[str] | None = None,
30
+ valid_columns: dict[str, set[str]] | None = None,
31
+ schema_context: str = "",
32
+ expected_nonempty: bool = True,
33
+ pii_columns: list[str] | None = None,
34
+ weights: dict | None = None,
35
+ ) -> SQLASScores:
36
+ """
37
+ Evaluate a single SQL agent query across all SQLAS metrics.
38
+
39
+ Args:
40
+ question: User's natural language question
41
+ generated_sql: SQL produced by the agent
42
+ llm_judge: Function (prompt: str) -> str for LLM-as-judge metrics
43
+ gold_sql: Ground-truth SQL (optional, enables execution accuracy & context metrics)
44
+ db_path: Path to SQLite database (required for execution accuracy)
45
+ response: Agent's natural language response (optional, enables faithfulness/relevance)
46
+ result_data: Query result dict: {columns, rows, row_count, execution_time_ms}
47
+ valid_tables: Set of valid table names (enables schema compliance)
48
+ valid_columns: Dict of {table: {col1, col2}} (enables schema compliance)
49
+ schema_context: Brief schema text for SQL quality judge
50
+ expected_nonempty: Whether non-empty result is expected
51
+ pii_columns: Custom PII column names for safety check
52
+ weights: Custom weight dict (defaults to SQLAS production weights)
53
+
54
+ Returns:
55
+ SQLASScores with all metrics and overall_score
56
+ """
57
+ # ── Input validation ────────────────────────────────────────────────
58
+ if not generated_sql or not isinstance(generated_sql, str):
59
+ logger.error("generated_sql must be a non-empty string")
60
+ scores = SQLASScores()
61
+ scores.details["error"] = "generated_sql is empty or invalid"
62
+ return scores
63
+
64
+ if db_path and not os.path.exists(db_path):
65
+ logger.error("db_path does not exist: %s", db_path)
66
+ scores = SQLASScores()
67
+ scores.details["error"] = f"db_path not found: {db_path}"
68
+ return scores
69
+
70
+ if weights:
71
+ weight_sum = sum(weights.values())
72
+ if abs(weight_sum - 1.0) > 0.01:
73
+ logger.warning("Custom weights sum to %.4f (expected ~1.0)", weight_sum)
74
+
75
+ scores = SQLASScores()
76
+
77
+ # ── 1. Core Correctness ─────────────────────────────────────────────
78
+ scores.syntax_valid = syntax_valid(generated_sql)
79
+
80
+ if gold_sql and db_path:
81
+ ex_acc, ex_details = execution_accuracy(generated_sql, gold_sql, db_path)
82
+ scores.execution_accuracy = ex_acc
83
+ scores.details["execution_accuracy"] = ex_details
84
+
85
+ # Efficiency (VES) — reuse timing from execution_accuracy
86
+ scores.efficiency_score = ex_details.get("efficiency_score", 0.0)
87
+ else:
88
+ scores.execution_accuracy = 1.0 if result_data else 0.0
89
+ scores.efficiency_score = 1.0 if result_data else 0.0
90
+
91
+ sem, sem_details = semantic_equivalence(question, generated_sql, llm_judge, gold_sql)
92
+ scores.semantic_equivalence = sem
93
+ scores.details["semantic_equivalence"] = sem_details
94
+
95
+ # ── 2. Context Quality (RAGAS-mapped) ───────────────────────────────
96
+ if gold_sql:
97
+ cp, cp_details = context_precision(generated_sql, gold_sql)
98
+ scores.context_precision = cp
99
+ scores.details["context_precision"] = cp_details
100
+
101
+ cr, cr_details = context_recall(generated_sql, gold_sql)
102
+ scores.context_recall = cr
103
+ scores.details["context_recall"] = cr_details
104
+
105
+ er, er_details = entity_recall(generated_sql, gold_sql)
106
+ scores.entity_recall = er
107
+ scores.details["entity_recall"] = er_details
108
+
109
+ nr, nr_details = noise_robustness(generated_sql, gold_sql, valid_tables, valid_columns)
110
+ scores.noise_robustness = nr
111
+ scores.details["noise_robustness"] = nr_details
112
+
113
+ if db_path:
114
+ rs, rs_details = result_set_similarity(generated_sql, gold_sql, db_path)
115
+ scores.result_set_similarity = rs
116
+ scores.details["result_set_similarity"] = rs_details
117
+
118
+ # ── 3. SQL Quality ──────────────────────────────────────────────────
119
+ if valid_tables and valid_columns:
120
+ sc, sc_details = schema_compliance(generated_sql, valid_tables, valid_columns)
121
+ scores.schema_compliance = sc
122
+ scores.details["schema_compliance"] = sc_details
123
+ else:
124
+ scores.schema_compliance = 1.0 # can't check without schema
125
+
126
+ sq, sq_details = sql_quality(question, generated_sql, llm_judge, schema_context)
127
+ scores.sql_quality = sq
128
+ scores.details["sql_quality"] = sq_details
129
+
130
+ cm, cm_details = complexity_match(question, generated_sql, llm_judge)
131
+ scores.complexity_match = cm
132
+ scores.details["complexity_match"] = cm_details
133
+
134
+ # ── 4. Production Execution ─────────────────────────────────────────
135
+ exec_eval = execution_result(result_data, expected_nonempty)
136
+ scores.execution_success = exec_eval["execution_success"]
137
+ scores.execution_time_ms = exec_eval["execution_time_ms"]
138
+ scores.result_row_count = exec_eval["result_row_count"]
139
+ scores.empty_result_penalty = exec_eval["empty_result_penalty"]
140
+ scores.row_explosion_detected = exec_eval["row_explosion_detected"]
141
+
142
+ scan, scan_details = data_scan_efficiency(generated_sql, scores.result_row_count)
143
+ scores.data_scan_efficiency = scan
144
+ scores.details["data_scan"] = scan_details
145
+
146
+ # ── 5. Response Quality ─────────────────────────────────────────────
147
+ if response and result_data:
148
+ result_preview = f"Columns: {result_data.get('columns', [])}\n"
149
+ for row in result_data.get("rows", [])[:15]:
150
+ result_preview += f"{row}\n"
151
+
152
+ f_score, f_details = faithfulness(question, response, result_preview, llm_judge)
153
+ scores.faithfulness = f_score
154
+ scores.details["faithfulness"] = f_details
155
+
156
+ r_score, r_details = answer_relevance(question, response, llm_judge)
157
+ scores.answer_relevance = r_score
158
+ scores.details["answer_relevance"] = r_details
159
+
160
+ c_score, c_details = answer_completeness(question, response, result_preview, llm_judge)
161
+ scores.answer_completeness = c_score
162
+ scores.details["answer_completeness"] = c_details
163
+
164
+ fl_score, fl_details = fluency(response, llm_judge)
165
+ scores.fluency = fl_score
166
+ scores.details["fluency"] = fl_details
167
+
168
+ # ── 6. Safety ───────────────────────────────────────────────────────
169
+ scores.read_only_compliance = read_only_compliance(generated_sql)
170
+
171
+ safety, safety_details = safety_score(generated_sql, response or "", pii_columns)
172
+ scores.safety_score = safety
173
+ scores.details["safety"] = safety_details
174
+
175
+ # ── Composite ───────────────────────────────────────────────────────
176
+ scores.overall_score = compute_composite_score(scores, weights)
177
+
178
+ return scores
179
+
180
+
181
+ def evaluate_batch(
182
+ test_cases: list[dict],
183
+ llm_judge: LLMJudge,
184
+ db_path: str | None = None,
185
+ valid_tables: set[str] | None = None,
186
+ valid_columns: dict[str, set[str]] | None = None,
187
+ schema_context: str = "",
188
+ pii_columns: list[str] | None = None,
189
+ weights: dict | None = None,
190
+ ) -> list[SQLASScores]:
191
+ """
192
+ Evaluate a batch of test cases.
193
+
194
+ Each dict in test_cases should have:
195
+ question, generated_sql, and optionally:
196
+ gold_sql, response, result_data, expected_nonempty
197
+
198
+ Returns list of SQLASScores.
199
+ """
200
+ results = []
201
+ for tc in test_cases:
202
+ scores = evaluate(
203
+ question=tc["question"],
204
+ generated_sql=tc["generated_sql"],
205
+ llm_judge=llm_judge,
206
+ gold_sql=tc.get("gold_sql"),
207
+ db_path=db_path,
208
+ response=tc.get("response"),
209
+ result_data=tc.get("result_data"),
210
+ valid_tables=valid_tables,
211
+ valid_columns=valid_columns,
212
+ schema_context=tc.get("schema_context", schema_context),
213
+ expected_nonempty=tc.get("expected_nonempty", True),
214
+ pii_columns=pii_columns,
215
+ weights=weights,
216
+ )
217
+ results.append(scores)
218
+ return results
sqlas/production.py ADDED
@@ -0,0 +1,74 @@
1
+ """
2
+ Production Execution Metrics.
3
+ - Data Scan Efficiency (full scan detection)
4
+ - Execution Result (success, empty result, row explosion)
5
+
6
+ Author: SQLAS Contributors
7
+ """
8
+
9
+ import re
10
+
11
+
12
+ def data_scan_efficiency(sql: str, result_row_count: int = 0) -> tuple[float, dict]:
13
+ """
14
+ Detect inefficient data access patterns:
15
+ - SELECT * without WHERE
16
+ - Missing filters on large queries
17
+ - Cartesian products from bad JOINs
18
+ - No LIMIT on detail queries
19
+ """
20
+ upper = sql.upper()
21
+ issues = []
22
+ score = 1.0
23
+
24
+ if "SELECT *" in upper or "SELECT *" in upper:
25
+ issues.append("SELECT * — should specify columns")
26
+ score -= 0.2
27
+
28
+ has_where = "WHERE" in upper
29
+ has_group = "GROUP BY" in upper
30
+ has_limit = "LIMIT" in upper
31
+
32
+ if not has_where and not has_group and not has_limit:
33
+ issues.append("No WHERE, GROUP BY, or LIMIT — potential full scan")
34
+ score -= 0.3
35
+
36
+ if result_row_count > 10000 and "JOIN" in upper:
37
+ issues.append(f"Large result ({result_row_count} rows) from JOIN — possible cartesian product")
38
+ score -= 0.3
39
+
40
+ if not has_group and not has_limit and result_row_count > 100:
41
+ issues.append("No LIMIT on detail query returning many rows")
42
+ score -= 0.1
43
+
44
+ return max(score, 0.0), {"issues": issues or ["none"]}
45
+
46
+
47
+ def execution_result(
48
+ data: dict | None,
49
+ expected_nonempty: bool = True,
50
+ ) -> dict:
51
+ """
52
+ Evaluate execution outcome.
53
+
54
+ Args:
55
+ data: Query result dict with keys: row_count, execution_time_ms, truncated
56
+ expected_nonempty: Whether non-empty result is expected
57
+ """
58
+ if data is None:
59
+ return {
60
+ "execution_success": 0.0,
61
+ "empty_result_penalty": 0.0,
62
+ "row_explosion_detected": False,
63
+ "execution_time_ms": 0,
64
+ "result_row_count": 0,
65
+ }
66
+
67
+ row_count = data.get("row_count", 0)
68
+ return {
69
+ "execution_success": 1.0,
70
+ "execution_time_ms": data.get("execution_time_ms", 0),
71
+ "result_row_count": row_count,
72
+ "empty_result_penalty": 0.0 if (expected_nonempty and row_count == 0) else 1.0,
73
+ "row_explosion_detected": row_count > 50000,
74
+ }
sqlas/py.typed ADDED
File without changes