visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/evaluation.py ADDED
@@ -0,0 +1,602 @@
1
+ """Evaluation runner with UI updates."""
2
+
3
+ import hashlib
4
+ import importlib.util
5
+ import json
6
+ import logging
7
+ import time
8
+ import traceback
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import numpy as np
14
+ import streamlit as st
15
+ import torch
16
+
17
+ from visual_rag import VisualEmbedder
18
+
19
+
20
+ TORCH_DTYPE_MAP = {
21
+ "float16": torch.float16,
22
+ "float32": torch.float32,
23
+ "bfloat16": torch.bfloat16,
24
+ }
25
+ from qdrant_client.models import Filter, FieldCondition, MatchValue
26
+
27
+ from visual_rag.retrieval import MultiVectorRetriever
28
+
29
+
30
+ def _load_local_benchmark_module(module_filename: str):
31
+ """
32
+ Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
33
+
34
+ Motivation:
35
+ - Some environments (notably containers / Spaces) can have a third-party
36
+ `benchmarks` package installed, causing `import benchmarks...` to resolve
37
+ to the wrong module.
38
+ - This fallback guarantees we load the repo's benchmark utilities.
39
+ """
40
+ root = Path(__file__).resolve().parents[1] # demo/.. = repo root
41
+ target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
42
+ if not target.exists():
43
+ raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
44
+
45
+ name = f"_visual_rag_toolkit_local_{target.stem}"
46
+ spec = importlib.util.spec_from_file_location(name, str(target))
47
+ if spec is None or spec.loader is None:
48
+ raise ModuleNotFoundError(f"Could not load module spec for: {target}")
49
+ mod = importlib.util.module_from_spec(spec)
50
+ spec.loader.exec_module(mod) # type: ignore[attr-defined]
51
+ return mod
52
+
53
+
54
+ try:
55
+ # Preferred: normal import
56
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
57
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
58
+ except ModuleNotFoundError:
59
+ # Robust fallback: load from local file paths
60
+ _dl = _load_local_benchmark_module("dataset_loader.py")
61
+ _mx = _load_local_benchmark_module("metrics.py")
62
+ load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
63
+ ndcg_at_k = _mx.ndcg_at_k
64
+ mrr_at_k = _mx.mrr_at_k
65
+ recall_at_k = _mx.recall_at_k
66
+
67
+ from demo.qdrant_utils import get_qdrant_credentials
68
+
69
+ logger = logging.getLogger(__name__)
70
+ logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
71
+
72
+
73
+ def _stable_uuid(text: str) -> str:
74
+ """Generate a stable UUID from text (same as benchmark script)."""
75
+ hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
76
+ return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
77
+
78
+
79
+ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
80
+ """Generate union point ID (same as benchmark script)."""
81
+ ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
82
+ return _stable_uuid(f"{ns}::{source_doc_id}")
83
+
84
+
85
+ def _remap_qrels_to_union_ids(
86
+ qrels: Dict[str, Dict[str, int]],
87
+ corpus: List[Any],
88
+ dataset_name: str,
89
+ collection_name: str,
90
+ ) -> Dict[str, Dict[str, int]]:
91
+ """Remap qrels doc_ids from original format to union_doc_id format (matching benchmark)."""
92
+ id_map: Dict[str, str] = {}
93
+ for doc in corpus:
94
+ source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
95
+ id_map[str(doc.doc_id)] = _union_point_id(
96
+ dataset_name=dataset_name,
97
+ source_doc_id=source_doc_id,
98
+ union_namespace=collection_name,
99
+ )
100
+
101
+ remapped: Dict[str, Dict[str, int]] = {}
102
+ for qid, rels in qrels.items():
103
+ out_rels: Dict[str, int] = {}
104
+ for did, score in rels.items():
105
+ mapped = id_map.get(str(did))
106
+ if mapped:
107
+ out_rels[mapped] = int(score)
108
+ if out_rels:
109
+ remapped[qid] = out_rels
110
+ return remapped
111
+
112
+
113
+ def get_doc_id_from_result(r: Dict[str, Any], use_original: bool = True) -> str:
114
+ """Extract document ID from search result.
115
+
116
+ Args:
117
+ r: Search result dict with 'id' and 'payload'
118
+ use_original: If True, prefer original doc_id for matching with qrels.
119
+ If False, prefer union_doc_id (Qdrant point ID).
120
+ """
121
+ payload = r.get("payload", {})
122
+ if use_original:
123
+ doc_id = (
124
+ payload.get("doc_id")
125
+ or payload.get("source_doc_id")
126
+ or payload.get("corpus-id")
127
+ or payload.get("union_doc_id")
128
+ or str(r.get("id", ""))
129
+ )
130
+ else:
131
+ doc_id = (
132
+ payload.get("union_doc_id")
133
+ or str(r.get("id", ""))
134
+ or payload.get("doc_id")
135
+ )
136
+ return str(doc_id)
137
+
138
+
139
+ def run_evaluation_with_ui(config: Dict[str, Any]):
140
+ st.divider()
141
+
142
+ print("=" * 60)
143
+ print("[EVAL] Starting evaluation via UI")
144
+ print("=" * 60)
145
+
146
+ url, api_key = get_qdrant_credentials()
147
+ if not url:
148
+ st.error("QDRANT_URL not configured")
149
+ return
150
+
151
+ datasets = config.get("datasets", [])
152
+ collection = config["collection"]
153
+ model = config.get("model", "vidore/colpali-v1.3")
154
+ mode = config.get("mode", "single_full")
155
+ top_k = config.get("top_k", 100)
156
+ prefetch_k = config.get("prefetch_k", 256)
157
+ stage1_mode = config.get("stage1_mode", "tokens_vs_tiles")
158
+ stage1_k = config.get("stage1_k", 1000)
159
+ stage2_k = config.get("stage2_k", 300)
160
+ prefer_grpc = config.get("prefer_grpc", True)
161
+ torch_dtype = config.get("torch_dtype", "float16")
162
+ evaluation_scope = config.get("evaluation_scope", "union")
163
+
164
+ print(f"[EVAL] ═══════════════════════════════════════════════════")
165
+ print(f"[EVAL] Collection: {collection}")
166
+ print(f"[EVAL] Model: {model}")
167
+ print(f"[EVAL] Mode: {mode}, Scope: {evaluation_scope}")
168
+ print(f"[EVAL] Datasets: {datasets}")
169
+ print(f"[EVAL] Query embedding dtype: {torch_dtype} (vectors already indexed)")
170
+ print(f"[EVAL] ═══════════════════════════════════════════════════")
171
+
172
+ phase1_container = st.container()
173
+ phase2_container = st.container()
174
+ phase3_container = st.container()
175
+ results_container = st.container()
176
+
177
+ try:
178
+ with phase1_container:
179
+ st.markdown("##### 🤖 Phase 1: Loading Model")
180
+ model_status = st.empty()
181
+ model_status.info(f"Loading `{model.split('/')[-1]}`...")
182
+
183
+ print(f"[EVAL] Loading embedder: {model}")
184
+ torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
185
+ qdrant_dtype = config.get("qdrant_vector_dtype", "float16")
186
+ output_dtype_obj = np.float16 if qdrant_dtype == "float16" else np.float32
187
+ embedder = VisualEmbedder(
188
+ model_name=model,
189
+ torch_dtype=torch_dtype_obj,
190
+ output_dtype=output_dtype_obj,
191
+ )
192
+ embedder._load_model()
193
+ print(f"[EVAL] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_dtype})")
194
+
195
+ model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
196
+
197
+ retriever_status = st.empty()
198
+ retriever_status.info(f"Connecting to collection `{collection}`...")
199
+
200
+ print(f"[EVAL] Connecting to Qdrant collection: {collection}")
201
+ retriever = MultiVectorRetriever(
202
+ collection_name=collection,
203
+ model_name=model,
204
+ qdrant_url=url,
205
+ qdrant_api_key=api_key,
206
+ prefer_grpc=prefer_grpc,
207
+ embedder=embedder,
208
+ )
209
+ print(f"[EVAL] Connected to Qdrant")
210
+ retriever_status.success(f"✅ Connected to `{collection}`")
211
+
212
+ with phase2_container:
213
+ st.markdown("##### 📚 Phase 2: Loading Datasets")
214
+
215
+ dataset_data = {}
216
+ total_queries = 0
217
+ max_queries_per_ds = config.get("max_queries")
218
+
219
+ for ds_name in datasets:
220
+ ds_status = st.empty()
221
+ ds_short = ds_name.split("/")[-1]
222
+ ds_status.info(f"Loading `{ds_short}`...")
223
+
224
+ print(f"[EVAL] Loading dataset: {ds_name}")
225
+ corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
226
+
227
+ print(f"[EVAL] Remapping qrels to union_doc_id format for collection={collection}")
228
+ remapped_qrels = _remap_qrels_to_union_ids(qrels, corpus, ds_name, collection)
229
+ print(f"[EVAL] Remapped {len(qrels)} -> {len(remapped_qrels)} queries with valid rels")
230
+
231
+ if evaluation_scope == "per_dataset" and max_queries_per_ds:
232
+ queries = queries[:max_queries_per_ds]
233
+
234
+ dataset_data[ds_name] = {
235
+ "queries": queries,
236
+ "qrels": remapped_qrels,
237
+ "num_docs": len(corpus),
238
+ }
239
+ total_queries += len(queries)
240
+ print(f"[EVAL] Loaded {ds_name}: {len(corpus)} docs, {len(queries)} queries")
241
+ ds_status.success(f"✅ `{ds_short}`: {len(corpus)} docs, {len(queries)} queries")
242
+
243
+ if evaluation_scope == "union" and max_queries_per_ds and max_queries_per_ds < total_queries:
244
+ total_queries = max_queries_per_ds
245
+ print(f"[EVAL] Will limit to {total_queries} total queries (union mode)")
246
+
247
+ embed_status = st.empty()
248
+ embed_status.info(f"Embedding queries...")
249
+
250
+ with phase3_container:
251
+ st.markdown("##### 🎯 Phase 3: Running Evaluation")
252
+
253
+ metrics_collectors = {
254
+ "ndcg@5": [], "ndcg@10": [],
255
+ "recall@5": [], "recall@10": [],
256
+ "mrr@5": [], "mrr@10": [],
257
+ }
258
+ latencies = []
259
+ log_lines = []
260
+ metrics_by_dataset = {}
261
+
262
+ if evaluation_scope == "per_dataset":
263
+ overall_progress = st.progress(0.0)
264
+ datasets_done = 0
265
+
266
+ for ds_name, ds_info in dataset_data.items():
267
+ ds_short = ds_name.split("/")[-1]
268
+ st.markdown(f"**Evaluating `{ds_short}`**")
269
+
270
+ queries = ds_info["queries"]
271
+ qrels = ds_info["qrels"]
272
+
273
+ if not queries:
274
+ continue
275
+
276
+ print(f"[EVAL] Embedding {len(queries)} queries for {ds_short}...")
277
+ query_texts = [q.text for q in queries]
278
+ query_embeddings = embedder.embed_queries(query_texts, show_progress=False)
279
+ print(f"[EVAL] Queries embedded for {ds_short}")
280
+
281
+ ds_filter = Filter(
282
+ must=[FieldCondition(key="dataset", match=MatchValue(value=ds_name))]
283
+ )
284
+ print(f"[EVAL] Using filter: dataset={ds_name}")
285
+
286
+ progress_bar = st.progress(0.0)
287
+ eval_status = st.empty()
288
+ log_area = st.empty()
289
+
290
+ ds_metrics = {"ndcg@5": [], "ndcg@10": [], "recall@5": [], "recall@10": [], "mrr@5": [], "mrr@10": []}
291
+ ds_latencies = []
292
+ ds_log_lines = []
293
+
294
+ eval_status.info(f"Evaluating {len(queries)} queries...")
295
+ print(f"[EVAL] Starting per-dataset evaluation: {ds_short}, {len(queries)} queries")
296
+
297
+ for i, (q, qemb) in enumerate(zip(queries, query_embeddings)):
298
+ start = time.time()
299
+
300
+ if isinstance(qemb, torch.Tensor):
301
+ qemb_np = qemb.detach().cpu().numpy()
302
+ else:
303
+ qemb_np = qemb.numpy() if hasattr(qemb, 'numpy') else np.array(qemb)
304
+
305
+ results = retriever.search_embedded(
306
+ query_embedding=qemb_np,
307
+ top_k=max(100, top_k),
308
+ mode=mode,
309
+ prefetch_k=prefetch_k,
310
+ stage1_mode=stage1_mode,
311
+ stage1_k=stage1_k,
312
+ stage2_k=stage2_k,
313
+ filter_obj=ds_filter,
314
+ )
315
+ ds_latencies.append((time.time() - start) * 1000)
316
+ latencies.append(ds_latencies[-1])
317
+
318
+ ranking = [str(r["id"]) for r in results]
319
+ rels = qrels.get(q.query_id, {})
320
+
321
+ if i == 0:
322
+ print(f"[EVAL] First query for {ds_short} - query_id: {q.query_id}")
323
+ print(f"[EVAL] Top 3 retrieved doc_ids: {ranking[:3]}")
324
+ print(f"[EVAL] Expected doc_ids (qrels): {list(rels.keys())[:3]}")
325
+ print(f"[EVAL] qrels has {len(qrels)} queries, this query in qrels: {q.query_id in qrels}")
326
+ if results:
327
+ r0 = results[0]
328
+ print(f"[EVAL] Sample result payload keys: {list(r0.get('payload', {}).keys())}")
329
+ p = r0.get("payload", {})
330
+ print(f"[EVAL] Sample payload doc_id={p.get('doc_id')}, union_doc_id={p.get('union_doc_id')}, source_doc_id={p.get('source_doc_id')}")
331
+ has_match = any(rid in rels for rid in ranking[:10])
332
+ print(f"[EVAL] Any match in top-10? {has_match}")
333
+
334
+ for k_name, k_val in [("ndcg@5", 5), ("ndcg@10", 10)]:
335
+ ds_metrics[k_name].append(ndcg_at_k(ranking, rels, k=k_val))
336
+ for k_name, k_val in [("recall@5", 5), ("recall@10", 10)]:
337
+ ds_metrics[k_name].append(recall_at_k(ranking, rels, k=k_val))
338
+ for k_name, k_val in [("mrr@5", 5), ("mrr@10", 10)]:
339
+ ds_metrics[k_name].append(mrr_at_k(ranking, rels, k=k_val))
340
+
341
+ progress = (i + 1) / len(queries)
342
+ progress_bar.progress(progress)
343
+ eval_status.info(f"🎯 {i+1}/{len(queries)} ({int(progress*100)}%) — latency: {np.mean(ds_latencies):.0f}ms")
344
+
345
+ log_interval = max(5, len(queries) // 10)
346
+ if (i + 1) % log_interval == 0 and i > 0:
347
+ cur_ndcg = np.mean(ds_metrics["ndcg@10"])
348
+ cur_lat = np.mean(ds_latencies[1:]) if len(ds_latencies) > 1 else ds_latencies[0]
349
+ ds_log_lines.append(f"[{i+1}/{len(queries)}] NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
350
+ log_area.code("\n".join(ds_log_lines[-5:]), language="text")
351
+ print(f"[EVAL] {ds_short} {i+1}/{len(queries)}: NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
352
+
353
+ progress_bar.progress(1.0)
354
+ ds_final = {k: float(np.mean(v)) for k, v in ds_metrics.items()}
355
+ ds_final["avg_latency_ms"] = float(np.mean(ds_latencies))
356
+ ds_final["num_queries"] = len(queries)
357
+ metrics_by_dataset[ds_name] = ds_final
358
+
359
+ for k, v in ds_metrics.items():
360
+ metrics_collectors[k].extend(v)
361
+
362
+ eval_status.success(f"✅ `{ds_short}`: NDCG@10={ds_final['ndcg@10']:.4f}, latency={ds_final['avg_latency_ms']:.0f}ms")
363
+ print(f"[EVAL] {ds_short} DONE: NDCG@10={ds_final['ndcg@10']:.4f}")
364
+
365
+ datasets_done += 1
366
+ overall_progress.progress(datasets_done / len(datasets))
367
+
368
+ overall_progress.progress(1.0)
369
+ embed_status.success(f"✅ All queries embedded")
370
+ total_queries = sum(d["num_queries"] for d in metrics_by_dataset.values())
371
+
372
+ else:
373
+ all_queries = []
374
+ all_qrels = {}
375
+ for ds_name, ds_info in dataset_data.items():
376
+ all_queries.extend(ds_info["queries"])
377
+ for qid, rels in ds_info["qrels"].items():
378
+ all_qrels[qid] = rels
379
+
380
+ sample_qrel_keys = list(all_qrels.keys())[:3]
381
+ sample_doc_ids = []
382
+ for qid in sample_qrel_keys:
383
+ sample_doc_ids.extend(list(all_qrels[qid].keys())[:2])
384
+ print(f"[EVAL] all_qrels built: {len(all_qrels)} queries")
385
+ print(f"[EVAL] Sample qrel query_ids: {sample_qrel_keys}")
386
+ print(f"[EVAL] Sample qrel doc_ids: {sample_doc_ids[:5]}")
387
+
388
+ max_q = config.get("max_queries")
389
+ if max_q and max_q < len(all_queries):
390
+ all_queries = all_queries[:max_q]
391
+ total_queries = len(all_queries)
392
+
393
+ print(f"[EVAL] Embedding {total_queries} queries (union mode)...")
394
+ query_texts = [q.text for q in all_queries]
395
+ query_embeddings = embedder.embed_queries(query_texts, show_progress=False)
396
+ print(f"[EVAL] Queries embedded")
397
+ embed_status.success(f"✅ {total_queries} queries embedded")
398
+
399
+ progress_bar = st.progress(0.0)
400
+ eval_status = st.empty()
401
+ log_area = st.empty()
402
+
403
+ eval_status.info(f"Evaluating {total_queries} queries in `{mode}` mode...")
404
+ print(f"[EVAL] Starting union evaluation: {total_queries} queries, mode={mode}")
405
+
406
+ for i, (q, qemb) in enumerate(zip(all_queries, query_embeddings)):
407
+ start = time.time()
408
+
409
+ if isinstance(qemb, torch.Tensor):
410
+ qemb_np = qemb.detach().cpu().numpy()
411
+ else:
412
+ qemb_np = qemb.numpy() if hasattr(qemb, 'numpy') else np.array(qemb)
413
+
414
+ results = retriever.search_embedded(
415
+ query_embedding=qemb_np,
416
+ top_k=max(100, top_k),
417
+ mode=mode,
418
+ prefetch_k=prefetch_k,
419
+ stage1_mode=stage1_mode,
420
+ stage1_k=stage1_k,
421
+ stage2_k=stage2_k,
422
+ )
423
+ latencies.append((time.time() - start) * 1000)
424
+
425
+ ranking = [str(r["id"]) for r in results]
426
+ rels = all_qrels.get(q.query_id, {})
427
+
428
+ if i == 0:
429
+ print(f"[EVAL] First query - query_id: {q.query_id}")
430
+ print(f"[EVAL] Top 3 retrieved doc_ids: {ranking[:3]}")
431
+ print(f"[EVAL] Expected doc_ids (qrels): {list(rels.keys())[:3]}")
432
+ print(f"[EVAL] all_qrels has {len(all_qrels)} queries, this query in qrels: {q.query_id in all_qrels}")
433
+ if results:
434
+ r0 = results[0]
435
+ print(f"[EVAL] Sample result payload keys: {list(r0.get('payload', {}).keys())}")
436
+ p = r0.get("payload", {})
437
+ print(f"[EVAL] Sample payload doc_id={p.get('doc_id')}, union_doc_id={p.get('union_doc_id')}, source_doc_id={p.get('source_doc_id')}")
438
+ has_match = any(rid in rels for rid in ranking[:10])
439
+ print(f"[EVAL] Any match in top-10? {has_match}")
440
+
441
+ metrics_collectors["ndcg@5"].append(ndcg_at_k(ranking, rels, k=5))
442
+ metrics_collectors["ndcg@10"].append(ndcg_at_k(ranking, rels, k=10))
443
+ metrics_collectors["recall@5"].append(recall_at_k(ranking, rels, k=5))
444
+ metrics_collectors["recall@10"].append(recall_at_k(ranking, rels, k=10))
445
+ metrics_collectors["mrr@5"].append(mrr_at_k(ranking, rels, k=5))
446
+ metrics_collectors["mrr@10"].append(mrr_at_k(ranking, rels, k=10))
447
+
448
+ progress = (i + 1) / total_queries
449
+ progress_bar.progress(progress)
450
+ eval_status.info(f"🎯 {i+1}/{total_queries} ({int(progress*100)}%) — latency: {np.mean(latencies):.0f}ms")
451
+
452
+ log_interval = max(10, total_queries // 10)
453
+ if (i + 1) % log_interval == 0 and i > 0:
454
+ cur_ndcg = np.mean(metrics_collectors["ndcg@10"])
455
+ cur_lat = np.mean(latencies[1:]) if len(latencies) > 1 else latencies[0]
456
+ log_lines.append(f"[{i+1}/{total_queries}] NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
457
+ log_area.code("\n".join(log_lines[-10:]), language="text")
458
+ print(f"[EVAL] Progress {i+1}/{total_queries}: NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms")
459
+
460
+ progress_bar.progress(1.0)
461
+ eval_status.success(f"✅ Evaluation complete! ({total_queries} queries)")
462
+
463
+ with results_container:
464
+ st.markdown("##### 📊 Results")
465
+
466
+ p95_latency = float(np.percentile(latencies, 95))
467
+ eval_time_s = sum(latencies) / 1000
468
+ qps = total_queries / eval_time_s if eval_time_s > 0 else 0
469
+
470
+ final_metrics = {
471
+ "ndcg@5": float(np.mean(metrics_collectors["ndcg@5"])),
472
+ "ndcg@10": float(np.mean(metrics_collectors["ndcg@10"])),
473
+ "recall@5": float(np.mean(metrics_collectors["recall@5"])),
474
+ "recall@10": float(np.mean(metrics_collectors["recall@10"])),
475
+ "mrr@5": float(np.mean(metrics_collectors["mrr@5"])),
476
+ "mrr@10": float(np.mean(metrics_collectors["mrr@10"])),
477
+ "avg_latency_ms": float(np.mean(latencies)),
478
+ "p95_latency_ms": p95_latency,
479
+ "qps": qps,
480
+ "eval_time_s": eval_time_s,
481
+ "num_queries": total_queries,
482
+ }
483
+
484
+ print("=" * 60)
485
+ print("[EVAL] FINAL RESULTS:")
486
+ print(f"[EVAL] NDCG@5: {final_metrics['ndcg@5']:.4f}")
487
+ print(f"[EVAL] NDCG@10: {final_metrics['ndcg@10']:.4f}")
488
+ print(f"[EVAL] Recall@5: {final_metrics['recall@5']:.4f}")
489
+ print(f"[EVAL] Recall@10: {final_metrics['recall@10']:.4f}")
490
+ print(f"[EVAL] MRR@5: {final_metrics['mrr@5']:.4f}")
491
+ print(f"[EVAL] MRR@10: {final_metrics['mrr@10']:.4f}")
492
+ print(f"[EVAL] Avg Latency: {final_metrics['avg_latency_ms']:.1f}ms")
493
+ print(f"[EVAL] P95 Latency: {final_metrics['p95_latency_ms']:.1f}ms")
494
+ print(f"[EVAL] QPS: {final_metrics['qps']:.2f}")
495
+ print(f"[EVAL] Queries: {final_metrics['num_queries']}")
496
+ print("=" * 60)
497
+
498
+ st.markdown("**Retrieval Metrics**")
499
+ c1, c2, c3 = st.columns(3)
500
+ with c1:
501
+ st.metric("NDCG@5", f"{final_metrics['ndcg@5']:.4f}")
502
+ st.metric("NDCG@10", f"{final_metrics['ndcg@10']:.4f}")
503
+ with c2:
504
+ st.metric("Recall@5", f"{final_metrics['recall@5']:.4f}")
505
+ st.metric("Recall@10", f"{final_metrics['recall@10']:.4f}")
506
+ with c3:
507
+ st.metric("MRR@5", f"{final_metrics['mrr@5']:.4f}")
508
+ st.metric("MRR@10", f"{final_metrics['mrr@10']:.4f}")
509
+
510
+ st.markdown("**Performance**")
511
+ c4, c5, c6, c7 = st.columns(4)
512
+ c4.metric("Avg Latency", f"{final_metrics['avg_latency_ms']:.0f}ms")
513
+ c5.metric("P95 Latency", f"{final_metrics['p95_latency_ms']:.0f}ms")
514
+ c6.metric("QPS", f"{final_metrics['qps']:.2f}")
515
+ c7.metric("Eval Time", f"{final_metrics['eval_time_s']:.1f}s")
516
+
517
+ with st.expander("📋 Full Results JSON"):
518
+ st.json(final_metrics)
519
+
520
+ detailed_report = {
521
+ "generated_at": datetime.now().isoformat(),
522
+ "config": {
523
+ "collection": collection,
524
+ "model": model,
525
+ "datasets": datasets,
526
+ "mode": mode,
527
+ "top_k": top_k,
528
+ "evaluation_scope": config.get("evaluation_scope", "union"),
529
+ "prefer_grpc": prefer_grpc,
530
+ "torch_dtype": torch_dtype,
531
+ "max_queries": config.get("max_queries"),
532
+ },
533
+ "retrieval_metrics": {
534
+ "ndcg@5": final_metrics["ndcg@5"],
535
+ "ndcg@10": final_metrics["ndcg@10"],
536
+ "recall@5": final_metrics["recall@5"],
537
+ "recall@10": final_metrics["recall@10"],
538
+ "mrr@5": final_metrics["mrr@5"],
539
+ "mrr@10": final_metrics["mrr@10"],
540
+ },
541
+ "performance": {
542
+ "avg_latency_ms": final_metrics["avg_latency_ms"],
543
+ "p95_latency_ms": final_metrics["p95_latency_ms"],
544
+ "qps": final_metrics["qps"],
545
+ "eval_time_s": final_metrics["eval_time_s"],
546
+ "num_queries": final_metrics["num_queries"],
547
+ },
548
+ }
549
+
550
+ if mode == "two_stage":
551
+ detailed_report["config"]["stage1_mode"] = stage1_mode
552
+ detailed_report["config"]["prefetch_k"] = prefetch_k
553
+ elif mode == "three_stage":
554
+ detailed_report["config"]["stage1_k"] = stage1_k
555
+ detailed_report["config"]["stage2_k"] = stage2_k
556
+
557
+ if evaluation_scope == "per_dataset" and metrics_by_dataset:
558
+ detailed_report["metrics_by_dataset"] = metrics_by_dataset
559
+
560
+ st.markdown("**Per-Dataset Results**")
561
+ for ds_name, ds_metrics in metrics_by_dataset.items():
562
+ ds_short = ds_name.split("/")[-1]
563
+ with st.expander(f"📁 {ds_short}"):
564
+ dc1, dc2, dc3, dc4 = st.columns(4)
565
+ dc1.metric("NDCG@10", f"{ds_metrics['ndcg@10']:.4f}")
566
+ dc2.metric("Recall@10", f"{ds_metrics['recall@10']:.4f}")
567
+ dc3.metric("MRR@10", f"{ds_metrics['mrr@10']:.4f}")
568
+ dc4.metric("Latency", f"{ds_metrics['avg_latency_ms']:.0f}ms")
569
+
570
+ report_json = json.dumps(detailed_report, indent=2)
571
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
572
+ filename = f"eval_report__{collection}__{mode}__{timestamp}.json"
573
+
574
+ st.download_button(
575
+ label="📥 Download Detailed Report",
576
+ data=report_json,
577
+ file_name=filename,
578
+ mime="application/json",
579
+ use_container_width=True,
580
+ )
581
+
582
+ st.session_state["last_eval_metrics"] = final_metrics
583
+
584
+ except Exception as e:
585
+ error_msg = str(e)
586
+
587
+ if "not configured in this collection" in error_msg:
588
+ vector_name = error_msg.split("name ")[-1].split(" is")[0] if "name " in error_msg else "unknown"
589
+ st.error(f"❌ **Collection Mismatch**: Vector `{vector_name}` not found in collection `{collection}`")
590
+ st.warning(f"""
591
+ **The selected mode `{mode}` requires vectors that don't exist in this collection.**
592
+
593
+ **Suggestions:**
594
+ - Try `single_full` mode (works with basic collections)
595
+ - Use a collection indexed with the Visual RAG Toolkit
596
+ - Check that the collection has the required vector types for `{mode}` mode
597
+ """)
598
+ else:
599
+ st.error(f"❌ Error: {e}")
600
+
601
+ with st.expander("🔍 Full Error Details"):
602
+ st.code(traceback.format_exc(), language="text")
@@ -0,0 +1,37 @@
1
+ {
2
+ "filenames": {
3
+ "sigir2025-llms": {
4
+ "year": 2025,
5
+ "source": "Conference Paper",
6
+ "district": null,
7
+ "doc_type": "paper",
8
+ "project": "sigir-demo",
9
+ "tags": ["llms", "retrieval"]
10
+ },
11
+ "sigir2025-ginger": {
12
+ "year": 2025,
13
+ "source": "Conference Paper",
14
+ "district": null,
15
+ "doc_type": "paper",
16
+ "project": "sigir-demo",
17
+ "tags": ["ginger", "case-study"]
18
+ },
19
+ "2505.15859v1": {
20
+ "year": 2025,
21
+ "source": "arXiv",
22
+ "district": null,
23
+ "doc_type": "preprint",
24
+ "project": "sigir-demo",
25
+ "tags": ["arxiv", "ranking"]
26
+ },
27
+ "2507.04942v2": {
28
+ "year": 2025,
29
+ "source": "arXiv",
30
+ "district": null,
31
+ "doc_type": "preprint",
32
+ "project": "sigir-demo",
33
+ "tags": ["arxiv", "rag"]
34
+ }
35
+ }
36
+ }
37
+