sqlas 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlas/__init__.py ADDED
@@ -0,0 +1,69 @@
1
+ """
2
+ SQLAS — SQL Agent Scoring Framework
3
+ A RAGAS-equivalent evaluation library for Text-to-SQL and SQL AI agents.
4
+
5
+ Author: SQLAS Contributors
6
+
7
+ Usage:
8
+ from sqlas import evaluate, SQLASScores, TestCase, WEIGHTS
9
+
10
+ scores = evaluate(
11
+ question="How many users are active?",
12
+ generated_sql="SELECT COUNT(*) FROM users WHERE active = 1",
13
+ gold_sql="SELECT COUNT(*) FROM users WHERE active = 1",
14
+ db_path="my_database.db",
15
+ llm_judge=my_llm_function,
16
+ )
17
+ print(scores.overall_score)
18
+ """
19
+
20
+ from sqlas.core import SQLASScores, TestCase, WEIGHTS, WEIGHTS_V2, compute_composite_score
21
+ from sqlas.evaluate import evaluate, evaluate_batch
22
+ from sqlas.correctness import execution_accuracy, syntax_valid, semantic_equivalence, result_set_similarity
23
+ from sqlas.quality import sql_quality, schema_compliance, complexity_match
24
+ from sqlas.production import data_scan_efficiency, execution_result
25
+ from sqlas.response import faithfulness, answer_relevance, answer_completeness, fluency
26
+ from sqlas.safety import safety_score, read_only_compliance
27
+ from sqlas.context import context_precision, context_recall, entity_recall, noise_robustness
28
+ from sqlas.runner import run_suite
29
+
30
+ __version__ = "1.1.0"
31
+ __author__ = "SQLAS Contributors"
32
+
33
+ __all__ = [
34
+ # Core
35
+ "SQLASScores",
36
+ "TestCase",
37
+ "WEIGHTS",
38
+ "WEIGHTS_V2",
39
+ "compute_composite_score",
40
+ # Top-level API
41
+ "evaluate",
42
+ "evaluate_batch",
43
+ "run_suite",
44
+ # Correctness metrics
45
+ "execution_accuracy",
46
+ "syntax_valid",
47
+ "semantic_equivalence",
48
+ "result_set_similarity",
49
+ # Quality metrics
50
+ "sql_quality",
51
+ "schema_compliance",
52
+ "complexity_match",
53
+ # Production metrics
54
+ "data_scan_efficiency",
55
+ "execution_result",
56
+ # Response metrics
57
+ "faithfulness",
58
+ "answer_relevance",
59
+ "answer_completeness",
60
+ "fluency",
61
+ # Safety metrics
62
+ "safety_score",
63
+ "read_only_compliance",
64
+ # Context metrics (RAGAS-mapped)
65
+ "context_precision",
66
+ "context_recall",
67
+ "entity_recall",
68
+ "noise_robustness",
69
+ ]
sqlas/context.py ADDED
@@ -0,0 +1,268 @@
1
+ """
2
+ Context Quality Metrics (RAGAS-mapped for SQL agents).
3
+ - Context Precision (schema element precision)
4
+ - Context Recall (schema element recall)
5
+ - Entity Recall (strict entity-level recall)
6
+ - Noise Robustness (irrelevant schema resistance)
7
+
8
+ Author: SQLAS Contributors
9
+ """
10
+
11
+ import sqlglot
12
+
13
+
14
+ # ── Shared AST Extraction ──────────────────────────────────────────────────
15
+
16
+ def _extract_sql_elements(sql: str, dialect: str = "sqlite") -> dict:
17
+ """Extract tables, columns, literals, and functions from SQL AST.
18
+
19
+ Returns:
20
+ {
21
+ "tables": set of lowered table names,
22
+ "columns": set of lowered column names,
23
+ "table_columns": set of (table, column) tuples,
24
+ "literals": set of string representations of literal values,
25
+ "functions": set of lowered function names,
26
+ }
27
+ """
28
+ try:
29
+ parsed = sqlglot.parse_one(sql, dialect=dialect)
30
+ except Exception:
31
+ return {
32
+ "tables": set(),
33
+ "columns": set(),
34
+ "table_columns": set(),
35
+ "literals": set(),
36
+ "functions": set(),
37
+ }
38
+
39
+ tables = set()
40
+ for table in parsed.find_all(sqlglot.exp.Table):
41
+ if table.name:
42
+ tables.add(table.name.lower())
43
+
44
+ columns = set()
45
+ table_columns = set()
46
+ for col in parsed.find_all(sqlglot.exp.Column):
47
+ if col.name:
48
+ col_name = col.name.lower()
49
+ columns.add(col_name)
50
+ tbl = col.table.lower() if col.table else ""
51
+ if tbl:
52
+ table_columns.add((tbl, col_name))
53
+ else:
54
+ table_columns.add(("", col_name))
55
+
56
+ literals = set()
57
+ for lit in parsed.find_all(sqlglot.exp.Literal):
58
+ literals.add(str(lit.this).lower())
59
+
60
+ functions = set()
61
+ # Built-in aggregate / scalar expressions
62
+ for cls in (sqlglot.exp.Count, sqlglot.exp.Sum, sqlglot.exp.Avg,
63
+ sqlglot.exp.Min, sqlglot.exp.Max):
64
+ for _ in parsed.find_all(cls):
65
+ functions.add(cls.key.lower())
66
+ # Named functions (e.g. ROUND, COALESCE, custom UDFs)
67
+ for func in parsed.find_all(sqlglot.exp.Anonymous):
68
+ if func.name:
69
+ functions.add(func.name.lower())
70
+
71
+ return {
72
+ "tables": tables,
73
+ "columns": columns,
74
+ "table_columns": table_columns,
75
+ "literals": literals,
76
+ "functions": functions,
77
+ }
78
+
79
+
80
+ # ── Public API ─────────────────────────────────────────────────────────────
81
+
82
+ def context_precision(
83
+ generated_sql: str,
84
+ gold_sql: str,
85
+ dialect: str = "sqlite",
86
+ ) -> tuple[float, dict]:
87
+ """
88
+ RAGAS Context Precision for SQL agents.
89
+
90
+ Of all schema elements (tables + columns) referenced in the generated SQL,
91
+ what fraction are also referenced in the gold SQL? Penalizes referencing
92
+ unnecessary schema elements.
93
+
94
+ Args:
95
+ generated_sql: SQL produced by the agent
96
+ gold_sql: Ground-truth SQL
97
+ dialect: SQL dialect for parsing
98
+
99
+ Returns:
100
+ (precision score 0.0–1.0, details dict)
101
+ """
102
+ gen = _extract_sql_elements(generated_sql, dialect)
103
+ gold = _extract_sql_elements(gold_sql, dialect)
104
+
105
+ gen_elements = gen["tables"] | gen["columns"]
106
+ gold_elements = gold["tables"] | gold["columns"]
107
+
108
+ if not gen_elements:
109
+ return 1.0, {"generated_elements": [], "gold_elements": list(gold_elements),
110
+ "extra_elements": [], "precision": 1.0}
111
+
112
+ overlap = gen_elements & gold_elements
113
+ extra = gen_elements - gold_elements
114
+ precision = len(overlap) / len(gen_elements)
115
+
116
+ return round(precision, 4), {
117
+ "generated_elements": sorted(gen_elements),
118
+ "gold_elements": sorted(gold_elements),
119
+ "extra_elements": sorted(extra),
120
+ "precision": round(precision, 4),
121
+ }
122
+
123
+
124
+ def context_recall(
125
+ generated_sql: str,
126
+ gold_sql: str,
127
+ dialect: str = "sqlite",
128
+ ) -> tuple[float, dict]:
129
+ """
130
+ RAGAS Context Recall for SQL agents.
131
+
132
+ Of all schema elements (tables + columns) required by the gold SQL,
133
+ what fraction does the generated SQL also reference? Penalizes missing
134
+ necessary elements.
135
+
136
+ Args:
137
+ generated_sql: SQL produced by the agent
138
+ gold_sql: Ground-truth SQL
139
+ dialect: SQL dialect for parsing
140
+
141
+ Returns:
142
+ (recall score 0.0–1.0, details dict)
143
+ """
144
+ gen = _extract_sql_elements(generated_sql, dialect)
145
+ gold = _extract_sql_elements(gold_sql, dialect)
146
+
147
+ gen_elements = gen["tables"] | gen["columns"]
148
+ gold_elements = gold["tables"] | gold["columns"]
149
+
150
+ if not gold_elements:
151
+ return 1.0, {"missing_elements": [], "recall": 1.0}
152
+
153
+ overlap = gen_elements & gold_elements
154
+ missing = gold_elements - gen_elements
155
+ recall = len(overlap) / len(gold_elements)
156
+
157
+ return round(recall, 4), {
158
+ "missing_elements": sorted(missing),
159
+ "matched_elements": sorted(overlap),
160
+ "total_gold_elements": len(gold_elements),
161
+ "recall": round(recall, 4),
162
+ }
163
+
164
+
165
+ def entity_recall(
166
+ generated_sql: str,
167
+ gold_sql: str,
168
+ dialect: str = "sqlite",
169
+ ) -> tuple[float, dict]:
170
+ """
171
+ RAGAS Context Entity Recall for SQL agents.
172
+
173
+ Strict entity-level check: are all entities (tables, columns, literals,
174
+ functions) from the gold SQL present in the generated SQL?
175
+
176
+ This is stricter than context_recall because it also checks literal values
177
+ (e.g. WHERE status = 'active') and function usage (COUNT, SUM, etc.).
178
+
179
+ Args:
180
+ generated_sql: SQL produced by the agent
181
+ gold_sql: Ground-truth SQL
182
+ dialect: SQL dialect for parsing
183
+
184
+ Returns:
185
+ (recall score 0.0–1.0, details dict)
186
+ """
187
+ gen = _extract_sql_elements(generated_sql, dialect)
188
+ gold = _extract_sql_elements(gold_sql, dialect)
189
+
190
+ gold_entities = gold["tables"] | gold["columns"] | gold["literals"] | gold["functions"]
191
+ gen_entities = gen["tables"] | gen["columns"] | gen["literals"] | gen["functions"]
192
+
193
+ if not gold_entities:
194
+ return 1.0, {"missing_entities": [], "matched_entities": [],
195
+ "total_gold_entities": 0}
196
+
197
+ matched = gold_entities & gen_entities
198
+ missing = gold_entities - gen_entities
199
+ recall = len(matched) / len(gold_entities)
200
+
201
+ return round(recall, 4), {
202
+ "missing_entities": sorted(missing),
203
+ "matched_entities": sorted(matched),
204
+ "total_gold_entities": len(gold_entities),
205
+ }
206
+
207
+
208
+ def noise_robustness(
209
+ generated_sql: str,
210
+ gold_sql: str,
211
+ valid_tables: set[str] | None = None,
212
+ valid_columns: dict[str, set[str]] | None = None,
213
+ dialect: str = "sqlite",
214
+ ) -> tuple[float, dict]:
215
+ """
216
+ RAGAS Noise Sensitivity for SQL agents.
217
+
218
+ Does the agent avoid pulling in irrelevant schema elements? Measures
219
+ resilience to a large, noisy schema by checking if the generated SQL
220
+ references tables/columns that exist in the full schema but are NOT
221
+ needed by the gold SQL.
222
+
223
+ Args:
224
+ generated_sql: SQL produced by the agent
225
+ gold_sql: Ground-truth SQL
226
+ valid_tables: Full set of available table names (optional)
227
+ valid_columns: Full dict of {table: {col1, col2, ...}} (optional)
228
+ dialect: SQL dialect for parsing
229
+
230
+ Returns:
231
+ (robustness score 0.0–1.0, details dict)
232
+ """
233
+ gen = _extract_sql_elements(generated_sql, dialect)
234
+ gold = _extract_sql_elements(gold_sql, dialect)
235
+
236
+ gold_elements = gold["tables"] | gold["columns"]
237
+ gen_elements = gen["tables"] | gen["columns"]
238
+
239
+ if not gen_elements:
240
+ return 1.0, {"noise_tables": [], "noise_columns": [], "noise_count": 0}
241
+
242
+ # Elements in generated SQL that are NOT in gold SQL
243
+ extra = gen_elements - gold_elements
244
+
245
+ # If full schema is provided, only count extras that exist in the schema
246
+ # (i.e. real noise, not hallucinated tables/columns)
247
+ if valid_tables or valid_columns:
248
+ all_schema = set()
249
+ if valid_tables:
250
+ all_schema.update(t.lower() for t in valid_tables)
251
+ if valid_columns:
252
+ for cols in valid_columns.values():
253
+ all_schema.update(c.lower() for c in cols)
254
+ noise = extra & all_schema
255
+ else:
256
+ noise = extra
257
+
258
+ noise_tables = noise & gen["tables"]
259
+ noise_columns = noise & gen["columns"]
260
+ noise_count = len(noise)
261
+
262
+ score = max(0.0, 1.0 - (noise_count / len(gen_elements)))
263
+
264
+ return round(score, 4), {
265
+ "noise_tables": sorted(noise_tables),
266
+ "noise_columns": sorted(noise_columns),
267
+ "noise_count": noise_count,
268
+ }
sqlas/core.py ADDED
@@ -0,0 +1,208 @@
1
+ """
2
+ Core data structures and composite scoring for SQLAS.
3
+
4
+ Author: SQLAS Contributors
5
+ """
6
+
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import Callable
10
+
11
+
12
+ # ── Production Composite Weights (v1 — 15 metrics) ────────────────────────
13
+ # Aligned with industry-standard SQL agent evaluation:
14
+ # 40% Execution Accuracy — does the SQL return correct results?
15
+ # 15% Semantic Correctness — does the SQL answer the user's intent?
16
+ # 15% Cost Efficiency — is the query efficient?
17
+ # 10% Execution Quality — does the query execute successfully?
18
+ # 10% Task Success — does the user get a correct, complete answer?
19
+ # 10% Safety — is the query safe?
20
+ # ────────────────────────────────────────────────────────────────────────────
21
+
22
+ WEIGHTS = {
23
+ # 1. Execution Accuracy (40%)
24
+ "execution_accuracy": 0.40,
25
+ # 2. Semantic Correctness (15%)
26
+ "semantic_equivalence": 0.15,
27
+ # 3. Cost Efficiency (15%)
28
+ "efficiency_score": 0.05,
29
+ "data_scan_efficiency": 0.05,
30
+ "sql_quality": 0.03,
31
+ "schema_compliance": 0.02,
32
+ # 4. Execution Quality (10%)
33
+ "execution_success": 0.05,
34
+ "complexity_match": 0.03,
35
+ "empty_result_penalty": 0.02,
36
+ # 5. Task Success (10%)
37
+ "faithfulness": 0.04,
38
+ "answer_relevance": 0.03,
39
+ "answer_completeness": 0.02,
40
+ "fluency": 0.01,
41
+ # 6. Safety (10%)
42
+ "read_only_compliance": 0.05,
43
+ "safety_score": 0.05,
44
+ }
45
+
46
+
47
+ # ── Production Composite Weights (v2 — 20 metrics with context quality) ───
48
+ # Adds RAGAS-mapped context metrics for SQL agents.
49
+ # ────────────────────────────────────────────────────────────────────────────
50
+
51
+ WEIGHTS_V2 = {
52
+ # 1. Execution Accuracy (35%)
53
+ "execution_accuracy": 0.35,
54
+ # 2. Semantic Correctness (13%)
55
+ "semantic_equivalence": 0.13,
56
+ # 3. Context Quality (10%) — RAGAS-mapped
57
+ "context_precision": 0.03,
58
+ "context_recall": 0.03,
59
+ "entity_recall": 0.02,
60
+ "noise_robustness": 0.02,
61
+ # 4. Cost Efficiency (12%)
62
+ "efficiency_score": 0.04,
63
+ "data_scan_efficiency": 0.04,
64
+ "sql_quality": 0.02,
65
+ "schema_compliance": 0.02,
66
+ # 5. Execution Quality (8%)
67
+ "execution_success": 0.04,
68
+ "complexity_match": 0.02,
69
+ "empty_result_penalty": 0.02,
70
+ # 6. Task Success (8%)
71
+ "faithfulness": 0.03,
72
+ "answer_relevance": 0.02,
73
+ "answer_completeness": 0.02,
74
+ "fluency": 0.01,
75
+ # 7. Result Similarity (4%)
76
+ "result_set_similarity": 0.04,
77
+ # 8. Safety (10%)
78
+ "read_only_compliance": 0.05,
79
+ "safety_score": 0.05,
80
+ }
81
+
82
+
83
+ @dataclass
84
+ class TestCase:
85
+ """A single evaluation test case."""
86
+ question: str
87
+ gold_sql: str | None = None
88
+ expected_tables: list[str] | None = None
89
+ expects_join: bool = False
90
+ expected_nonempty: bool = True
91
+ category: str = "general"
92
+
93
+
94
+ @dataclass
95
+ class SQLASScores:
96
+ """Complete production-grade evaluation scores for a single query."""
97
+
98
+ # 1. Core SQL Correctness
99
+ execution_accuracy: float = 0.0
100
+ syntax_valid: float = 0.0
101
+ semantic_equivalence: float = 0.0
102
+
103
+ # 2. SQL Quality & Structure
104
+ schema_compliance: float = 0.0
105
+ sql_quality: float = 0.0
106
+ complexity_match: float = 0.0
107
+
108
+ # 3. Production Execution
109
+ execution_success: float = 0.0
110
+ execution_time_ms: float = 0.0
111
+ efficiency_score: float = 0.0
112
+ data_scan_efficiency: float = 0.0
113
+ result_row_count: int = 0
114
+ empty_result_penalty: float = 0.0
115
+ row_explosion_detected: bool = False
116
+
117
+ # 4. Response Quality
118
+ faithfulness: float = 0.0
119
+ answer_relevance: float = 0.0
120
+ answer_completeness: float = 0.0
121
+ fluency: float = 0.0
122
+
123
+ # 5. Safety & Governance
124
+ read_only_compliance: float = 0.0
125
+ safety_score: float = 0.0
126
+
127
+ # 6. Context Quality (RAGAS-mapped)
128
+ context_precision: float = 0.0
129
+ context_recall: float = 0.0
130
+ entity_recall: float = 0.0
131
+ noise_robustness: float = 0.0
132
+ result_set_similarity: float = 0.0
133
+
134
+ # Composite
135
+ overall_score: float = 0.0
136
+ details: dict = field(default_factory=dict)
137
+
138
+ def to_dict(self) -> dict:
139
+ """Export all scores as a flat dictionary."""
140
+ all_keys = set(WEIGHTS.keys()) | set(WEIGHTS_V2.keys())
141
+ d = {}
142
+ for key in all_keys:
143
+ d[key] = getattr(self, key, 0.0)
144
+ d["overall_score"] = self.overall_score
145
+ d["syntax_valid"] = self.syntax_valid
146
+ d["execution_time_ms"] = self.execution_time_ms
147
+ d["result_row_count"] = self.result_row_count
148
+ d["row_explosion_detected"] = self.row_explosion_detected
149
+ return d
150
+
151
+ def summary(self) -> str:
152
+ """Human-readable summary."""
153
+ lines = [f"SQLAS Score: {self.overall_score:.4f} / 1.0"]
154
+ cats = {
155
+ "Execution Accuracy": [("execution_accuracy", self.execution_accuracy)],
156
+ "Semantic Correctness": [("semantic_equivalence", self.semantic_equivalence)],
157
+ "Context Quality": [("context_precision", self.context_precision), ("context_recall", self.context_recall), ("entity_recall", self.entity_recall), ("noise_robustness", self.noise_robustness), ("result_similarity", self.result_set_similarity)],
158
+ "Cost Efficiency": [("efficiency", self.efficiency_score), ("data_scan", self.data_scan_efficiency), ("sql_quality", self.sql_quality), ("schema", self.schema_compliance)],
159
+ "Execution Quality": [("exec_success", self.execution_success), ("complexity", self.complexity_match), ("empty_result", self.empty_result_penalty)],
160
+ "Task Success": [("faithfulness", self.faithfulness), ("relevance", self.answer_relevance), ("completeness", self.answer_completeness), ("fluency", self.fluency)],
161
+ "Safety": [("read_only", self.read_only_compliance), ("safety", self.safety_score)],
162
+ }
163
+ for cat, metrics in cats.items():
164
+ lines.append(f" {cat}")
165
+ for name, val in metrics:
166
+ lines.append(f" {name}: {val:.4f}")
167
+ return "\n".join(lines)
168
+
169
+
170
+ # ── LLM Judge type ──────────────────────────────────────────────────────────
171
+ # Users provide their own LLM function: (prompt: str) -> str
172
+ LLMJudge = Callable[[str], str]
173
+
174
+
175
+ def _parse_score(result: str, key: str) -> tuple[float, str]:
176
+ """Shared helper to extract a score and reasoning from LLM judge output.
177
+
178
+ Looks for lines like 'Key: 0.85' and 'Reasoning: ...' in the result text.
179
+
180
+ Args:
181
+ result: Raw LLM judge output
182
+ key: The score key to look for (e.g. 'Faithfulness', 'Relevance')
183
+
184
+ Returns:
185
+ (score clamped to 0.0–1.0, reasoning string)
186
+ """
187
+ score, reasoning = 0.0, ""
188
+ for line in result.strip().split("\n"):
189
+ if line.startswith(key + ":"):
190
+ try:
191
+ score = float(re.search(r"[\d.]+", line.split(":")[-1]).group())
192
+ except Exception:
193
+ pass
194
+ if line.startswith("Reasoning:"):
195
+ reasoning = line.split(":", 1)[-1].strip()
196
+ return min(score, 1.0), reasoning
197
+
198
+
199
+ def compute_composite_score(scores: SQLASScores, weights: dict | None = None) -> float:
200
+ """Compute weighted overall SQLAS score."""
201
+ w = weights or WEIGHTS
202
+ total = 0.0
203
+ for metric, weight in w.items():
204
+ val = getattr(scores, metric, 0.0)
205
+ if isinstance(val, bool):
206
+ val = 1.0 if val else 0.0
207
+ total += val * weight
208
+ return round(total, 4)