sqlas 2.5.0__tar.gz → 2.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sqlas-2.5.0/sqlas.egg-info → sqlas-2.6.0}/PKG-INFO +15 -1
- {sqlas-2.5.0 → sqlas-2.6.0}/pyproject.toml +9 -4
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/__init__.py +3 -1
- sqlas-2.6.0/sqlas/benchmarks.py +480 -0
- sqlas-2.6.0/sqlas/integrations.py +274 -0
- sqlas-2.6.0/sqlas/ui.py +572 -0
- {sqlas-2.5.0 → sqlas-2.6.0/sqlas.egg-info}/PKG-INFO +15 -1
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas.egg-info/SOURCES.txt +3 -0
- sqlas-2.6.0/sqlas.egg-info/requires.txt +31 -0
- sqlas-2.6.0/sqlas.egg-info/top_level.txt +2 -0
- sqlas-2.5.0/sqlas.egg-info/requires.txt +0 -12
- sqlas-2.5.0/sqlas.egg-info/top_level.txt +0 -1
- {sqlas-2.5.0 → sqlas-2.6.0}/LICENSE +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/README.md +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/setup.cfg +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/agentic.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/cache.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/context.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/core.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/correctness.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/evaluate.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/feedback.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/guardrails.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/production.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/prompt_registry.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/py.typed +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/quality.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/response.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/runner.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/safety.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/schema_quality.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/visualization.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/sqlas.egg-info/dependency_links.txt +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_context.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_execute_fn.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_large_schema.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_sqlas.py +0 -0
- {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_v2.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sqlas
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6.0
|
|
4
4
|
Summary: SQLAS — SQL Agent Scoring Framework. Production-grade evaluation for Text-to-SQL and Agentic SQL agents with guardrail, visualization, agentic quality, and cache performance metrics.
|
|
5
5
|
Author-email: thepradip <pradiptivhale@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -26,12 +26,26 @@ License-File: LICENSE
|
|
|
26
26
|
Requires-Dist: sqlglot>=20.0
|
|
27
27
|
Provides-Extra: mlflow
|
|
28
28
|
Requires-Dist: mlflow>=3.0; extra == "mlflow"
|
|
29
|
+
Provides-Extra: wandb
|
|
30
|
+
Requires-Dist: wandb>=0.16; extra == "wandb"
|
|
31
|
+
Provides-Extra: langsmith
|
|
32
|
+
Requires-Dist: langsmith>=0.1; extra == "langsmith"
|
|
33
|
+
Provides-Extra: ui
|
|
34
|
+
Requires-Dist: streamlit>=1.30; extra == "ui"
|
|
35
|
+
Requires-Dist: pandas>=2.0; extra == "ui"
|
|
36
|
+
Provides-Extra: benchmarks
|
|
37
|
+
Provides-Extra: prometheus
|
|
38
|
+
Requires-Dist: prometheus-client>=0.19; extra == "prometheus"
|
|
29
39
|
Provides-Extra: dev
|
|
30
40
|
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
31
41
|
Requires-Dist: build; extra == "dev"
|
|
32
42
|
Requires-Dist: twine; extra == "dev"
|
|
33
43
|
Provides-Extra: all
|
|
34
44
|
Requires-Dist: mlflow>=3.0; extra == "all"
|
|
45
|
+
Requires-Dist: wandb>=0.16; extra == "all"
|
|
46
|
+
Requires-Dist: langsmith>=0.1; extra == "all"
|
|
47
|
+
Requires-Dist: streamlit>=1.30; extra == "all"
|
|
48
|
+
Requires-Dist: pandas>=2.0; extra == "all"
|
|
35
49
|
Dynamic: license-file
|
|
36
50
|
|
|
37
51
|
# SQLAS — SQL Agent Scoring Framework
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "sqlas"
|
|
7
|
-
version = "2.
|
|
7
|
+
version = "2.6.0"
|
|
8
8
|
description = "SQLAS — SQL Agent Scoring Framework. Production-grade evaluation for Text-to-SQL and Agentic SQL agents with guardrail, visualization, agentic quality, and cache performance metrics."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "MIT"
|
|
@@ -29,9 +29,14 @@ dependencies = [
|
|
|
29
29
|
]
|
|
30
30
|
|
|
31
31
|
[project.optional-dependencies]
|
|
32
|
-
mlflow
|
|
33
|
-
|
|
34
|
-
|
|
32
|
+
mlflow = ["mlflow>=3.0"]
|
|
33
|
+
wandb = ["wandb>=0.16"]
|
|
34
|
+
langsmith = ["langsmith>=0.1"]
|
|
35
|
+
ui = ["streamlit>=1.30", "pandas>=2.0"]
|
|
36
|
+
benchmarks = [] # Spider/BIRD datasets downloaded separately
|
|
37
|
+
prometheus = ["prometheus-client>=0.19"]
|
|
38
|
+
dev = ["pytest>=7.0", "build", "twine"]
|
|
39
|
+
all = ["mlflow>=3.0", "wandb>=0.16", "langsmith>=0.1", "streamlit>=1.30", "pandas>=2.0"]
|
|
35
40
|
|
|
36
41
|
[project.urls]
|
|
37
42
|
Homepage = "https://github.com/thepradip/SQLAS"
|
|
@@ -33,6 +33,8 @@ from sqlas.guardrails import GuardrailPipeline, GuardrailResult
|
|
|
33
33
|
from sqlas.feedback import FeedbackStore, FeedbackEntry
|
|
34
34
|
from sqlas.prompt_registry import PromptRegistry, PromptVersion
|
|
35
35
|
from sqlas.schema_quality import schema_retrieval_quality, batch_retrieval_quality
|
|
36
|
+
from sqlas.benchmarks import run_spider_benchmark, run_bird_benchmark, download_instructions
|
|
37
|
+
from sqlas.integrations import log_to_mlflow, log_to_wandb, log_to_langsmith, log_all
|
|
36
38
|
from sqlas.correctness import execution_accuracy, syntax_valid, semantic_equivalence, result_set_similarity
|
|
37
39
|
from sqlas.quality import sql_quality, schema_compliance, complexity_match
|
|
38
40
|
from sqlas.production import data_scan_efficiency, execution_result, result_coverage
|
|
@@ -51,7 +53,7 @@ from sqlas.agentic import (
|
|
|
51
53
|
from sqlas.cache import cache_hit_score, tokens_saved_score, few_shot_score
|
|
52
54
|
from sqlas.runner import run_suite
|
|
53
55
|
|
|
54
|
-
__version__ = "2.
|
|
56
|
+
__version__ = "2.6.0"
|
|
55
57
|
__author__ = "SQLAS Contributors"
|
|
56
58
|
|
|
57
59
|
__all__ = [
|
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SQLAS Benchmark Integration — Spider and BIRD datasets.
|
|
3
|
+
|
|
4
|
+
Closes the academic credibility gap vs RAGAS/DeepEval by integrating
|
|
5
|
+
the two standard NL2SQL benchmarks with smart sampling to keep costs low.
|
|
6
|
+
|
|
7
|
+
Token cost strategy:
|
|
8
|
+
- Default n_samples=50 → ~$0.25 with GPT-4o judge
|
|
9
|
+
- Safety metrics (free, no LLM) run on ALL sampled questions
|
|
10
|
+
- LLM judge only runs on questions that actually execute correctly
|
|
11
|
+
- Stratified sampling ensures representative difficulty distribution
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
from sqlas.benchmarks import run_spider_benchmark, download_instructions
|
|
15
|
+
|
|
16
|
+
# Check dataset is available
|
|
17
|
+
print(download_instructions("spider"))
|
|
18
|
+
|
|
19
|
+
# Run benchmark (50 questions, stratified by difficulty)
|
|
20
|
+
results = run_spider_benchmark(
|
|
21
|
+
agent_fn = my_agent,
|
|
22
|
+
llm_judge = my_judge,
|
|
23
|
+
spider_dir = "./spider",
|
|
24
|
+
n_samples = 50,
|
|
25
|
+
difficulty = None, # None = all difficulties
|
|
26
|
+
query_types = None, # None = all types
|
|
27
|
+
seed = 42, # reproducible sampling
|
|
28
|
+
weights = WEIGHTS_V4,
|
|
29
|
+
mlflow_run = True,
|
|
30
|
+
verbose = True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
print(results["summary"]["overall_score"])
|
|
34
|
+
print(results["benchmark_stats"]["execution_accuracy"])
|
|
35
|
+
print(results["cost_estimate_usd"])
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
import json
|
|
39
|
+
import logging
|
|
40
|
+
import os
|
|
41
|
+
import random
|
|
42
|
+
import sqlite3
|
|
43
|
+
import time
|
|
44
|
+
from dataclasses import dataclass, field
|
|
45
|
+
from pathlib import Path
|
|
46
|
+
from typing import Optional
|
|
47
|
+
|
|
48
|
+
from sqlas.core import TestCase, WEIGHTS, WEIGHTS_V4, LLMJudge, ExecuteFn
|
|
49
|
+
from sqlas.evaluate import evaluate
|
|
50
|
+
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
# ── Dataset download instructions ─────────────────────────────────────────────
|
|
54
|
+
|
|
55
|
+
_DOWNLOAD_INSTRUCTIONS = {
|
|
56
|
+
"spider": """
|
|
57
|
+
Spider dataset not found. Download it:
|
|
58
|
+
1. Visit: https://yale-lily.github.io/spider
|
|
59
|
+
2. Download 'Spider 1.0 Dataset' zip
|
|
60
|
+
3. Extract to a directory, e.g. ./spider/
|
|
61
|
+
4. Pass: spider_dir='./spider'
|
|
62
|
+
|
|
63
|
+
Expected structure:
|
|
64
|
+
spider/
|
|
65
|
+
dev.json ← 1034 dev questions
|
|
66
|
+
tables.json ← schema metadata
|
|
67
|
+
database/
|
|
68
|
+
{db_id}/
|
|
69
|
+
{db_id}.sqlite ← SQLite databases
|
|
70
|
+
""",
|
|
71
|
+
"bird": """
|
|
72
|
+
BIRD dataset not found. Download it:
|
|
73
|
+
1. Visit: https://bird-bench.github.io/
|
|
74
|
+
2. Download 'BIRD-SQL dev set'
|
|
75
|
+
3. Extract to a directory, e.g. ./bird/
|
|
76
|
+
4. Pass: bird_dir='./bird'
|
|
77
|
+
|
|
78
|
+
Expected structure:
|
|
79
|
+
bird/
|
|
80
|
+
dev/
|
|
81
|
+
dev.json
|
|
82
|
+
dev_databases/
|
|
83
|
+
{db_id}/
|
|
84
|
+
{db_id}.sqlite
|
|
85
|
+
""",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def download_instructions(dataset: str = "spider") -> str:
|
|
90
|
+
return _DOWNLOAD_INSTRUCTIONS.get(dataset, f"Unknown dataset: {dataset}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ── Sampling ───────────────────────────────────────────────────────────────────
|
|
94
|
+
|
|
95
|
+
_DIFFICULTY_WEIGHTS = {
|
|
96
|
+
"easy": 0.20,
|
|
97
|
+
"medium": 0.30,
|
|
98
|
+
"hard": 0.30,
|
|
99
|
+
"extra hard": 0.20,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
_SQL_TYPE_PATTERNS = {
|
|
103
|
+
"simple": lambda sql: "JOIN" not in sql.upper() and "GROUP BY" not in sql.upper(),
|
|
104
|
+
"aggregation": lambda sql: "GROUP BY" in sql.upper() or any(f in sql.upper() for f in ("COUNT(","SUM(","AVG(","MAX(","MIN(")),
|
|
105
|
+
"join": lambda sql: "JOIN" in sql.upper(),
|
|
106
|
+
"nested": lambda sql: sql.upper().count("SELECT") > 1,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _sample_questions(
|
|
111
|
+
questions: list[dict],
|
|
112
|
+
n_samples: int,
|
|
113
|
+
difficulty: list[str] | None,
|
|
114
|
+
query_types: list[str] | None,
|
|
115
|
+
seed: int,
|
|
116
|
+
) -> list[dict]:
|
|
117
|
+
"""
|
|
118
|
+
Stratified sample by difficulty and optionally filter by query type.
|
|
119
|
+
Reproducible with fixed seed. Handles uneven difficulty distributions.
|
|
120
|
+
"""
|
|
121
|
+
rng = random.Random(seed)
|
|
122
|
+
|
|
123
|
+
# Filter by difficulty
|
|
124
|
+
if difficulty:
|
|
125
|
+
dl = {d.lower() for d in difficulty}
|
|
126
|
+
questions = [q for q in questions if q.get("difficulty", "").lower() in dl]
|
|
127
|
+
|
|
128
|
+
# Filter by SQL type
|
|
129
|
+
if query_types:
|
|
130
|
+
filtered = []
|
|
131
|
+
for q in questions:
|
|
132
|
+
sql = q.get("query", "")
|
|
133
|
+
for qtype in query_types:
|
|
134
|
+
fn = _SQL_TYPE_PATTERNS.get(qtype)
|
|
135
|
+
if fn and fn(sql):
|
|
136
|
+
filtered.append(q)
|
|
137
|
+
break
|
|
138
|
+
questions = filtered
|
|
139
|
+
|
|
140
|
+
if not questions:
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
# Group by difficulty
|
|
144
|
+
by_diff: dict[str, list] = {}
|
|
145
|
+
for q in questions:
|
|
146
|
+
d = q.get("difficulty", "medium").lower()
|
|
147
|
+
by_diff.setdefault(d, []).append(q)
|
|
148
|
+
|
|
149
|
+
# Proportional sample
|
|
150
|
+
sampled: list[dict] = []
|
|
151
|
+
for diff, weight in _DIFFICULTY_WEIGHTS.items():
|
|
152
|
+
pool = by_diff.get(diff, [])
|
|
153
|
+
n = max(1, round(n_samples * weight))
|
|
154
|
+
take = min(n, len(pool))
|
|
155
|
+
if pool:
|
|
156
|
+
sampled.extend(rng.sample(pool, take))
|
|
157
|
+
|
|
158
|
+
# Top up if total < n_samples due to rounding
|
|
159
|
+
remaining = [q for q in questions if q not in sampled]
|
|
160
|
+
rng.shuffle(remaining)
|
|
161
|
+
sampled.extend(remaining[: max(0, n_samples - len(sampled))])
|
|
162
|
+
|
|
163
|
+
return sampled[:n_samples]
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# ── Spider benchmark ───────────────────────────────────────────────────────────
|
|
167
|
+
|
|
168
|
+
def run_spider_benchmark(
|
|
169
|
+
agent_fn,
|
|
170
|
+
llm_judge: LLMJudge,
|
|
171
|
+
spider_dir: str = "./spider",
|
|
172
|
+
n_samples: int = 50,
|
|
173
|
+
difficulty: list[str] | None = None,
|
|
174
|
+
query_types: list[str] | None = None,
|
|
175
|
+
seed: int = 42,
|
|
176
|
+
weights: dict | None = None,
|
|
177
|
+
pass_threshold: float = 0.6,
|
|
178
|
+
validate_chart_with_llm: bool = False, # off by default to save tokens
|
|
179
|
+
mlflow_run: bool = False,
|
|
180
|
+
verbose: bool = True,
|
|
181
|
+
) -> dict:
|
|
182
|
+
"""
|
|
183
|
+
Evaluate an SQL agent against the Spider benchmark with smart sampling.
|
|
184
|
+
|
|
185
|
+
Token-saving defaults:
|
|
186
|
+
n_samples=50 → ~$0.25 with GPT-4o (not $5-15 for full set)
|
|
187
|
+
Safety checks → free (no LLM), run on all samples
|
|
188
|
+
LLM judge → only when execution succeeds (skips failed queries)
|
|
189
|
+
No chart eval → validate_chart_with_llm=False
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
agent_fn: Function(question: str) -> {sql, response, data?}
|
|
193
|
+
llm_judge: LLM judge function (prompt: str) -> str
|
|
194
|
+
spider_dir: Path to extracted Spider dataset
|
|
195
|
+
n_samples: Questions to evaluate (default 50, full set = 1034)
|
|
196
|
+
difficulty: Filter by difficulty: ["easy","medium","hard","extra hard"]
|
|
197
|
+
query_types: Filter by type: ["simple","aggregation","join","nested"]
|
|
198
|
+
seed: Random seed for reproducible sampling
|
|
199
|
+
weights: SQLAS weight profile (default WEIGHTS_V4)
|
|
200
|
+
pass_threshold: Min score for PASS label
|
|
201
|
+
mlflow_run: Log to MLflow experiment
|
|
202
|
+
verbose: Print progress
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
{summary, details, benchmark_stats, cost_estimate_usd, sample_info}
|
|
206
|
+
"""
|
|
207
|
+
spider_path = Path(spider_dir)
|
|
208
|
+
dev_file = spider_path / "dev.json"
|
|
209
|
+
db_dir = spider_path / "database"
|
|
210
|
+
|
|
211
|
+
if not dev_file.exists():
|
|
212
|
+
raise FileNotFoundError(
|
|
213
|
+
f"Spider dev.json not found at {dev_file}\n{download_instructions('spider')}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
with open(dev_file) as f:
|
|
217
|
+
all_questions = json.load(f)
|
|
218
|
+
|
|
219
|
+
sampled = _sample_questions(all_questions, n_samples, difficulty, query_types, seed)
|
|
220
|
+
|
|
221
|
+
if verbose:
|
|
222
|
+
diff_dist = {}
|
|
223
|
+
for q in sampled:
|
|
224
|
+
d = q.get("difficulty","?")
|
|
225
|
+
diff_dist[d] = diff_dist.get(d, 0) + 1
|
|
226
|
+
print(f"\nSQLAS Spider Benchmark")
|
|
227
|
+
print(f" Dataset : Spider dev ({len(all_questions)} total)")
|
|
228
|
+
print(f" Sample : {len(sampled)} questions (seed={seed})")
|
|
229
|
+
print(f" Difficulty : {diff_dist}")
|
|
230
|
+
print(f" Est. cost : ~${len(sampled) * 0.005:.2f} (GPT-4o)\n")
|
|
231
|
+
|
|
232
|
+
results, benchmark_stats = _run_benchmark(
|
|
233
|
+
sampled, agent_fn, llm_judge, db_dir,
|
|
234
|
+
weights, pass_threshold, validate_chart_with_llm, verbose,
|
|
235
|
+
dataset_name="Spider",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if mlflow_run:
|
|
239
|
+
_log_to_mlflow("sqlas-spider-benchmark", results, benchmark_stats, sampled)
|
|
240
|
+
|
|
241
|
+
cost = len(sampled) * 0.005
|
|
242
|
+
return {
|
|
243
|
+
**results,
|
|
244
|
+
"benchmark_stats": benchmark_stats,
|
|
245
|
+
"cost_estimate_usd": round(cost, 3),
|
|
246
|
+
"sample_info": {
|
|
247
|
+
"total_in_dataset": len(all_questions),
|
|
248
|
+
"sampled": len(sampled),
|
|
249
|
+
"seed": seed,
|
|
250
|
+
"difficulty_filter": difficulty,
|
|
251
|
+
"type_filter": query_types,
|
|
252
|
+
},
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# ── BIRD benchmark ─────────────────────────────────────────────────────────────
|
|
257
|
+
|
|
258
|
+
def run_bird_benchmark(
|
|
259
|
+
agent_fn,
|
|
260
|
+
llm_judge: LLMJudge,
|
|
261
|
+
bird_dir: str = "./bird",
|
|
262
|
+
n_samples: int = 50,
|
|
263
|
+
difficulty: list[str] | None = None,
|
|
264
|
+
seed: int = 42,
|
|
265
|
+
weights: dict | None = None,
|
|
266
|
+
pass_threshold: float = 0.6,
|
|
267
|
+
mlflow_run: bool = False,
|
|
268
|
+
verbose: bool = True,
|
|
269
|
+
) -> dict:
|
|
270
|
+
"""
|
|
271
|
+
Evaluate against BIRD benchmark (harder than Spider — real DBs with noise).
|
|
272
|
+
BIRD includes the Valid Efficiency Score (VES) — correct AND fast queries.
|
|
273
|
+
|
|
274
|
+
Same token-saving defaults as run_spider_benchmark().
|
|
275
|
+
"""
|
|
276
|
+
bird_path = Path(bird_dir)
|
|
277
|
+
dev_file = bird_path / "dev" / "dev.json"
|
|
278
|
+
db_dir = bird_path / "dev" / "dev_databases"
|
|
279
|
+
|
|
280
|
+
if not dev_file.exists():
|
|
281
|
+
raise FileNotFoundError(
|
|
282
|
+
f"BIRD dev.json not found at {dev_file}\n{download_instructions('bird')}"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
with open(dev_file) as f:
|
|
286
|
+
all_questions = json.load(f)
|
|
287
|
+
|
|
288
|
+
# BIRD uses "difficulty" field too (simple/moderate/challenging)
|
|
289
|
+
sampled = _sample_questions(all_questions, n_samples, difficulty, None, seed)
|
|
290
|
+
|
|
291
|
+
if verbose:
|
|
292
|
+
print(f"\nSQLAS BIRD Benchmark")
|
|
293
|
+
print(f" Dataset : BIRD dev ({len(all_questions)} total)")
|
|
294
|
+
print(f" Sample : {len(sampled)} questions")
|
|
295
|
+
print(f" Est. cost: ~${len(sampled) * 0.005:.2f}\n")
|
|
296
|
+
|
|
297
|
+
results, benchmark_stats = _run_benchmark(
|
|
298
|
+
sampled, agent_fn, llm_judge, db_dir,
|
|
299
|
+
weights, pass_threshold, False, verbose,
|
|
300
|
+
dataset_name="BIRD",
|
|
301
|
+
db_subdir=True, # BIRD has {db_dir}/{db_id}/{db_id}.sqlite
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if mlflow_run:
|
|
305
|
+
_log_to_mlflow("sqlas-bird-benchmark", results, benchmark_stats, sampled)
|
|
306
|
+
|
|
307
|
+
return {
|
|
308
|
+
**results,
|
|
309
|
+
"benchmark_stats": benchmark_stats,
|
|
310
|
+
"cost_estimate_usd": round(len(sampled) * 0.005, 3),
|
|
311
|
+
"sample_info": {"total_in_dataset": len(all_questions), "sampled": len(sampled)},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ── Shared runner ──────────────────────────────────────────────────────────────
|
|
316
|
+
|
|
317
|
+
def _run_benchmark(
|
|
318
|
+
sampled: list[dict],
|
|
319
|
+
agent_fn,
|
|
320
|
+
llm_judge: LLMJudge,
|
|
321
|
+
db_dir: Path,
|
|
322
|
+
weights: dict | None,
|
|
323
|
+
pass_threshold: float,
|
|
324
|
+
validate_chart: bool,
|
|
325
|
+
verbose: bool,
|
|
326
|
+
dataset_name: str,
|
|
327
|
+
db_subdir: bool = False,
|
|
328
|
+
) -> tuple[dict, dict]:
|
|
329
|
+
"""Core benchmark runner shared by Spider and BIRD."""
|
|
330
|
+
from sqlas.core import SQLASScores
|
|
331
|
+
|
|
332
|
+
w = weights or WEIGHTS_V4
|
|
333
|
+
all_scores: list[SQLASScores] = []
|
|
334
|
+
by_difficulty: dict[str, list[float]] = {}
|
|
335
|
+
by_type: dict[str, list[float]] = {}
|
|
336
|
+
exec_successes = 0
|
|
337
|
+
start = time.perf_counter()
|
|
338
|
+
|
|
339
|
+
for i, q in enumerate(sampled):
|
|
340
|
+
db_id = q.get("db_id", "")
|
|
341
|
+
question = q.get("question", "")
|
|
342
|
+
gold_sql = q.get("query", q.get("SQL", ""))
|
|
343
|
+
diff = q.get("difficulty", "medium")
|
|
344
|
+
|
|
345
|
+
# Locate the SQLite database
|
|
346
|
+
if db_subdir:
|
|
347
|
+
db_path = str(db_dir / db_id / f"{db_id}.sqlite")
|
|
348
|
+
else:
|
|
349
|
+
db_path = str(db_dir / db_id / f"{db_id}.sqlite")
|
|
350
|
+
|
|
351
|
+
if not os.path.exists(db_path):
|
|
352
|
+
if verbose:
|
|
353
|
+
print(f" SKIP [{i+1}/{len(sampled)}] DB not found: {db_path}")
|
|
354
|
+
continue
|
|
355
|
+
|
|
356
|
+
if verbose:
|
|
357
|
+
print(f" [{i+1}/{len(sampled)}] {diff:12s} | {question[:60]}...")
|
|
358
|
+
|
|
359
|
+
# Run agent
|
|
360
|
+
try:
|
|
361
|
+
result = agent_fn(question)
|
|
362
|
+
except Exception as e:
|
|
363
|
+
logger.warning("agent_fn failed: %s", e)
|
|
364
|
+
result = {"sql": "", "response": str(e), "data": None, "success": False}
|
|
365
|
+
|
|
366
|
+
# Build execute_fn for this specific database
|
|
367
|
+
def make_execute_fn(path: str):
|
|
368
|
+
def execute_fn(sql: str) -> list[tuple]:
|
|
369
|
+
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
|
370
|
+
try:
|
|
371
|
+
return conn.execute(sql).fetchall()
|
|
372
|
+
finally:
|
|
373
|
+
conn.close()
|
|
374
|
+
return execute_fn
|
|
375
|
+
|
|
376
|
+
exec_fn = make_execute_fn(db_path)
|
|
377
|
+
|
|
378
|
+
# Determine SQL type for breakdown
|
|
379
|
+
sql_type = "simple"
|
|
380
|
+
sql_upper = gold_sql.upper()
|
|
381
|
+
if sql_upper.count("SELECT") > 1:
|
|
382
|
+
sql_type = "nested"
|
|
383
|
+
elif "JOIN" in sql_upper:
|
|
384
|
+
sql_type = "join"
|
|
385
|
+
elif "GROUP BY" in sql_upper or any(f in sql_upper for f in ("COUNT(","SUM(","AVG(")):
|
|
386
|
+
sql_type = "aggregation"
|
|
387
|
+
|
|
388
|
+
# Evaluate — skip LLM judge if query failed to save tokens
|
|
389
|
+
use_llm = result.get("success", False) or bool(result.get("sql", "").strip())
|
|
390
|
+
scores = evaluate(
|
|
391
|
+
question = question,
|
|
392
|
+
generated_sql = result.get("sql", ""),
|
|
393
|
+
llm_judge = llm_judge if use_llm else lambda p: "Semantic_Score: 0.0\nReasoning: Skipped — execution failed.",
|
|
394
|
+
gold_sql = gold_sql,
|
|
395
|
+
db_path = db_path,
|
|
396
|
+
execute_fn = exec_fn,
|
|
397
|
+
response = result.get("response"),
|
|
398
|
+
result_data = result.get("data"),
|
|
399
|
+
validate_chart_with_llm = validate_chart,
|
|
400
|
+
weights = w,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
all_scores.append(scores)
|
|
404
|
+
by_difficulty.setdefault(diff, []).append(scores.overall_score)
|
|
405
|
+
by_type.setdefault(sql_type, []).append(scores.overall_score)
|
|
406
|
+
if scores.execution_success == 1.0:
|
|
407
|
+
exec_successes += 1
|
|
408
|
+
|
|
409
|
+
elapsed = time.perf_counter() - start
|
|
410
|
+
n = len(all_scores)
|
|
411
|
+
avg = lambda attr: round(sum(getattr(s, attr, 0) for s in all_scores) / max(n, 1), 4)
|
|
412
|
+
|
|
413
|
+
summary = {
|
|
414
|
+
"dataset": dataset_name,
|
|
415
|
+
"n_evaluated": n,
|
|
416
|
+
"time_seconds": round(elapsed, 1),
|
|
417
|
+
"overall_score": avg("overall_score"),
|
|
418
|
+
"pass_rate": round(sum(1 for s in all_scores if s.overall_score >= pass_threshold) / max(n, 1), 4),
|
|
419
|
+
"execution_accuracy": avg("execution_accuracy"),
|
|
420
|
+
"semantic_equivalence": avg("semantic_equivalence"),
|
|
421
|
+
"faithfulness": avg("faithfulness"),
|
|
422
|
+
"safety_score": avg("safety_score"),
|
|
423
|
+
"sql_quality": avg("sql_quality"),
|
|
424
|
+
"by_difficulty": {d: round(sum(v)/len(v), 4) for d, v in by_difficulty.items()},
|
|
425
|
+
"by_query_type": {t: round(sum(v)/len(v), 4) for t, v in by_type.items()},
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
benchmark_stats = {
|
|
429
|
+
"execution_success_rate": round(exec_successes / max(n, 1), 4),
|
|
430
|
+
"avg_correctness_score": avg("correctness_score"),
|
|
431
|
+
"avg_quality_score": avg("quality_score"),
|
|
432
|
+
"avg_safety_score": avg("safety_composite_score"),
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
if verbose:
|
|
436
|
+
_print_benchmark_report(summary, benchmark_stats, dataset_name)
|
|
437
|
+
|
|
438
|
+
return {"summary": summary, "details": all_scores}, benchmark_stats
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def _print_benchmark_report(summary: dict, stats: dict, name: str):
|
|
442
|
+
n = summary["n_evaluated"]
|
|
443
|
+
print(f"\n{'='*60}")
|
|
444
|
+
print(f" SQLAS {name} Benchmark Results")
|
|
445
|
+
print(f"{'='*60}")
|
|
446
|
+
print(f" Questions evaluated : {n}")
|
|
447
|
+
print(f" Overall SQLAS score : {summary['overall_score']:.4f} / 1.0")
|
|
448
|
+
print(f" Pass rate : {summary['pass_rate']*100:.0f}%")
|
|
449
|
+
print(f" Execution accuracy : {summary['execution_accuracy']:.4f}")
|
|
450
|
+
print(f" Safety score : {summary['safety_score']:.4f}")
|
|
451
|
+
print(f"\n By difficulty:")
|
|
452
|
+
for d, score in sorted(summary["by_difficulty"].items()):
|
|
453
|
+
bar = "█" * int(score * 20) + "░" * (20 - int(score * 20))
|
|
454
|
+
print(f" {d:15s} [{bar}] {score:.4f}")
|
|
455
|
+
print(f"\n By query type:")
|
|
456
|
+
for t, score in sorted(summary["by_query_type"].items()):
|
|
457
|
+
bar = "█" * int(score * 20) + "░" * (20 - int(score * 20))
|
|
458
|
+
print(f" {t:15s} [{bar}] {score:.4f}")
|
|
459
|
+
print(f"{'='*60}\n")
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def _log_to_mlflow(experiment: str, results: dict, stats: dict, sampled: list):
|
|
463
|
+
"""Log benchmark results to MLflow experiment."""
|
|
464
|
+
try:
|
|
465
|
+
import mlflow
|
|
466
|
+
mlflow.set_experiment(experiment)
|
|
467
|
+
with mlflow.start_run():
|
|
468
|
+
summary = results["summary"]
|
|
469
|
+
for k, v in summary.items():
|
|
470
|
+
if isinstance(v, (int, float)):
|
|
471
|
+
mlflow.log_metric(k, v)
|
|
472
|
+
for k, v in stats.items():
|
|
473
|
+
if isinstance(v, (int, float)):
|
|
474
|
+
mlflow.log_metric(k, v)
|
|
475
|
+
mlflow.log_param("n_samples", len(sampled))
|
|
476
|
+
mlflow.log_param("dataset", summary.get("dataset", "unknown"))
|
|
477
|
+
except ImportError:
|
|
478
|
+
logger.warning("mlflow not installed — skipping benchmark logging")
|
|
479
|
+
except Exception as e:
|
|
480
|
+
logger.warning("mlflow logging failed: %s", e)
|