schema-search 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schema_search/__init__.py +26 -0
- schema_search/chunkers/__init__.py +6 -0
- schema_search/chunkers/base.py +95 -0
- schema_search/chunkers/factory.py +31 -0
- schema_search/chunkers/llm.py +54 -0
- schema_search/chunkers/markdown.py +25 -0
- schema_search/embedding_cache/__init__.py +5 -0
- schema_search/embedding_cache/base.py +40 -0
- schema_search/embedding_cache/bm25.py +63 -0
- schema_search/embedding_cache/factory.py +20 -0
- schema_search/embedding_cache/inmemory.py +122 -0
- schema_search/graph_builder.py +69 -0
- schema_search/mcp_server.py +81 -0
- schema_search/metrics.py +33 -0
- schema_search/rankers/__init__.py +5 -0
- schema_search/rankers/base.py +45 -0
- schema_search/rankers/cross_encoder.py +40 -0
- schema_search/rankers/factory.py +11 -0
- schema_search/schema_extractor.py +135 -0
- schema_search/schema_search.py +276 -0
- schema_search/search/__init__.py +15 -0
- schema_search/search/base.py +85 -0
- schema_search/search/bm25.py +48 -0
- schema_search/search/factory.py +61 -0
- schema_search/search/fuzzy.py +56 -0
- schema_search/search/hybrid.py +82 -0
- schema_search/search/semantic.py +49 -0
- schema_search/types.py +57 -0
- schema_search/utils/__init__.py +0 -0
- schema_search/utils/lazy_import.py +26 -0
- schema_search-0.1.10.dist-info/METADATA +308 -0
- schema_search-0.1.10.dist-info/RECORD +40 -0
- schema_search-0.1.10.dist-info/WHEEL +5 -0
- schema_search-0.1.10.dist-info/entry_points.txt +2 -0
- schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
- schema_search-0.1.10.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_integration.py +352 -0
- tests/test_llm_sql_generation.py +320 -0
- tests/test_spider_eval.py +488 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
"""Spider benchmark evaluation.
|
|
2
|
+
|
|
3
|
+
Warning: this test intentionally creates and drops PostgreSQL databases
|
|
4
|
+
matching Spider db_ids. Only run against an isolated server you are
|
|
5
|
+
comfortable wiping.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
import time
|
|
10
|
+
import json
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from typing import List, Set, Dict, Any
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pytest
|
|
17
|
+
from datasets import load_dataset
|
|
18
|
+
from sqlalchemy import create_engine, text
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from schema_search import SchemaSearch
|
|
21
|
+
from schema_search.types import SearchType
|
|
22
|
+
|
|
23
|
+
# WARNING: Do not swap this for an environment variable—this benchmark is meant
|
|
24
|
+
# to run against an explicit throwaway Postgres instance.
|
|
25
|
+
DATABASE_URL = "postgresql://user:pass@localhost/db"
|
|
26
|
+
MIN_TABLES = 10
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def parse_spider_schema(schema_str: str) -> Dict[str, List[tuple]]:
|
|
30
|
+
"""Parse spider-schema format into table definitions."""
|
|
31
|
+
tables = {}
|
|
32
|
+
table_parts = schema_str.split(" | ")
|
|
33
|
+
|
|
34
|
+
for table_part in table_parts:
|
|
35
|
+
if " : " not in table_part:
|
|
36
|
+
continue
|
|
37
|
+
|
|
38
|
+
table_name, columns_str = table_part.split(" : ", 1)
|
|
39
|
+
table_name = table_name.strip()
|
|
40
|
+
|
|
41
|
+
columns = []
|
|
42
|
+
col_parts = columns_str.split(" , ")
|
|
43
|
+
|
|
44
|
+
for col_part in col_parts:
|
|
45
|
+
match = re.match(r"(.+?)\s*\((\w+)\)", col_part.strip())
|
|
46
|
+
if match:
|
|
47
|
+
col_name = match.group(1).strip()
|
|
48
|
+
col_type = match.group(2).strip()
|
|
49
|
+
sql_type = "INTEGER" if col_type == "number" else "TEXT"
|
|
50
|
+
columns.append((col_name, sql_type))
|
|
51
|
+
|
|
52
|
+
if columns:
|
|
53
|
+
tables[table_name] = columns
|
|
54
|
+
|
|
55
|
+
return tables
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def extract_tables_from_sql(sql: str) -> Set[str]:
|
|
59
|
+
"""Extract table names from SQL query using regex.
|
|
60
|
+
|
|
61
|
+
Handles: aliases, quoted names, schemas, comma-separated joins.
|
|
62
|
+
"""
|
|
63
|
+
sql_lower = sql.lower()
|
|
64
|
+
|
|
65
|
+
def clean_table_name(name: str) -> str:
|
|
66
|
+
"""Strip quotes, brackets, schema prefix, and aliases."""
|
|
67
|
+
name = name.strip().strip('`"[]')
|
|
68
|
+
if "." in name:
|
|
69
|
+
name = name.split(".")[-1]
|
|
70
|
+
if " " in name:
|
|
71
|
+
name = name.split()[0]
|
|
72
|
+
return name.strip()
|
|
73
|
+
|
|
74
|
+
tables = set()
|
|
75
|
+
|
|
76
|
+
from_pattern = r"\bfrom\s+([`\"[]?[\w]+[`\"\]]?(?:\.[\w]+)?(?:\s+(?:as\s+)?\w+)?)"
|
|
77
|
+
join_pattern = r"\bjoin\s+([`\"[]?[\w]+[`\"\]]?(?:\.[\w]+)?(?:\s+(?:as\s+)?\w+)?)"
|
|
78
|
+
|
|
79
|
+
for match in re.finditer(from_pattern, sql_lower):
|
|
80
|
+
table = clean_table_name(match.group(1))
|
|
81
|
+
if table:
|
|
82
|
+
tables.add(table)
|
|
83
|
+
|
|
84
|
+
for match in re.finditer(join_pattern, sql_lower):
|
|
85
|
+
table = clean_table_name(match.group(1))
|
|
86
|
+
if table:
|
|
87
|
+
tables.add(table)
|
|
88
|
+
|
|
89
|
+
from_clause_match = re.search(
|
|
90
|
+
r"\bfrom\s+([^;]+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|$)",
|
|
91
|
+
sql_lower,
|
|
92
|
+
)
|
|
93
|
+
if from_clause_match:
|
|
94
|
+
from_clause = from_clause_match.group(1)
|
|
95
|
+
for table_expr in re.split(r"\s*,\s*", from_clause):
|
|
96
|
+
table = clean_table_name(table_expr)
|
|
97
|
+
if table and re.match(r"^[a-zA-Z_][\w]*$", table):
|
|
98
|
+
tables.add(table)
|
|
99
|
+
|
|
100
|
+
return tables
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def create_database_from_schema(db_id: str, schema_str: str):
|
|
104
|
+
"""Create PostgreSQL database from spider-schema string.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
tuple: (db_engine, num_tables)
|
|
108
|
+
"""
|
|
109
|
+
safe_db_id = db_id.lower().replace("-", "_")
|
|
110
|
+
|
|
111
|
+
assert (
|
|
112
|
+
"localhost" in DATABASE_URL
|
|
113
|
+
), "DATABASE_URL must be a test database on localhost"
|
|
114
|
+
|
|
115
|
+
admin_engine = create_engine(
|
|
116
|
+
DATABASE_URL, isolation_level="AUTOCOMMIT", pool_pre_ping=True, pool_recycle=60
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
with admin_engine.connect() as conn:
|
|
120
|
+
conn.execute(
|
|
121
|
+
text(
|
|
122
|
+
f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{safe_db_id}' AND pid != pg_backend_pid()"
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
conn.execute(text(f"DROP DATABASE IF EXISTS {safe_db_id}"))
|
|
126
|
+
conn.execute(text(f"CREATE DATABASE {safe_db_id}"))
|
|
127
|
+
|
|
128
|
+
admin_engine.dispose()
|
|
129
|
+
|
|
130
|
+
db_engine = create_engine(
|
|
131
|
+
f"postgresql://user:pass@localhost/{safe_db_id}",
|
|
132
|
+
pool_pre_ping=True,
|
|
133
|
+
pool_recycle=60,
|
|
134
|
+
)
|
|
135
|
+
tables = parse_spider_schema(schema_str)
|
|
136
|
+
|
|
137
|
+
with db_engine.connect() as conn:
|
|
138
|
+
for table_name, columns in tables.items():
|
|
139
|
+
col_defs = ", ".join(
|
|
140
|
+
[f'"{col_name}" {col_type}' for col_name, col_type in columns]
|
|
141
|
+
)
|
|
142
|
+
create_sql = f'CREATE TABLE "{table_name}" ({col_defs})'
|
|
143
|
+
conn.execute(text(create_sql))
|
|
144
|
+
conn.commit()
|
|
145
|
+
|
|
146
|
+
return db_engine, len(tables)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def calculate_recall_at_k(
|
|
150
|
+
predicted: List[str], ground_truth: Set[str], k: int
|
|
151
|
+
) -> float:
|
|
152
|
+
"""Calculate Recall@k = (relevant retrieved) / (total relevant).
|
|
153
|
+
|
|
154
|
+
For k=1, ground_truth={A,B,C}, predicted=[A,...]: Recall@1 = 1/3 = 0.33
|
|
155
|
+
"""
|
|
156
|
+
if not ground_truth:
|
|
157
|
+
return 0.0
|
|
158
|
+
|
|
159
|
+
predicted_at_k = set([p.lower() for p in predicted[:k]])
|
|
160
|
+
ground_truth_lower = set([g.lower() for g in ground_truth])
|
|
161
|
+
|
|
162
|
+
matches = len(predicted_at_k & ground_truth_lower)
|
|
163
|
+
return matches / len(ground_truth_lower)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def calculate_precision_at_k(
|
|
167
|
+
predicted: List[str], ground_truth: Set[str], k: int
|
|
168
|
+
) -> float:
|
|
169
|
+
"""Calculate Precision@k = (relevant retrieved) / k.
|
|
170
|
+
|
|
171
|
+
For k=1, ground_truth={A,B,C}, predicted=[A,...]: Precision@1 = 1/1 = 1.0
|
|
172
|
+
"""
|
|
173
|
+
if k == 0:
|
|
174
|
+
return 0.0
|
|
175
|
+
|
|
176
|
+
predicted_at_k = set([p.lower() for p in predicted[:k]])
|
|
177
|
+
ground_truth_lower = set([g.lower() for g in ground_truth])
|
|
178
|
+
|
|
179
|
+
matches = len(predicted_at_k & ground_truth_lower)
|
|
180
|
+
return matches / k
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def calculate_mrr(predicted: List[str], ground_truth: Set[str]) -> float:
|
|
184
|
+
"""Calculate Mean Reciprocal Rank = 1/rank of first correct item."""
|
|
185
|
+
if not ground_truth:
|
|
186
|
+
return 0.0
|
|
187
|
+
|
|
188
|
+
ground_truth_lower = set([g.lower() for g in ground_truth])
|
|
189
|
+
|
|
190
|
+
for rank, pred in enumerate(predicted, 1):
|
|
191
|
+
if pred.lower() in ground_truth_lower:
|
|
192
|
+
return 1.0 / rank
|
|
193
|
+
|
|
194
|
+
return 0.0
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def save_benchmark_results(results_by_strategy, index_latencies, strategies):
|
|
198
|
+
"""Save benchmark results as JSON."""
|
|
199
|
+
import yaml
|
|
200
|
+
|
|
201
|
+
config_path = Path(__file__).parent.parent / "config.yml"
|
|
202
|
+
with open(config_path) as f:
|
|
203
|
+
config = yaml.safe_load(f)
|
|
204
|
+
|
|
205
|
+
has_reranker = config.get("reranker", {}).get("model") is not None
|
|
206
|
+
reranker_model = config.get("reranker", {}).get("model")
|
|
207
|
+
|
|
208
|
+
img_dir = Path(__file__).parent.parent / "img"
|
|
209
|
+
img_dir.mkdir(exist_ok=True)
|
|
210
|
+
|
|
211
|
+
output = {
|
|
212
|
+
"reranker_enabled": has_reranker,
|
|
213
|
+
"reranker_model": reranker_model,
|
|
214
|
+
"indexing": {},
|
|
215
|
+
"strategies": {},
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
if index_latencies:
|
|
219
|
+
output["indexing"] = {
|
|
220
|
+
"num_databases": len(index_latencies),
|
|
221
|
+
"mean_latency": float(np.mean(index_latencies)),
|
|
222
|
+
"std_latency": float(np.std(index_latencies)),
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
for strategy in strategies:
|
|
226
|
+
stats = results_by_strategy[strategy]
|
|
227
|
+
if stats["num_queries"] == 0:
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
strategy_results = {"num_queries": stats["num_queries"]}
|
|
231
|
+
|
|
232
|
+
metric_names = [
|
|
233
|
+
"recall_at_1",
|
|
234
|
+
"recall_at_3",
|
|
235
|
+
"recall_at_5",
|
|
236
|
+
"mrr",
|
|
237
|
+
"precision_at_1",
|
|
238
|
+
"precision_at_3",
|
|
239
|
+
"precision_at_5",
|
|
240
|
+
"latency",
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
for metric_name in metric_names:
|
|
244
|
+
if metric_name in stats and stats[metric_name]:
|
|
245
|
+
values = np.array(stats[metric_name])
|
|
246
|
+
strategy_results[metric_name] = {
|
|
247
|
+
"mean": float(np.mean(values)),
|
|
248
|
+
"std": float(np.std(values)),
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
output["strategies"][strategy] = strategy_results
|
|
252
|
+
|
|
253
|
+
output_filename = (
|
|
254
|
+
"spider_benchmark_with_reranker.json"
|
|
255
|
+
if has_reranker
|
|
256
|
+
else "spider_benchmark_without_reranker.json"
|
|
257
|
+
)
|
|
258
|
+
output_path = img_dir / output_filename
|
|
259
|
+
|
|
260
|
+
with open(output_path, "w") as f:
|
|
261
|
+
json.dump(output, f, indent=2)
|
|
262
|
+
|
|
263
|
+
print(f"\nBenchmark results saved to: {output_path}")
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def cleanup_spider_databases():
|
|
267
|
+
"""Drop Spider test databases and clear cache before starting tests.
|
|
268
|
+
|
|
269
|
+
Only drops databases matching Spider naming pattern (alphanumeric + underscores).
|
|
270
|
+
"""
|
|
271
|
+
import shutil
|
|
272
|
+
from pathlib import Path
|
|
273
|
+
|
|
274
|
+
admin_engine = create_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
|
275
|
+
|
|
276
|
+
with admin_engine.connect() as conn:
|
|
277
|
+
result = conn.execute(
|
|
278
|
+
text(
|
|
279
|
+
"SELECT datname FROM pg_database WHERE datname NOT IN ('postgres', 'template0', 'template1', 'db')"
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
databases = [row[0] for row in result]
|
|
283
|
+
|
|
284
|
+
for db_name in databases:
|
|
285
|
+
if re.match(r"^[a-z0-9_]+$", db_name):
|
|
286
|
+
conn.execute(
|
|
287
|
+
text(
|
|
288
|
+
f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{db_name}'"
|
|
289
|
+
)
|
|
290
|
+
)
|
|
291
|
+
conn.execute(text(f"DROP DATABASE IF EXISTS {db_name}"))
|
|
292
|
+
|
|
293
|
+
admin_engine.dispose()
|
|
294
|
+
|
|
295
|
+
cache_path = Path("/tmp/.schema_search_cache")
|
|
296
|
+
if cache_path.exists():
|
|
297
|
+
shutil.rmtree(cache_path)
|
|
298
|
+
print(f"Cleared cache: {cache_path}")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@pytest.fixture(scope="module")
|
|
302
|
+
def spider_data():
|
|
303
|
+
"""Load Spider questions/queries and spider-schema definitions."""
|
|
304
|
+
cleanup_spider_databases()
|
|
305
|
+
|
|
306
|
+
spider = load_dataset("spider", split="train")
|
|
307
|
+
spider_schema = load_dataset("richardr1126/spider-schema", split="train")
|
|
308
|
+
|
|
309
|
+
schema_map = {ex["db_id"]: ex for ex in spider_schema}
|
|
310
|
+
|
|
311
|
+
return spider, schema_map
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def test_spider_evaluation(spider_data):
|
|
315
|
+
"""Evaluate schema search on Spider benchmark.
|
|
316
|
+
|
|
317
|
+
Tests table-level retrieval accuracy across multiple databases.
|
|
318
|
+
Metrics: Recall@1, Recall@3, Recall@5, MRR, Precision@k, Latency
|
|
319
|
+
"""
|
|
320
|
+
spider, schema_map = spider_data
|
|
321
|
+
|
|
322
|
+
results_by_strategy = defaultdict(
|
|
323
|
+
lambda: {
|
|
324
|
+
"recall_at_1": [],
|
|
325
|
+
"recall_at_3": [],
|
|
326
|
+
"recall_at_5": [],
|
|
327
|
+
"mrr": [],
|
|
328
|
+
"precision_at_1": [],
|
|
329
|
+
"precision_at_3": [],
|
|
330
|
+
"precision_at_5": [],
|
|
331
|
+
"latency": [],
|
|
332
|
+
"num_queries": 0,
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
strategies: List[SearchType] = ["hybrid", "fuzzy", "bm25", "semantic"]
|
|
337
|
+
|
|
338
|
+
current_db_id = None
|
|
339
|
+
search_engine = None
|
|
340
|
+
db_engine = None
|
|
341
|
+
index_latencies = []
|
|
342
|
+
num_dbs_filtered = 0
|
|
343
|
+
total_queries_evaluated = 0
|
|
344
|
+
|
|
345
|
+
for example in tqdm(spider, desc="Evaluating Spider"):
|
|
346
|
+
db_id = example["db_id"]
|
|
347
|
+
question = example["question"]
|
|
348
|
+
sql = example["query"]
|
|
349
|
+
|
|
350
|
+
if db_id not in schema_map:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
if db_id != current_db_id:
|
|
354
|
+
if db_engine is not None:
|
|
355
|
+
db_engine.dispose()
|
|
356
|
+
|
|
357
|
+
schema_str = schema_map[db_id]["Schema (values (type))"]
|
|
358
|
+
db_engine, num_tables = create_database_from_schema(db_id, schema_str)
|
|
359
|
+
|
|
360
|
+
if num_tables < MIN_TABLES:
|
|
361
|
+
num_dbs_filtered += 1
|
|
362
|
+
safe_db_id = db_id.lower().replace("-", "_")
|
|
363
|
+
db_engine.dispose()
|
|
364
|
+
|
|
365
|
+
admin_engine = create_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
|
366
|
+
with admin_engine.connect() as conn:
|
|
367
|
+
conn.execute(
|
|
368
|
+
text(
|
|
369
|
+
f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{safe_db_id}'"
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
conn.execute(text(f"DROP DATABASE IF EXISTS {safe_db_id}"))
|
|
373
|
+
admin_engine.dispose()
|
|
374
|
+
|
|
375
|
+
db_engine = None
|
|
376
|
+
search_engine = None
|
|
377
|
+
current_db_id = db_id
|
|
378
|
+
continue
|
|
379
|
+
|
|
380
|
+
search_engine = SchemaSearch(db_engine)
|
|
381
|
+
|
|
382
|
+
index_start = time.time()
|
|
383
|
+
search_engine.index(force=False)
|
|
384
|
+
index_latency = time.time() - index_start
|
|
385
|
+
index_latencies.append(index_latency)
|
|
386
|
+
|
|
387
|
+
current_db_id = db_id
|
|
388
|
+
|
|
389
|
+
if search_engine is None:
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
ground_truth_tables = extract_tables_from_sql(sql)
|
|
393
|
+
|
|
394
|
+
if not ground_truth_tables:
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
for strategy in strategies:
|
|
398
|
+
start_time = time.time()
|
|
399
|
+
response = search_engine.search(
|
|
400
|
+
question, search_type=strategy, limit=10, hops=1
|
|
401
|
+
)
|
|
402
|
+
latency = time.time() - start_time
|
|
403
|
+
|
|
404
|
+
predicted_tables = [r["table"] for r in response["results"]]
|
|
405
|
+
|
|
406
|
+
for k in [1, 3, 5]:
|
|
407
|
+
results_by_strategy[strategy][f"recall_at_{k}"].append(
|
|
408
|
+
calculate_recall_at_k(predicted_tables, ground_truth_tables, k)
|
|
409
|
+
)
|
|
410
|
+
results_by_strategy[strategy][f"precision_at_{k}"].append(
|
|
411
|
+
calculate_precision_at_k(predicted_tables, ground_truth_tables, k)
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
results_by_strategy[strategy]["mrr"].append(
|
|
415
|
+
calculate_mrr(predicted_tables, ground_truth_tables)
|
|
416
|
+
)
|
|
417
|
+
results_by_strategy[strategy]["latency"].append(latency)
|
|
418
|
+
results_by_strategy[strategy]["num_queries"] += 1
|
|
419
|
+
|
|
420
|
+
total_queries_evaluated += 1
|
|
421
|
+
|
|
422
|
+
print(f"\nFiltered out {num_dbs_filtered} databases with <{MIN_TABLES} tables")
|
|
423
|
+
print(
|
|
424
|
+
f"Evaluated {total_queries_evaluated} queries on {len(index_latencies)} databases"
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
if db_engine is not None:
|
|
428
|
+
db_engine.dispose()
|
|
429
|
+
|
|
430
|
+
print(f"\n{'='*80}")
|
|
431
|
+
print("FINAL RESULTS")
|
|
432
|
+
print(f"{'='*80}")
|
|
433
|
+
|
|
434
|
+
if index_latencies:
|
|
435
|
+
mean_index = np.mean(index_latencies)
|
|
436
|
+
std_index = np.std(index_latencies)
|
|
437
|
+
print(f"\nINDEXING")
|
|
438
|
+
print(f" Databases indexed: {len(index_latencies)}")
|
|
439
|
+
print(f" Index latency: {mean_index:.3f}s ± {std_index:.3f}s")
|
|
440
|
+
|
|
441
|
+
for strategy in strategies:
|
|
442
|
+
stats = results_by_strategy[strategy]
|
|
443
|
+
|
|
444
|
+
if stats["num_queries"] == 0:
|
|
445
|
+
continue
|
|
446
|
+
|
|
447
|
+
print(f"\n{strategy.upper()}")
|
|
448
|
+
print(f" Queries: {stats['num_queries']}")
|
|
449
|
+
|
|
450
|
+
metric_order = [
|
|
451
|
+
"recall_at_1",
|
|
452
|
+
"recall_at_3",
|
|
453
|
+
"recall_at_5",
|
|
454
|
+
"mrr",
|
|
455
|
+
"precision_at_1",
|
|
456
|
+
"precision_at_3",
|
|
457
|
+
"precision_at_5",
|
|
458
|
+
"latency",
|
|
459
|
+
]
|
|
460
|
+
metric_labels = {
|
|
461
|
+
"recall_at_1": "Recall@1",
|
|
462
|
+
"recall_at_3": "Recall@3",
|
|
463
|
+
"recall_at_5": "Recall@5",
|
|
464
|
+
"mrr": "MRR",
|
|
465
|
+
"precision_at_1": "Precision@1",
|
|
466
|
+
"precision_at_3": "Precision@3",
|
|
467
|
+
"precision_at_5": "Precision@5",
|
|
468
|
+
"latency": "Latency",
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
for metric_name in metric_order:
|
|
472
|
+
if metric_name in stats and stats[metric_name]:
|
|
473
|
+
values = np.array(stats[metric_name])
|
|
474
|
+
mean = np.mean(values)
|
|
475
|
+
std = np.std(values)
|
|
476
|
+
label = metric_labels[metric_name]
|
|
477
|
+
|
|
478
|
+
if metric_name == "latency":
|
|
479
|
+
print(f" {label}: {mean:.3f}s ± {std:.3f}s")
|
|
480
|
+
else:
|
|
481
|
+
print(f" {label}: {mean:.3f} ± {std:.3f}")
|
|
482
|
+
|
|
483
|
+
mean_recall_1 = np.mean(stats["recall_at_1"]) if stats["recall_at_1"] else 0.0
|
|
484
|
+
mean_mrr = np.mean(stats["mrr"]) if stats["mrr"] else 0.0
|
|
485
|
+
assert mean_recall_1 >= 0.0, f"{strategy}: Invalid recall@1"
|
|
486
|
+
assert mean_mrr >= 0.0, f"{strategy}: Invalid MRR"
|
|
487
|
+
|
|
488
|
+
save_benchmark_results(results_by_strategy, index_latencies, strategies)
|