sqlas 2.5.0__tar.gz → 2.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {sqlas-2.5.0/sqlas.egg-info → sqlas-2.6.0}/PKG-INFO +15 -1
  2. {sqlas-2.5.0 → sqlas-2.6.0}/pyproject.toml +9 -4
  3. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/__init__.py +3 -1
  4. sqlas-2.6.0/sqlas/benchmarks.py +480 -0
  5. sqlas-2.6.0/sqlas/integrations.py +274 -0
  6. sqlas-2.6.0/sqlas/ui.py +572 -0
  7. {sqlas-2.5.0 → sqlas-2.6.0/sqlas.egg-info}/PKG-INFO +15 -1
  8. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas.egg-info/SOURCES.txt +3 -0
  9. sqlas-2.6.0/sqlas.egg-info/requires.txt +31 -0
  10. sqlas-2.6.0/sqlas.egg-info/top_level.txt +2 -0
  11. sqlas-2.5.0/sqlas.egg-info/requires.txt +0 -12
  12. sqlas-2.5.0/sqlas.egg-info/top_level.txt +0 -1
  13. {sqlas-2.5.0 → sqlas-2.6.0}/LICENSE +0 -0
  14. {sqlas-2.5.0 → sqlas-2.6.0}/README.md +0 -0
  15. {sqlas-2.5.0 → sqlas-2.6.0}/setup.cfg +0 -0
  16. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/agentic.py +0 -0
  17. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/cache.py +0 -0
  18. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/context.py +0 -0
  19. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/core.py +0 -0
  20. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/correctness.py +0 -0
  21. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/evaluate.py +0 -0
  22. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/feedback.py +0 -0
  23. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/guardrails.py +0 -0
  24. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/production.py +0 -0
  25. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/prompt_registry.py +0 -0
  26. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/py.typed +0 -0
  27. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/quality.py +0 -0
  28. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/response.py +0 -0
  29. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/runner.py +0 -0
  30. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/safety.py +0 -0
  31. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/schema_quality.py +0 -0
  32. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas/visualization.py +0 -0
  33. {sqlas-2.5.0 → sqlas-2.6.0}/sqlas.egg-info/dependency_links.txt +0 -0
  34. {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_context.py +0 -0
  35. {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_execute_fn.py +0 -0
  36. {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_large_schema.py +0 -0
  37. {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_sqlas.py +0 -0
  38. {sqlas-2.5.0 → sqlas-2.6.0}/tests/test_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlas
3
- Version: 2.5.0
3
+ Version: 2.6.0
4
4
  Summary: SQLAS — SQL Agent Scoring Framework. Production-grade evaluation for Text-to-SQL and Agentic SQL agents with guardrail, visualization, agentic quality, and cache performance metrics.
5
5
  Author-email: thepradip <pradiptivhale@gmail.com>
6
6
  License-Expression: MIT
@@ -26,12 +26,26 @@ License-File: LICENSE
26
26
  Requires-Dist: sqlglot>=20.0
27
27
  Provides-Extra: mlflow
28
28
  Requires-Dist: mlflow>=3.0; extra == "mlflow"
29
+ Provides-Extra: wandb
30
+ Requires-Dist: wandb>=0.16; extra == "wandb"
31
+ Provides-Extra: langsmith
32
+ Requires-Dist: langsmith>=0.1; extra == "langsmith"
33
+ Provides-Extra: ui
34
+ Requires-Dist: streamlit>=1.30; extra == "ui"
35
+ Requires-Dist: pandas>=2.0; extra == "ui"
36
+ Provides-Extra: benchmarks
37
+ Provides-Extra: prometheus
38
+ Requires-Dist: prometheus-client>=0.19; extra == "prometheus"
29
39
  Provides-Extra: dev
30
40
  Requires-Dist: pytest>=7.0; extra == "dev"
31
41
  Requires-Dist: build; extra == "dev"
32
42
  Requires-Dist: twine; extra == "dev"
33
43
  Provides-Extra: all
34
44
  Requires-Dist: mlflow>=3.0; extra == "all"
45
+ Requires-Dist: wandb>=0.16; extra == "all"
46
+ Requires-Dist: langsmith>=0.1; extra == "all"
47
+ Requires-Dist: streamlit>=1.30; extra == "all"
48
+ Requires-Dist: pandas>=2.0; extra == "all"
35
49
  Dynamic: license-file
36
50
 
37
51
  # SQLAS — SQL Agent Scoring Framework
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sqlas"
7
- version = "2.5.0"
7
+ version = "2.6.0"
8
8
  description = "SQLAS — SQL Agent Scoring Framework. Production-grade evaluation for Text-to-SQL and Agentic SQL agents with guardrail, visualization, agentic quality, and cache performance metrics."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -29,9 +29,14 @@ dependencies = [
29
29
  ]
30
30
 
31
31
  [project.optional-dependencies]
32
- mlflow = ["mlflow>=3.0"]
33
- dev = ["pytest>=7.0", "build", "twine"]
34
- all = ["mlflow>=3.0"]
32
+ mlflow = ["mlflow>=3.0"]
33
+ wandb = ["wandb>=0.16"]
34
+ langsmith = ["langsmith>=0.1"]
35
+ ui = ["streamlit>=1.30", "pandas>=2.0"]
36
+ benchmarks = [] # Spider/BIRD datasets downloaded separately
37
+ prometheus = ["prometheus-client>=0.19"]
38
+ dev = ["pytest>=7.0", "build", "twine"]
39
+ all = ["mlflow>=3.0", "wandb>=0.16", "langsmith>=0.1", "streamlit>=1.30", "pandas>=2.0"]
35
40
 
36
41
  [project.urls]
37
42
  Homepage = "https://github.com/thepradip/SQLAS"
@@ -33,6 +33,8 @@ from sqlas.guardrails import GuardrailPipeline, GuardrailResult
33
33
  from sqlas.feedback import FeedbackStore, FeedbackEntry
34
34
  from sqlas.prompt_registry import PromptRegistry, PromptVersion
35
35
  from sqlas.schema_quality import schema_retrieval_quality, batch_retrieval_quality
36
+ from sqlas.benchmarks import run_spider_benchmark, run_bird_benchmark, download_instructions
37
+ from sqlas.integrations import log_to_mlflow, log_to_wandb, log_to_langsmith, log_all
36
38
  from sqlas.correctness import execution_accuracy, syntax_valid, semantic_equivalence, result_set_similarity
37
39
  from sqlas.quality import sql_quality, schema_compliance, complexity_match
38
40
  from sqlas.production import data_scan_efficiency, execution_result, result_coverage
@@ -51,7 +53,7 @@ from sqlas.agentic import (
51
53
  from sqlas.cache import cache_hit_score, tokens_saved_score, few_shot_score
52
54
  from sqlas.runner import run_suite
53
55
 
54
- __version__ = "2.5.0"
56
+ __version__ = "2.6.0"
55
57
  __author__ = "SQLAS Contributors"
56
58
 
57
59
  __all__ = [
@@ -0,0 +1,480 @@
1
+ """
2
+ SQLAS Benchmark Integration — Spider and BIRD datasets.
3
+
4
+ Closes the academic credibility gap vs RAGAS/DeepEval by integrating
5
+ the two standard NL2SQL benchmarks with smart sampling to keep costs low.
6
+
7
+ Token cost strategy:
8
+ - Default n_samples=50 → ~$0.25 with GPT-4o judge
9
+ - Safety metrics (free, no LLM) run on ALL sampled questions
10
+ - LLM judge only runs on questions that actually execute correctly
11
+ - Stratified sampling ensures representative difficulty distribution
12
+
13
+ Usage:
14
+ from sqlas.benchmarks import run_spider_benchmark, download_instructions
15
+
16
+ # Check dataset is available
17
+ print(download_instructions("spider"))
18
+
19
+ # Run benchmark (50 questions, stratified by difficulty)
20
+ results = run_spider_benchmark(
21
+ agent_fn = my_agent,
22
+ llm_judge = my_judge,
23
+ spider_dir = "./spider",
24
+ n_samples = 50,
25
+ difficulty = None, # None = all difficulties
26
+ query_types = None, # None = all types
27
+ seed = 42, # reproducible sampling
28
+ weights = WEIGHTS_V4,
29
+ mlflow_run = True,
30
+ verbose = True,
31
+ )
32
+
33
+ print(results["summary"]["overall_score"])
34
+ print(results["benchmark_stats"]["execution_accuracy"])
35
+ print(results["cost_estimate_usd"])
36
+ """
37
+
38
+ import json
39
+ import logging
40
+ import os
41
+ import random
42
+ import sqlite3
43
+ import time
44
+ from dataclasses import dataclass, field
45
+ from pathlib import Path
46
+ from typing import Optional
47
+
48
+ from sqlas.core import TestCase, WEIGHTS, WEIGHTS_V4, LLMJudge, ExecuteFn
49
+ from sqlas.evaluate import evaluate
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ # ── Dataset download instructions ─────────────────────────────────────────────
54
+
55
+ _DOWNLOAD_INSTRUCTIONS = {
56
+ "spider": """
57
+ Spider dataset not found. Download it:
58
+ 1. Visit: https://yale-lily.github.io/spider
59
+ 2. Download 'Spider 1.0 Dataset' zip
60
+ 3. Extract to a directory, e.g. ./spider/
61
+ 4. Pass: spider_dir='./spider'
62
+
63
+ Expected structure:
64
+ spider/
65
+ dev.json ← 1034 dev questions
66
+ tables.json ← schema metadata
67
+ database/
68
+ {db_id}/
69
+ {db_id}.sqlite ← SQLite databases
70
+ """,
71
+ "bird": """
72
+ BIRD dataset not found. Download it:
73
+ 1. Visit: https://bird-bench.github.io/
74
+ 2. Download 'BIRD-SQL dev set'
75
+ 3. Extract to a directory, e.g. ./bird/
76
+ 4. Pass: bird_dir='./bird'
77
+
78
+ Expected structure:
79
+ bird/
80
+ dev/
81
+ dev.json
82
+ dev_databases/
83
+ {db_id}/
84
+ {db_id}.sqlite
85
+ """,
86
+ }
87
+
88
+
89
+ def download_instructions(dataset: str = "spider") -> str:
90
+ return _DOWNLOAD_INSTRUCTIONS.get(dataset, f"Unknown dataset: {dataset}")
91
+
92
+
93
+ # ── Sampling ───────────────────────────────────────────────────────────────────
94
+
95
+ _DIFFICULTY_WEIGHTS = {
96
+ "easy": 0.20,
97
+ "medium": 0.30,
98
+ "hard": 0.30,
99
+ "extra hard": 0.20,
100
+ }
101
+
102
+ _SQL_TYPE_PATTERNS = {
103
+ "simple": lambda sql: "JOIN" not in sql.upper() and "GROUP BY" not in sql.upper(),
104
+ "aggregation": lambda sql: "GROUP BY" in sql.upper() or any(f in sql.upper() for f in ("COUNT(","SUM(","AVG(","MAX(","MIN(")),
105
+ "join": lambda sql: "JOIN" in sql.upper(),
106
+ "nested": lambda sql: sql.upper().count("SELECT") > 1,
107
+ }
108
+
109
+
110
+ def _sample_questions(
111
+ questions: list[dict],
112
+ n_samples: int,
113
+ difficulty: list[str] | None,
114
+ query_types: list[str] | None,
115
+ seed: int,
116
+ ) -> list[dict]:
117
+ """
118
+ Stratified sample by difficulty and optionally filter by query type.
119
+ Reproducible with fixed seed. Handles uneven difficulty distributions.
120
+ """
121
+ rng = random.Random(seed)
122
+
123
+ # Filter by difficulty
124
+ if difficulty:
125
+ dl = {d.lower() for d in difficulty}
126
+ questions = [q for q in questions if q.get("difficulty", "").lower() in dl]
127
+
128
+ # Filter by SQL type
129
+ if query_types:
130
+ filtered = []
131
+ for q in questions:
132
+ sql = q.get("query", "")
133
+ for qtype in query_types:
134
+ fn = _SQL_TYPE_PATTERNS.get(qtype)
135
+ if fn and fn(sql):
136
+ filtered.append(q)
137
+ break
138
+ questions = filtered
139
+
140
+ if not questions:
141
+ return []
142
+
143
+ # Group by difficulty
144
+ by_diff: dict[str, list] = {}
145
+ for q in questions:
146
+ d = q.get("difficulty", "medium").lower()
147
+ by_diff.setdefault(d, []).append(q)
148
+
149
+ # Proportional sample
150
+ sampled: list[dict] = []
151
+ for diff, weight in _DIFFICULTY_WEIGHTS.items():
152
+ pool = by_diff.get(diff, [])
153
+ n = max(1, round(n_samples * weight))
154
+ take = min(n, len(pool))
155
+ if pool:
156
+ sampled.extend(rng.sample(pool, take))
157
+
158
+ # Top up if total < n_samples due to rounding
159
+ remaining = [q for q in questions if q not in sampled]
160
+ rng.shuffle(remaining)
161
+ sampled.extend(remaining[: max(0, n_samples - len(sampled))])
162
+
163
+ return sampled[:n_samples]
164
+
165
+
166
+ # ── Spider benchmark ───────────────────────────────────────────────────────────
167
+
168
+ def run_spider_benchmark(
169
+ agent_fn,
170
+ llm_judge: LLMJudge,
171
+ spider_dir: str = "./spider",
172
+ n_samples: int = 50,
173
+ difficulty: list[str] | None = None,
174
+ query_types: list[str] | None = None,
175
+ seed: int = 42,
176
+ weights: dict | None = None,
177
+ pass_threshold: float = 0.6,
178
+ validate_chart_with_llm: bool = False, # off by default to save tokens
179
+ mlflow_run: bool = False,
180
+ verbose: bool = True,
181
+ ) -> dict:
182
+ """
183
+ Evaluate an SQL agent against the Spider benchmark with smart sampling.
184
+
185
+ Token-saving defaults:
186
+ n_samples=50 → ~$0.25 with GPT-4o (not $5-15 for full set)
187
+ Safety checks → free (no LLM), run on all samples
188
+ LLM judge → only when execution succeeds (skips failed queries)
189
+ No chart eval → validate_chart_with_llm=False
190
+
191
+ Args:
192
+ agent_fn: Function(question: str) -> {sql, response, data?}
193
+ llm_judge: LLM judge function (prompt: str) -> str
194
+ spider_dir: Path to extracted Spider dataset
195
+ n_samples: Questions to evaluate (default 50, full set = 1034)
196
+ difficulty: Filter by difficulty: ["easy","medium","hard","extra hard"]
197
+ query_types: Filter by type: ["simple","aggregation","join","nested"]
198
+ seed: Random seed for reproducible sampling
199
+ weights: SQLAS weight profile (default WEIGHTS_V4)
200
+ pass_threshold: Min score for PASS label
201
+ mlflow_run: Log to MLflow experiment
202
+ verbose: Print progress
203
+
204
+ Returns:
205
+ {summary, details, benchmark_stats, cost_estimate_usd, sample_info}
206
+ """
207
+ spider_path = Path(spider_dir)
208
+ dev_file = spider_path / "dev.json"
209
+ db_dir = spider_path / "database"
210
+
211
+ if not dev_file.exists():
212
+ raise FileNotFoundError(
213
+ f"Spider dev.json not found at {dev_file}\n{download_instructions('spider')}"
214
+ )
215
+
216
+ with open(dev_file) as f:
217
+ all_questions = json.load(f)
218
+
219
+ sampled = _sample_questions(all_questions, n_samples, difficulty, query_types, seed)
220
+
221
+ if verbose:
222
+ diff_dist = {}
223
+ for q in sampled:
224
+ d = q.get("difficulty","?")
225
+ diff_dist[d] = diff_dist.get(d, 0) + 1
226
+ print(f"\nSQLAS Spider Benchmark")
227
+ print(f" Dataset : Spider dev ({len(all_questions)} total)")
228
+ print(f" Sample : {len(sampled)} questions (seed={seed})")
229
+ print(f" Difficulty : {diff_dist}")
230
+ print(f" Est. cost : ~${len(sampled) * 0.005:.2f} (GPT-4o)\n")
231
+
232
+ results, benchmark_stats = _run_benchmark(
233
+ sampled, agent_fn, llm_judge, db_dir,
234
+ weights, pass_threshold, validate_chart_with_llm, verbose,
235
+ dataset_name="Spider",
236
+ )
237
+
238
+ if mlflow_run:
239
+ _log_to_mlflow("sqlas-spider-benchmark", results, benchmark_stats, sampled)
240
+
241
+ cost = len(sampled) * 0.005
242
+ return {
243
+ **results,
244
+ "benchmark_stats": benchmark_stats,
245
+ "cost_estimate_usd": round(cost, 3),
246
+ "sample_info": {
247
+ "total_in_dataset": len(all_questions),
248
+ "sampled": len(sampled),
249
+ "seed": seed,
250
+ "difficulty_filter": difficulty,
251
+ "type_filter": query_types,
252
+ },
253
+ }
254
+
255
+
256
+ # ── BIRD benchmark ─────────────────────────────────────────────────────────────
257
+
258
+ def run_bird_benchmark(
259
+ agent_fn,
260
+ llm_judge: LLMJudge,
261
+ bird_dir: str = "./bird",
262
+ n_samples: int = 50,
263
+ difficulty: list[str] | None = None,
264
+ seed: int = 42,
265
+ weights: dict | None = None,
266
+ pass_threshold: float = 0.6,
267
+ mlflow_run: bool = False,
268
+ verbose: bool = True,
269
+ ) -> dict:
270
+ """
271
+ Evaluate against BIRD benchmark (harder than Spider — real DBs with noise).
272
+ BIRD includes the Valid Efficiency Score (VES) — correct AND fast queries.
273
+
274
+ Same token-saving defaults as run_spider_benchmark().
275
+ """
276
+ bird_path = Path(bird_dir)
277
+ dev_file = bird_path / "dev" / "dev.json"
278
+ db_dir = bird_path / "dev" / "dev_databases"
279
+
280
+ if not dev_file.exists():
281
+ raise FileNotFoundError(
282
+ f"BIRD dev.json not found at {dev_file}\n{download_instructions('bird')}"
283
+ )
284
+
285
+ with open(dev_file) as f:
286
+ all_questions = json.load(f)
287
+
288
+ # BIRD uses "difficulty" field too (simple/moderate/challenging)
289
+ sampled = _sample_questions(all_questions, n_samples, difficulty, None, seed)
290
+
291
+ if verbose:
292
+ print(f"\nSQLAS BIRD Benchmark")
293
+ print(f" Dataset : BIRD dev ({len(all_questions)} total)")
294
+ print(f" Sample : {len(sampled)} questions")
295
+ print(f" Est. cost: ~${len(sampled) * 0.005:.2f}\n")
296
+
297
+ results, benchmark_stats = _run_benchmark(
298
+ sampled, agent_fn, llm_judge, db_dir,
299
+ weights, pass_threshold, False, verbose,
300
+ dataset_name="BIRD",
301
+ db_subdir=True, # BIRD has {db_dir}/{db_id}/{db_id}.sqlite
302
+ )
303
+
304
+ if mlflow_run:
305
+ _log_to_mlflow("sqlas-bird-benchmark", results, benchmark_stats, sampled)
306
+
307
+ return {
308
+ **results,
309
+ "benchmark_stats": benchmark_stats,
310
+ "cost_estimate_usd": round(len(sampled) * 0.005, 3),
311
+ "sample_info": {"total_in_dataset": len(all_questions), "sampled": len(sampled)},
312
+ }
313
+
314
+
315
+ # ── Shared runner ──────────────────────────────────────────────────────────────
316
+
317
+ def _run_benchmark(
318
+ sampled: list[dict],
319
+ agent_fn,
320
+ llm_judge: LLMJudge,
321
+ db_dir: Path,
322
+ weights: dict | None,
323
+ pass_threshold: float,
324
+ validate_chart: bool,
325
+ verbose: bool,
326
+ dataset_name: str,
327
+ db_subdir: bool = False,
328
+ ) -> tuple[dict, dict]:
329
+ """Core benchmark runner shared by Spider and BIRD."""
330
+ from sqlas.core import SQLASScores
331
+
332
+ w = weights or WEIGHTS_V4
333
+ all_scores: list[SQLASScores] = []
334
+ by_difficulty: dict[str, list[float]] = {}
335
+ by_type: dict[str, list[float]] = {}
336
+ exec_successes = 0
337
+ start = time.perf_counter()
338
+
339
+ for i, q in enumerate(sampled):
340
+ db_id = q.get("db_id", "")
341
+ question = q.get("question", "")
342
+ gold_sql = q.get("query", q.get("SQL", ""))
343
+ diff = q.get("difficulty", "medium")
344
+
345
+ # Locate the SQLite database
346
+ if db_subdir:
347
+ db_path = str(db_dir / db_id / f"{db_id}.sqlite")
348
+ else:
349
+ db_path = str(db_dir / db_id / f"{db_id}.sqlite")
350
+
351
+ if not os.path.exists(db_path):
352
+ if verbose:
353
+ print(f" SKIP [{i+1}/{len(sampled)}] DB not found: {db_path}")
354
+ continue
355
+
356
+ if verbose:
357
+ print(f" [{i+1}/{len(sampled)}] {diff:12s} | {question[:60]}...")
358
+
359
+ # Run agent
360
+ try:
361
+ result = agent_fn(question)
362
+ except Exception as e:
363
+ logger.warning("agent_fn failed: %s", e)
364
+ result = {"sql": "", "response": str(e), "data": None, "success": False}
365
+
366
+ # Build execute_fn for this specific database
367
+ def make_execute_fn(path: str):
368
+ def execute_fn(sql: str) -> list[tuple]:
369
+ conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
370
+ try:
371
+ return conn.execute(sql).fetchall()
372
+ finally:
373
+ conn.close()
374
+ return execute_fn
375
+
376
+ exec_fn = make_execute_fn(db_path)
377
+
378
+ # Determine SQL type for breakdown
379
+ sql_type = "simple"
380
+ sql_upper = gold_sql.upper()
381
+ if sql_upper.count("SELECT") > 1:
382
+ sql_type = "nested"
383
+ elif "JOIN" in sql_upper:
384
+ sql_type = "join"
385
+ elif "GROUP BY" in sql_upper or any(f in sql_upper for f in ("COUNT(","SUM(","AVG(")):
386
+ sql_type = "aggregation"
387
+
388
+ # Evaluate — skip LLM judge if query failed to save tokens
389
+ use_llm = result.get("success", False) or bool(result.get("sql", "").strip())
390
+ scores = evaluate(
391
+ question = question,
392
+ generated_sql = result.get("sql", ""),
393
+ llm_judge = llm_judge if use_llm else lambda p: "Semantic_Score: 0.0\nReasoning: Skipped — execution failed.",
394
+ gold_sql = gold_sql,
395
+ db_path = db_path,
396
+ execute_fn = exec_fn,
397
+ response = result.get("response"),
398
+ result_data = result.get("data"),
399
+ validate_chart_with_llm = validate_chart,
400
+ weights = w,
401
+ )
402
+
403
+ all_scores.append(scores)
404
+ by_difficulty.setdefault(diff, []).append(scores.overall_score)
405
+ by_type.setdefault(sql_type, []).append(scores.overall_score)
406
+ if scores.execution_success == 1.0:
407
+ exec_successes += 1
408
+
409
+ elapsed = time.perf_counter() - start
410
+ n = len(all_scores)
411
+ avg = lambda attr: round(sum(getattr(s, attr, 0) for s in all_scores) / max(n, 1), 4)
412
+
413
+ summary = {
414
+ "dataset": dataset_name,
415
+ "n_evaluated": n,
416
+ "time_seconds": round(elapsed, 1),
417
+ "overall_score": avg("overall_score"),
418
+ "pass_rate": round(sum(1 for s in all_scores if s.overall_score >= pass_threshold) / max(n, 1), 4),
419
+ "execution_accuracy": avg("execution_accuracy"),
420
+ "semantic_equivalence": avg("semantic_equivalence"),
421
+ "faithfulness": avg("faithfulness"),
422
+ "safety_score": avg("safety_score"),
423
+ "sql_quality": avg("sql_quality"),
424
+ "by_difficulty": {d: round(sum(v)/len(v), 4) for d, v in by_difficulty.items()},
425
+ "by_query_type": {t: round(sum(v)/len(v), 4) for t, v in by_type.items()},
426
+ }
427
+
428
+ benchmark_stats = {
429
+ "execution_success_rate": round(exec_successes / max(n, 1), 4),
430
+ "avg_correctness_score": avg("correctness_score"),
431
+ "avg_quality_score": avg("quality_score"),
432
+ "avg_safety_score": avg("safety_composite_score"),
433
+ }
434
+
435
+ if verbose:
436
+ _print_benchmark_report(summary, benchmark_stats, dataset_name)
437
+
438
+ return {"summary": summary, "details": all_scores}, benchmark_stats
439
+
440
+
441
+ def _print_benchmark_report(summary: dict, stats: dict, name: str):
442
+ n = summary["n_evaluated"]
443
+ print(f"\n{'='*60}")
444
+ print(f" SQLAS {name} Benchmark Results")
445
+ print(f"{'='*60}")
446
+ print(f" Questions evaluated : {n}")
447
+ print(f" Overall SQLAS score : {summary['overall_score']:.4f} / 1.0")
448
+ print(f" Pass rate : {summary['pass_rate']*100:.0f}%")
449
+ print(f" Execution accuracy : {summary['execution_accuracy']:.4f}")
450
+ print(f" Safety score : {summary['safety_score']:.4f}")
451
+ print(f"\n By difficulty:")
452
+ for d, score in sorted(summary["by_difficulty"].items()):
453
+ bar = "█" * int(score * 20) + "░" * (20 - int(score * 20))
454
+ print(f" {d:15s} [{bar}] {score:.4f}")
455
+ print(f"\n By query type:")
456
+ for t, score in sorted(summary["by_query_type"].items()):
457
+ bar = "█" * int(score * 20) + "░" * (20 - int(score * 20))
458
+ print(f" {t:15s} [{bar}] {score:.4f}")
459
+ print(f"{'='*60}\n")
460
+
461
+
462
+ def _log_to_mlflow(experiment: str, results: dict, stats: dict, sampled: list):
463
+ """Log benchmark results to MLflow experiment."""
464
+ try:
465
+ import mlflow
466
+ mlflow.set_experiment(experiment)
467
+ with mlflow.start_run():
468
+ summary = results["summary"]
469
+ for k, v in summary.items():
470
+ if isinstance(v, (int, float)):
471
+ mlflow.log_metric(k, v)
472
+ for k, v in stats.items():
473
+ if isinstance(v, (int, float)):
474
+ mlflow.log_metric(k, v)
475
+ mlflow.log_param("n_samples", len(sampled))
476
+ mlflow.log_param("dataset", summary.get("dataset", "unknown"))
477
+ except ImportError:
478
+ logger.warning("mlflow not installed — skipping benchmark logging")
479
+ except Exception as e:
480
+ logger.warning("mlflow logging failed: %s", e)