visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,1365 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
13
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
14
+ from visual_rag import VisualEmbedder
15
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
16
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
17
+ from visual_rag.retrieval import MultiVectorRetriever
18
+
19
+
20
+ def _maybe_load_dotenv() -> None:
21
+ try:
22
+ from dotenv import load_dotenv
23
+ except ImportError:
24
+ return
25
+ if Path(".env").exists():
26
+ load_dotenv(".env")
27
+
28
+
29
+ def _torch_dtype_to_str(dtype) -> str:
30
+ if dtype is None:
31
+ return "auto"
32
+ s = str(dtype)
33
+ return s.replace("torch.", "")
34
+
35
+
36
+ def _parse_torch_dtype(dtype_str: str):
37
+ if dtype_str == "auto":
38
+ return None
39
+ import torch
40
+
41
+ mapping = {
42
+ "float32": torch.float32,
43
+ "float16": torch.float16,
44
+ "bfloat16": torch.bfloat16,
45
+ }
46
+ return mapping[dtype_str]
47
+
48
+
49
+ def _stable_uuid(text: str) -> str:
50
+ import hashlib
51
+
52
+ hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
53
+ return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
54
+
55
+
56
+ def _sample_list(items: List[Any], *, k: int, strategy: str, seed: int) -> List[Any]:
57
+ if not k or k <= 0:
58
+ return items
59
+ if k >= len(items):
60
+ return items
61
+ if strategy == "first":
62
+ return items[:k]
63
+ if strategy == "random":
64
+ import random
65
+
66
+ rng = random.Random(int(seed))
67
+ indices = rng.sample(range(len(items)), k)
68
+ return [items[i] for i in indices]
69
+ raise ValueError("sample strategy must be 'first' or 'random'")
70
+
71
+
72
+ def _parse_payload_indexes(values: List[str]) -> List[Dict[str, str]]:
73
+ indexes: List[Dict[str, str]] = []
74
+ for raw in values or []:
75
+ if ":" not in raw:
76
+ raise ValueError("payload index must be in field:type format")
77
+ field, type_str = raw.split(":", 1)
78
+ field = field.strip()
79
+ type_str = type_str.strip()
80
+ if not field or not type_str:
81
+ raise ValueError("payload index must be in field:type format")
82
+ indexes.append({"field": field, "type": type_str})
83
+ return indexes
84
+
85
+
86
+ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
87
+ ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
88
+ return _stable_uuid(f"{ns}::{source_doc_id}")
89
+
90
+
91
+ def _filter_qrels(qrels: Dict[str, Dict[str, int]], query_ids: List[str]) -> Dict[str, Dict[str, int]]:
92
+ keep = set(query_ids)
93
+ return {qid: rels for qid, rels in qrels.items() if qid in keep}
94
+
95
+ def _failed_log_path(*, collection_name: str, dataset_name: str) -> Path:
96
+ dir_name = _safe_filename(collection_name)
97
+ return Path("results") / dir_name / f"index_failures__{_safe_filename(dataset_name)}.jsonl"
98
+
99
+
100
+ def _resolve_output_path(raw_output: str, *, collection_name: str) -> Path:
101
+ """
102
+ Default behavior:
103
+ - If --output is a bare filename, write it to results/{collection_name}/{filename}
104
+ - If --output points into the legacy results/reports/, rewrite into results/{collection_name}/
105
+ - If --output includes any other directory (relative or absolute), respect it
106
+ """
107
+ p = Path(str(raw_output))
108
+ dir_name = _safe_filename(collection_name)
109
+
110
+ if str(p).startswith("results/reports/"):
111
+ return Path("results") / dir_name / p.name
112
+ if p.is_absolute():
113
+ return p
114
+ if p.parent == Path("."):
115
+ return Path("results") / dir_name / p.name
116
+ return p
117
+
118
+
119
+ def _default_output_filename(*, args, datasets: List[str]) -> str:
120
+ model_tag = _safe_filename(str(args.model).split("/")[-1])
121
+ scope_tag = _safe_filename(str(args.evaluation_scope))
122
+ mode_tag = _safe_filename(str(args.mode))
123
+ topk_tag = f"top{int(args.top_k)}"
124
+ ds_tag = f"{len(datasets)}ds"
125
+
126
+ parts = [model_tag, mode_tag]
127
+ if str(args.mode) == "two_stage":
128
+ parts.append(_safe_filename(str(args.stage1_mode)))
129
+ parts.append(f"pk{int(args.prefetch_k)}")
130
+ if str(args.mode) == "three_stage":
131
+ parts.append("tokens_vs_global")
132
+ parts.append(f"s1k{int(args.stage1_k)}")
133
+ parts.append("tokens_vs_experimental")
134
+ parts.append(f"s2k{int(args.stage2_k)}")
135
+ parts.extend([topk_tag, scope_tag, ds_tag])
136
+
137
+ if bool(args.crop_empty):
138
+ pct = int(round(float(args.crop_empty_percentage_to_remove) * 100))
139
+ parts.append(f"crop{pct}")
140
+
141
+ name = "__".join([p for p in parts if p])
142
+ return f"{name}.json"
143
+
144
+
145
+ def _append_jsonl(path: Path, obj: Dict[str, Any]) -> None:
146
+ path.parent.mkdir(parents=True, exist_ok=True)
147
+ with path.open("a") as f:
148
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
149
+
150
+
151
+ def _load_failed_ids(path: Path) -> set:
152
+ if not path.exists():
153
+ return set()
154
+ ids = set()
155
+ with path.open("r") as f:
156
+ for line in f:
157
+ line = (line or "").strip()
158
+ if not line:
159
+ continue
160
+ try:
161
+ obj = json.loads(line)
162
+ except Exception:
163
+ continue
164
+ for key in ("union_doc_id", "doc_id"):
165
+ v = obj.get(key)
166
+ if v:
167
+ ids.add(str(v))
168
+ return ids
169
+
170
+
171
+ def _load_failed_union_ids(
172
+ path: Path,
173
+ *,
174
+ dataset_name: str,
175
+ union_namespace: Optional[str],
176
+ ) -> set:
177
+ """
178
+ Load a set of union_doc_id values usable against Qdrant point IDs.
179
+
180
+ Older logs may contain union_doc_id computed without union_namespace.
181
+ We always recompute the current union_doc_id from source_doc_id to make retries consistent.
182
+ """
183
+ if not path.exists():
184
+ return set()
185
+ out = set()
186
+ with path.open("r") as f:
187
+ for line in f:
188
+ s = (line or "").strip()
189
+ if not s:
190
+ continue
191
+ try:
192
+ obj = json.loads(s)
193
+ except Exception:
194
+ continue
195
+ src = obj.get("source_doc_id")
196
+ if src:
197
+ out.add(
198
+ _union_point_id(
199
+ dataset_name=str(dataset_name),
200
+ source_doc_id=str(src),
201
+ union_namespace=union_namespace,
202
+ )
203
+ )
204
+ u = obj.get("union_doc_id")
205
+ if u:
206
+ out.add(str(u))
207
+ return out
208
+
209
+
210
+ def _remove_failed_from_qrels(qrels: Dict[str, Dict[str, int]], failed_ids: set) -> Tuple[Dict[str, Dict[str, int]], int]:
211
+ removed = 0
212
+ if not failed_ids:
213
+ return qrels, 0
214
+ out: Dict[str, Dict[str, int]] = {}
215
+ for qid, rels in (qrels or {}).items():
216
+ new_rels: Dict[str, int] = {}
217
+ for did, score in (rels or {}).items():
218
+ if str(did) in failed_ids:
219
+ removed += 1
220
+ continue
221
+ new_rels[str(did)] = int(score)
222
+ out[str(qid)] = new_rels
223
+ return out, removed
224
+
225
+
226
+ def _evaluate(
227
+ *,
228
+ queries,
229
+ qrels: Dict[str, Dict[str, int]],
230
+ retriever: MultiVectorRetriever,
231
+ embedder: VisualEmbedder,
232
+ top_k: int,
233
+ prefetch_k: int,
234
+ mode: str,
235
+ stage1_mode: str,
236
+ stage1_k: int,
237
+ stage2_k: int,
238
+ max_queries: int,
239
+ drop_empty_queries: bool,
240
+ filter_obj=None,
241
+ ) -> Dict[str, float]:
242
+ eval_started_at = time.time()
243
+ if drop_empty_queries:
244
+ queries = [q for q in queries if any(v > 0 for v in qrels.get(q.query_id, {}).values())]
245
+ if max_queries and max_queries > 0:
246
+ queries = queries[:max_queries]
247
+ if not queries:
248
+ return {
249
+ "ndcg@1": 0.0,
250
+ "ndcg@5": 0.0,
251
+ "ndcg@10": 0.0,
252
+ "ndcg@100": 0.0,
253
+ "mrr@1": 0.0,
254
+ "mrr@5": 0.0,
255
+ "mrr@10": 0.0,
256
+ "mrr@100": 0.0,
257
+ "recall@1": 0.0,
258
+ "recall@5": 0.0,
259
+ "recall@10": 0.0,
260
+ "recall@100": 0.0,
261
+ "avg_latency_ms": 0.0,
262
+ "p95_latency_ms": 0.0,
263
+ "eval_wall_time_s": 0.0,
264
+ "eval_search_time_s": 0.0,
265
+ "qps": 0.0,
266
+ "num_queries_eval": 0,
267
+ }
268
+
269
+ ndcg1: List[float] = []
270
+ ndcg5: List[float] = []
271
+ ndcg10: List[float] = []
272
+ ndcg100: List[float] = []
273
+ mrr1: List[float] = []
274
+ mrr5: List[float] = []
275
+ mrr10: List[float] = []
276
+ mrr100: List[float] = []
277
+ recall1: List[float] = []
278
+ recall5: List[float] = []
279
+ recall10: List[float] = []
280
+ recall100: List[float] = []
281
+ latencies_ms: List[float] = []
282
+
283
+ retrieve_k = max(100, top_k)
284
+
285
+ query_texts = [q.text for q in queries]
286
+ query_embeddings = embedder.embed_queries(
287
+ query_texts,
288
+ batch_size=getattr(embedder, "batch_size", None),
289
+ show_progress=False,
290
+ )
291
+
292
+ iterator = queries
293
+ try:
294
+ from tqdm import tqdm
295
+
296
+ iterator = tqdm(queries, desc="Searching", unit="q")
297
+ except ImportError:
298
+ pass
299
+
300
+ for q, qemb in zip(iterator, query_embeddings):
301
+ start = time.time()
302
+ try:
303
+ import torch
304
+ except ImportError:
305
+ torch = None
306
+ if torch is not None and isinstance(qemb, torch.Tensor):
307
+ qemb_np = qemb.detach().cpu().numpy()
308
+ else:
309
+ qemb_np = qemb.numpy()
310
+
311
+ results = retriever.search_embedded(
312
+ query_embedding=qemb_np,
313
+ top_k=retrieve_k,
314
+ mode=mode,
315
+ prefetch_k=prefetch_k,
316
+ stage1_mode=stage1_mode,
317
+ stage1_k=int(stage1_k),
318
+ stage2_k=int(stage2_k),
319
+ filter_obj=filter_obj,
320
+ )
321
+ latencies_ms.append((time.time() - start) * 1000.0)
322
+
323
+ ranking = [str(r["id"]) for r in results]
324
+ rels = qrels.get(q.query_id, {})
325
+
326
+ ndcg1.append(ndcg_at_k(ranking, rels, k=1))
327
+ ndcg5.append(ndcg_at_k(ranking, rels, k=5))
328
+ ndcg10.append(ndcg_at_k(ranking, rels, k=10))
329
+ ndcg100.append(ndcg_at_k(ranking, rels, k=100))
330
+ mrr1.append(mrr_at_k(ranking, rels, k=1))
331
+ mrr5.append(mrr_at_k(ranking, rels, k=5))
332
+ mrr10.append(mrr_at_k(ranking, rels, k=10))
333
+ mrr100.append(mrr_at_k(ranking, rels, k=100))
334
+ recall1.append(recall_at_k(ranking, rels, k=1))
335
+ recall5.append(recall_at_k(ranking, rels, k=5))
336
+ recall10.append(recall_at_k(ranking, rels, k=10))
337
+ recall100.append(recall_at_k(ranking, rels, k=100))
338
+
339
+ eval_wall_time_s = float(max(time.time() - eval_started_at, 0.0))
340
+ eval_search_time_s = float(np.sum(latencies_ms) / 1000.0) if latencies_ms else 0.0
341
+ qps = float(len(queries) / eval_wall_time_s) if eval_wall_time_s > 0 else 0.0
342
+ return {
343
+ "ndcg@1": float(np.mean(ndcg1)),
344
+ "ndcg@5": float(np.mean(ndcg5)),
345
+ "ndcg@10": float(np.mean(ndcg10)),
346
+ "ndcg@100": float(np.mean(ndcg100)),
347
+ "mrr@1": float(np.mean(mrr1)),
348
+ "mrr@5": float(np.mean(mrr5)),
349
+ "mrr@10": float(np.mean(mrr10)),
350
+ "mrr@100": float(np.mean(mrr100)),
351
+ "recall@1": float(np.mean(recall1)),
352
+ "recall@5": float(np.mean(recall5)),
353
+ "recall@10": float(np.mean(recall10)),
354
+ "recall@100": float(np.mean(recall100)),
355
+ "avg_latency_ms": float(np.mean(latencies_ms)),
356
+ "p95_latency_ms": float(np.percentile(latencies_ms, 95)),
357
+ "eval_wall_time_s": eval_wall_time_s,
358
+ "eval_search_time_s": eval_search_time_s,
359
+ "qps": qps,
360
+ "num_queries_eval": int(len(queries)),
361
+ }
362
+
363
+
364
+ def _write_json_atomic(path: Path, data: Dict[str, Any]) -> None:
365
+ path.parent.mkdir(parents=True, exist_ok=True)
366
+ fd, tmp_path = tempfile.mkstemp(prefix=path.name + ".", dir=str(path.parent))
367
+ try:
368
+ with os.fdopen(fd, "w") as f:
369
+ json.dump(data, f, indent=2)
370
+ os.replace(tmp_path, path)
371
+ finally:
372
+ try:
373
+ if os.path.exists(tmp_path):
374
+ os.unlink(tmp_path)
375
+ except Exception:
376
+ pass
377
+
378
+
379
+ def _safe_filename(text: str) -> str:
380
+ out = []
381
+ for ch in str(text):
382
+ if ch.isalnum() or ch in ("-", "_", "."):
383
+ out.append(ch)
384
+ else:
385
+ out.append("_")
386
+ return "".join(out).strip("_")
387
+
388
+
389
+ def _index_beir_corpus(
390
+ *,
391
+ dataset_name: str,
392
+ corpus,
393
+ embedder: VisualEmbedder,
394
+ collection_name: str,
395
+ prefer_grpc: bool,
396
+ qdrant_vector_dtype: str,
397
+ recreate: bool,
398
+ indexing_threshold: int,
399
+ batch_size: int,
400
+ upload_batch_size: int,
401
+ upload_workers: int,
402
+ upsert_wait: bool,
403
+ max_corpus_docs: int,
404
+ sample_corpus_docs: int,
405
+ sample_corpus_strategy: str,
406
+ sample_seed: int,
407
+ payload_indexes: List[Dict[str, str]],
408
+ union_namespace: Optional[str],
409
+ model_name: str,
410
+ resume: bool,
411
+ qdrant_timeout: int,
412
+ full_scan_threshold: int,
413
+ crop_empty: bool,
414
+ crop_empty_percentage_to_remove: float,
415
+ crop_empty_remove_page_number: bool,
416
+ crop_empty_preserve_border_px: int,
417
+ crop_empty_uniform_std_threshold: float,
418
+ no_cloudinary: bool,
419
+ cloudinary_folder: str,
420
+ retry_failures: bool,
421
+ only_failures: bool,
422
+ ) -> None:
423
+ qdrant_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
424
+ if not qdrant_url:
425
+ raise ValueError("QDRANT_URL not set")
426
+ qdrant_api_key = (
427
+ os.getenv("SIGIR_QDRANT_KEY")
428
+ or os.getenv("SIGIR_QDRANT_API_KEY")
429
+ or os.getenv("DEST_QDRANT_API_KEY")
430
+ or os.getenv("QDRANT_API_KEY")
431
+ )
432
+
433
+ indexer = QdrantIndexer(
434
+ url=qdrant_url,
435
+ api_key=qdrant_api_key,
436
+ collection_name=collection_name,
437
+ prefer_grpc=prefer_grpc,
438
+ vector_datatype=qdrant_vector_dtype,
439
+ timeout=int(qdrant_timeout),
440
+ )
441
+ indexer.create_collection(
442
+ force_recreate=recreate,
443
+ indexing_threshold=indexing_threshold,
444
+ full_scan_threshold=int(full_scan_threshold),
445
+ )
446
+ indexer.create_payload_indexes(fields=payload_indexes)
447
+
448
+ cloudinary_uploader: Optional[CloudinaryUploader] = None
449
+ if not bool(no_cloudinary):
450
+ try:
451
+ cloudinary_uploader = CloudinaryUploader(folder=str(cloudinary_folder))
452
+ except Exception:
453
+ cloudinary_uploader = None
454
+
455
+ failure_log = _failed_log_path(collection_name=collection_name, dataset_name=dataset_name)
456
+ failed_ids = _load_failed_union_ids(failure_log, dataset_name=dataset_name, union_namespace=union_namespace)
457
+ previously_failed_ids = set(failed_ids)
458
+
459
+ existing_ids = set()
460
+ if resume:
461
+ offset = None
462
+ while True:
463
+ points, next_offset = indexer.client.scroll(
464
+ collection_name=collection_name,
465
+ limit=1000,
466
+ offset=offset,
467
+ with_payload=False,
468
+ with_vectors=False,
469
+ )
470
+ for p in points:
471
+ existing_ids.add(str(p.id))
472
+ if not next_offset or not points:
473
+ break
474
+ offset = next_offset
475
+ if not bool(retry_failures) and not bool(only_failures):
476
+ existing_ids |= failed_ids
477
+
478
+ target_ids = None
479
+ if bool(only_failures):
480
+ target_ids = set(str(x) for x in failed_ids)
481
+ if not target_ids:
482
+ print(f"No failed ids found for dataset={dataset_name}; nothing to retry.")
483
+ return
484
+
485
+ if sample_corpus_docs and sample_corpus_docs > 0:
486
+ corpus = _sample_list(
487
+ list(corpus),
488
+ k=int(sample_corpus_docs),
489
+ strategy=str(sample_corpus_strategy),
490
+ seed=int(sample_seed),
491
+ )
492
+ elif max_corpus_docs and max_corpus_docs > 0:
493
+ corpus = corpus[:max_corpus_docs]
494
+ total_docs = len(corpus)
495
+ points_buffer: List[Dict[str, Any]] = []
496
+
497
+ def _safe_public_id(s: str) -> str:
498
+ out = _safe_filename(str(s))
499
+ return out[:180] if len(out) > 180 else out
500
+
501
+ def _ensure_pil(img):
502
+ try:
503
+ from PIL import Image
504
+ except Exception:
505
+ return img
506
+ if img is None:
507
+ return None
508
+ if isinstance(img, Image.Image):
509
+ return img
510
+ try:
511
+ return img.convert("RGB")
512
+ except Exception:
513
+ return img
514
+
515
+ def _resized_for_display(img):
516
+ from PIL import Image
517
+
518
+ if img is None or not isinstance(img, Image.Image):
519
+ return None
520
+ out = img.copy()
521
+ out.thumbnail((1024, 1024), Image.BICUBIC)
522
+ return out
523
+ uploaded_docs = 0
524
+ skipped_docs = 0
525
+ start_time = time.time()
526
+ last_tick_time = start_time
527
+ last_tick_docs = 0
528
+ last_s_per_doc = 0.0
529
+
530
+ pbar = None
531
+ try:
532
+ from tqdm import tqdm
533
+
534
+ pbar = tqdm(total=total_docs, desc="📦 Indexing corpus", unit="doc")
535
+ except ImportError:
536
+ pass
537
+
538
+ import threading
539
+ from concurrent.futures import ThreadPoolExecutor, wait as futures_wait, FIRST_EXCEPTION
540
+
541
+ stop_event = threading.Event()
542
+ executor = ThreadPoolExecutor(max_workers=int(upload_workers)) if upload_workers and upload_workers > 0 else None
543
+ futures = []
544
+
545
+ def _upload(points: List[Dict[str, Any]]) -> int:
546
+ uploaded = int(indexer.upload_batch(points, delay_between_batches=0.0, wait=upsert_wait, stop_event=stop_event) or 0)
547
+ if uploaded <= 0 and points:
548
+ for p in points:
549
+ pid = str(p.get("id") or "")
550
+ if pid and pid not in failed_ids:
551
+ _append_jsonl(
552
+ failure_log,
553
+ {
554
+ "dataset": dataset_name,
555
+ "collection": collection_name,
556
+ "model": model_name,
557
+ "source_doc_id": str((p.get("metadata") or {}).get("source_doc_id") or ""),
558
+ "doc_id": str((p.get("metadata") or {}).get("doc_id") or ""),
559
+ "union_doc_id": pid,
560
+ "error": "Qdrant upsert failed (all retries exhausted)",
561
+ },
562
+ )
563
+ failed_ids.add(pid)
564
+ return uploaded
565
+
566
+ def _drain(block: bool) -> None:
567
+ nonlocal uploaded_docs
568
+ if not futures:
569
+ return
570
+ done, _ = futures_wait(futures, return_when=FIRST_EXCEPTION, timeout=None if block else 0)
571
+ for d in list(done):
572
+ futures.remove(d)
573
+ uploaded_docs += int(d.result() or 0)
574
+
575
+ try:
576
+ for start in range(0, total_docs, batch_size):
577
+ batch = corpus[start : start + batch_size]
578
+ batch_total = len(batch)
579
+ if target_ids is not None:
580
+ filtered = []
581
+ for d in batch:
582
+ source_doc_id = str((d.payload or {}).get("source_doc_id") or d.doc_id)
583
+ union_doc_id = _union_point_id(
584
+ dataset_name=dataset_name,
585
+ source_doc_id=source_doc_id,
586
+ union_namespace=union_namespace,
587
+ )
588
+ if union_doc_id in existing_ids:
589
+ skipped_docs += 1
590
+ continue
591
+ if union_doc_id in target_ids:
592
+ filtered.append(d)
593
+ else:
594
+ skipped_docs += 1
595
+ batch = filtered
596
+ elif existing_ids:
597
+ filtered = []
598
+ for d in batch:
599
+ source_doc_id = str((d.payload or {}).get("source_doc_id") or d.doc_id)
600
+ union_doc_id = _union_point_id(
601
+ dataset_name=dataset_name,
602
+ source_doc_id=source_doc_id,
603
+ union_namespace=union_namespace,
604
+ )
605
+ if union_doc_id in existing_ids:
606
+ skipped_docs += 1
607
+ else:
608
+ filtered.append(d)
609
+ batch = filtered
610
+
611
+ if not batch:
612
+ if pbar is not None:
613
+ pbar.update(batch_total)
614
+ now = time.time()
615
+ done_docs = int(pbar.n)
616
+ elapsed = max(now - start_time, 1e-9)
617
+ avg_s_per_doc = elapsed / max(done_docs, 1)
618
+ delta_docs = done_docs - last_tick_docs
619
+ delta_t = max(now - last_tick_time, 1e-9)
620
+ if delta_docs > 0:
621
+ last_s_per_doc = delta_t / delta_docs
622
+ last_tick_time = now
623
+ last_tick_docs = done_docs
624
+ pbar.set_postfix(
625
+ {
626
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
627
+ "last_s/doc": f"{last_s_per_doc:.2f}",
628
+ "upl": uploaded_docs,
629
+ "skip": skipped_docs,
630
+ }
631
+ )
632
+ continue
633
+
634
+ if crop_empty:
635
+ from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty as _crop_empty
636
+
637
+ crop_cfg = CropEmptyConfig(
638
+ percentage_to_remove=float(crop_empty_percentage_to_remove),
639
+ remove_page_number=bool(crop_empty_remove_page_number),
640
+ preserve_border_px=int(crop_empty_preserve_border_px),
641
+ uniform_rowcol_std_threshold=float(crop_empty_uniform_std_threshold),
642
+ )
643
+ crop_metas = []
644
+ images = []
645
+ original_images = []
646
+ for d in batch:
647
+ original_img = _ensure_pil(d.image)
648
+ original_images.append(original_img)
649
+ cropped, meta = _crop_empty(original_img, config=crop_cfg)
650
+ images.append(cropped)
651
+ crop_metas.append(meta)
652
+ else:
653
+ original_images = [_ensure_pil(d.image) for d in batch]
654
+ images = original_images
655
+ crop_metas = [None for _ in batch]
656
+ try:
657
+ embeddings, token_infos = embedder.embed_images(
658
+ images,
659
+ batch_size=batch_size,
660
+ return_token_info=True,
661
+ show_progress=False,
662
+ )
663
+ except Exception as e:
664
+ # Retry per-doc to isolate flaky backend / corrupted sample issues.
665
+ embeddings = []
666
+ token_infos = []
667
+ for doc_i, img_i, crop_meta_i in zip(batch, images, crop_metas):
668
+ try:
669
+ e1, t1 = embedder.embed_images(
670
+ [img_i],
671
+ batch_size=1,
672
+ return_token_info=True,
673
+ show_progress=False,
674
+ )
675
+ embeddings.append(e1[0])
676
+ token_infos.append(t1[0])
677
+ except Exception as e_single:
678
+ source_doc_id_i = str((doc_i.payload or {}).get("source_doc_id") or doc_i.doc_id)
679
+ union_doc_id_i = _union_point_id(
680
+ dataset_name=dataset_name,
681
+ source_doc_id=source_doc_id_i,
682
+ union_namespace=union_namespace,
683
+ )
684
+ if str(union_doc_id_i) not in failed_ids:
685
+ _append_jsonl(
686
+ failure_log,
687
+ {
688
+ "dataset": dataset_name,
689
+ "collection": collection_name,
690
+ "model": model_name,
691
+ "source_doc_id": str(source_doc_id_i),
692
+ "doc_id": str(getattr(doc_i, "doc_id", "")),
693
+ "union_doc_id": str(union_doc_id_i),
694
+ "error": str(e_single),
695
+ },
696
+ )
697
+ failed_ids.add(str(union_doc_id_i))
698
+ existing_ids.add(str(union_doc_id_i))
699
+ skipped_docs += 1
700
+ if pbar is not None:
701
+ pbar.update(batch_total)
702
+ continue
703
+
704
+ for doc, emb, token_info, crop_meta, original_img, embed_img in zip(
705
+ batch, embeddings, token_infos, crop_metas, original_images, images
706
+ ):
707
+ try:
708
+ emb_np = emb.cpu().float().numpy() if hasattr(emb, "cpu") else np.array(emb, dtype=np.float32)
709
+ visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
710
+ visual_embedding = emb_np[visual_indices].astype(np.float32)
711
+ tile_pooled = embedder.mean_pool_visual_embedding(visual_embedding, token_info, target_vectors=32)
712
+ experimental_pooled = embedder.experimental_pool_visual_embedding(
713
+ visual_embedding, token_info, target_vectors=32, mean_pool=tile_pooled
714
+ )
715
+ global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)
716
+ except Exception as e_single:
717
+ source_doc_id_i = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
718
+ union_doc_id_i = _union_point_id(
719
+ dataset_name=dataset_name,
720
+ source_doc_id=source_doc_id_i,
721
+ union_namespace=union_namespace,
722
+ )
723
+ if str(union_doc_id_i) not in failed_ids:
724
+ _append_jsonl(
725
+ failure_log,
726
+ {
727
+ "dataset": dataset_name,
728
+ "collection": collection_name,
729
+ "model": model_name,
730
+ "source_doc_id": str(source_doc_id_i),
731
+ "doc_id": str(getattr(doc, "doc_id", "")),
732
+ "union_doc_id": str(union_doc_id_i),
733
+ "error": str(e_single),
734
+ },
735
+ )
736
+ failed_ids.add(str(union_doc_id_i))
737
+ existing_ids.add(str(union_doc_id_i))
738
+ skipped_docs += 1
739
+ continue
740
+
741
+ num_tiles = int(tile_pooled.shape[0])
742
+ patches_per_tile = int(visual_embedding.shape[0] // max(num_tiles, 1)) if num_tiles else 0
743
+
744
+ source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
745
+ union_doc_id = _union_point_id(
746
+ dataset_name=dataset_name,
747
+ source_doc_id=source_doc_id,
748
+ union_namespace=union_namespace,
749
+ )
750
+
751
+ resized_img = _resized_for_display(embed_img) or embed_img
752
+ original_url = ""
753
+ cropped_url = ""
754
+ resized_url = ""
755
+ if cloudinary_uploader is not None and original_img is not None and resized_img is not None:
756
+ base_public_id = _safe_public_id(f"{dataset_name}__{union_doc_id}")
757
+ try:
758
+ if crop_empty:
759
+ o_url, c_url, r_url = cloudinary_uploader.upload_original_cropped_and_resized(
760
+ original_img,
761
+ embed_img if embed_img is not None and embed_img is not original_img else None,
762
+ resized_img,
763
+ base_public_id,
764
+ )
765
+ original_url = o_url or ""
766
+ cropped_url = c_url or ""
767
+ resized_url = r_url or ""
768
+ else:
769
+ o_url, r_url = cloudinary_uploader.upload_original_and_resized(
770
+ original_img,
771
+ resized_img,
772
+ base_public_id,
773
+ )
774
+ original_url = o_url or ""
775
+ resized_url = r_url or ""
776
+ except Exception:
777
+ pass
778
+
779
+ payload = {
780
+ "dataset": dataset_name,
781
+ "doc_id": doc.doc_id,
782
+ "union_doc_id": union_doc_id,
783
+ "page": resized_url or original_url or "",
784
+ "original_url": original_url,
785
+ "cropped_url": cropped_url,
786
+ "resized_url": resized_url,
787
+ "original_width": int(original_img.width) if original_img is not None else None,
788
+ "original_height": int(original_img.height) if original_img is not None else None,
789
+ "cropped_width": int(embed_img.width) if embed_img is not None else None,
790
+ "cropped_height": int(embed_img.height) if embed_img is not None else None,
791
+ "resized_width": int(resized_img.width) if resized_img is not None else None,
792
+ "resized_height": int(resized_img.height) if resized_img is not None else None,
793
+ "num_tiles": int(num_tiles),
794
+ "patches_per_tile": int(patches_per_tile),
795
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
796
+ "model_name": model_name,
797
+ "crop_empty_enabled": bool(crop_empty),
798
+ "crop_empty_crop_box": (crop_meta or {}).get("crop_box") if crop_empty else None,
799
+ "crop_empty_remove_page_number": bool(crop_empty_remove_page_number) if crop_empty else None,
800
+ "crop_empty_percentage_to_remove": float(crop_empty_percentage_to_remove) if crop_empty else None,
801
+ "index_recovery_previously_failed": bool(union_doc_id in previously_failed_ids),
802
+ "index_recovery_mode": (
803
+ "only_failures" if bool(only_failures) else ("retry_failures" if bool(retry_failures) else None)
804
+ ),
805
+ "index_recovery_pooling_inferred_tiles": bool(
806
+ (token_info or {}).get("num_tiles") is None and (token_info or {}).get("n_rows") is None and (token_info or {}).get("n_cols") is None
807
+ ),
808
+ "index_recovery_num_visual_tokens": int(visual_embedding.shape[0]),
809
+ **(doc.payload or {}),
810
+ }
811
+
812
+ points_buffer.append(
813
+ {
814
+ "id": union_doc_id,
815
+ "visual_embedding": visual_embedding,
816
+ "tile_pooled_embedding": tile_pooled,
817
+ "experimental_pooled_embedding": experimental_pooled,
818
+ "global_pooled_embedding": global_pooled,
819
+ "metadata": payload,
820
+ }
821
+ )
822
+
823
+ if len(points_buffer) >= upload_batch_size:
824
+ chunk = points_buffer
825
+ points_buffer = []
826
+ if executor is None:
827
+ uploaded_docs += int(_upload(chunk) or 0)
828
+ else:
829
+ futures.append(executor.submit(_upload, chunk))
830
+ _drain(block=len(futures) >= int(upload_workers) * 2)
831
+
832
+ if pbar is not None:
833
+ pbar.update(len(batch))
834
+ now = time.time()
835
+ done_docs = int(pbar.n)
836
+ elapsed = max(now - start_time, 1e-9)
837
+ avg_s_per_doc = elapsed / max(done_docs, 1)
838
+ delta_docs = done_docs - last_tick_docs
839
+ delta_t = max(now - last_tick_time, 1e-9)
840
+ if delta_docs > 0:
841
+ last_s_per_doc = delta_t / delta_docs
842
+ last_tick_time = now
843
+ last_tick_docs = done_docs
844
+ pbar.set_postfix(
845
+ {
846
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
847
+ "last_s/doc": f"{last_s_per_doc:.2f}",
848
+ "upl": uploaded_docs,
849
+ "skip": skipped_docs,
850
+ }
851
+ )
852
+
853
+ if executor is not None:
854
+ _drain(block=False)
855
+ except KeyboardInterrupt:
856
+ stop_event.set()
857
+ if executor is not None:
858
+ executor.shutdown(wait=False, cancel_futures=True)
859
+ raise
860
+
861
+ if points_buffer:
862
+ if executor is None:
863
+ uploaded_docs += int(_upload(points_buffer) or 0)
864
+ else:
865
+ futures.append(executor.submit(_upload, points_buffer))
866
+
867
+ if executor is not None:
868
+ _drain(block=True)
869
+ executor.shutdown(wait=True)
870
+
871
+ if pbar is not None:
872
+ pbar.close()
873
+
874
+
875
+ def main() -> None:
876
+ parser = argparse.ArgumentParser()
877
+ parser.add_argument("--dataset", type=str, default=None)
878
+ parser.add_argument("--datasets", type=str, nargs="+", default=None)
879
+ parser.add_argument("--collection", type=str, required=True)
880
+ parser.add_argument("--model", type=str, default="vidore/colSmol-500M")
881
+ parser.add_argument(
882
+ "--processor-speed",
883
+ type=str,
884
+ default="fast",
885
+ choices=["fast", "slow", "auto"],
886
+ help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
887
+ )
888
+ parser.add_argument(
889
+ "--torch-dtype",
890
+ type=str,
891
+ default="auto",
892
+ choices=["auto", "float32", "float16", "bfloat16"],
893
+ )
894
+ parser.add_argument(
895
+ "--qdrant-vector-dtype",
896
+ type=str,
897
+ default="float16",
898
+ choices=["float16", "float32"],
899
+ )
900
+ grpc_group = parser.add_mutually_exclusive_group()
901
+ grpc_group.add_argument("--prefer-grpc", dest="prefer_grpc", action="store_true", default=True)
902
+ grpc_group.add_argument("--no-prefer-grpc", dest="prefer_grpc", action="store_false")
903
+ parser.add_argument("--index", action="store_true")
904
+ parser.add_argument("--recreate", action="store_true")
905
+ parser.add_argument("--indexing-threshold", type=int, default=0)
906
+ parser.add_argument("--full-scan-threshold", type=int, default=0)
907
+ parser.add_argument("--batch-size", type=int, default=4)
908
+ parser.add_argument("--upload-batch-size", type=int, default=8)
909
+ parser.add_argument("--upload-workers", type=int, default=0)
910
+ parser.add_argument("--upsert-wait", action="store_true")
911
+ parser.add_argument("--max-corpus-docs", type=int, default=0)
912
+ parser.add_argument("--sample-corpus-docs", type=int, default=0)
913
+ parser.add_argument("--sample-corpus-strategy", type=str, default="first", choices=["first", "random"])
914
+ parser.add_argument("--sample-seed", type=int, default=0)
915
+ parser.add_argument("--sample-queries", type=int, default=0)
916
+ parser.add_argument("--sample-query-strategy", type=str, default="first", choices=["first", "random"])
917
+ parser.add_argument("--sample-query-seed", type=int, default=0)
918
+ parser.add_argument("--index-from-queries", action="store_true", default=False)
919
+ parser.add_argument("--resume", action="store_true", default=False)
920
+ parser.add_argument("--qdrant-timeout", type=int, default=120)
921
+ parser.add_argument("--qdrant-retries", type=int, default=3)
922
+ parser.add_argument("--qdrant-retry-sleep", type=float, default=0.5)
923
+ parser.add_argument("--crop-empty", action="store_true", default=False)
924
+ parser.add_argument("--crop-empty-percentage-to-remove", type=float, default=0.9)
925
+ parser.add_argument("--crop-empty-remove-page-number", action="store_true", default=False)
926
+ parser.add_argument("--crop-empty-preserve-border-px", type=int, default=1)
927
+ parser.add_argument("--crop-empty-uniform-std-threshold", type=float, default=0.0)
928
+ payload_group = parser.add_mutually_exclusive_group()
929
+ payload_group.add_argument("--index-common-metadata", action="store_true", default=True)
930
+ payload_group.add_argument("--no-index-common-metadata", dest="index_common_metadata", action="store_false")
931
+ parser.add_argument("--payload-index", action="append", default=[])
932
+ parser.add_argument(
933
+ "--no-cloudinary",
934
+ action="store_true",
935
+ help="Disable Cloudinary uploads during indexing (default: enabled).",
936
+ )
937
+ parser.add_argument(
938
+ "--cloudinary-folder",
939
+ type=str,
940
+ default="vidore-beir",
941
+ help="Cloudinary base folder for uploads (default: vidore-beir).",
942
+ )
943
+ parser.add_argument(
944
+ "--retry-failures",
945
+ action="store_true",
946
+ default=False,
947
+ help="On --resume, retry documents listed in index_failures__<collection>__<dataset>.jsonl (default: skip them).",
948
+ )
949
+ parser.add_argument(
950
+ "--only-failures",
951
+ action="store_true",
952
+ default=False,
953
+ help="Index only documents listed in index_failures__<collection>__<dataset>.jsonl.",
954
+ )
955
+
956
+ parser.add_argument("--top-k", type=int, default=100, help="Retrieve top-k results (default: 100 to calculate metrics at all cutoffs)")
957
+ parser.add_argument("--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage (default: 200)")
958
+ parser.add_argument(
959
+ "--no-eval",
960
+ action="store_true",
961
+ default=False,
962
+ help="If set, run indexing only and skip evaluation.",
963
+ )
964
+ parser.add_argument(
965
+ "--mode",
966
+ type=str,
967
+ default="single_full",
968
+ choices=["single_full", "single_tiles", "single_global", "two_stage", "three_stage"],
969
+ )
970
+ parser.add_argument(
971
+ "--stage1-mode",
972
+ type=str,
973
+ default="tokens_vs_tiles",
974
+ choices=[
975
+ "pooled_query_vs_tiles",
976
+ "tokens_vs_tiles",
977
+ "pooled_query_vs_experimental",
978
+ "tokens_vs_experimental",
979
+ "pooled_query_vs_global",
980
+ ],
981
+ )
982
+ parser.add_argument("--stage1-k", type=int, default=1000, help="Three-stage stage1 top_k (default: 1000)")
983
+ parser.add_argument("--stage2-k", type=int, default=300, help="Three-stage stage2 top_k (default: 300)")
984
+ parser.add_argument("--max-queries", type=int, default=0)
985
+ drop_group = parser.add_mutually_exclusive_group()
986
+ drop_group.add_argument("--drop-empty-queries", dest="drop_empty_queries", action="store_true", default=True)
987
+ drop_group.add_argument("--no-drop-empty-queries", dest="drop_empty_queries", action="store_false")
988
+ parser.add_argument(
989
+ "--evaluation-scope",
990
+ type=str,
991
+ default="union",
992
+ choices=["union", "per_dataset"],
993
+ help="Evaluation scope: 'union' searches over the whole collection (cross-dataset distractors). "
994
+ "'per_dataset' applies a Qdrant filter so each dataset's queries search only its own subset (leaderboard-comparable).",
995
+ )
996
+ cont_group = parser.add_mutually_exclusive_group()
997
+ cont_group.add_argument(
998
+ "--continue-on-error",
999
+ dest="continue_on_error",
1000
+ action="store_true",
1001
+ default=True,
1002
+ help="Continue evaluating remaining datasets if one dataset fails (default: true).",
1003
+ )
1004
+ cont_group.add_argument(
1005
+ "--no-continue-on-error",
1006
+ dest="continue_on_error",
1007
+ action="store_false",
1008
+ help="Stop the run immediately on the first dataset evaluation failure.",
1009
+ )
1010
+ parser.add_argument("--output", type=str, default="auto")
1011
+
1012
+ args = parser.parse_args()
1013
+
1014
+ _maybe_load_dotenv()
1015
+
1016
+ if args.recreate:
1017
+ args.index = True
1018
+
1019
+ if args.sample_corpus_docs and int(args.sample_corpus_docs) > 0 and args.max_corpus_docs and int(args.max_corpus_docs) > 0:
1020
+ raise ValueError("Use only one of --sample-corpus-docs or --max-corpus-docs (not both).")
1021
+ if args.sample_queries and int(args.sample_queries) > 0 and args.index_from_queries:
1022
+ if (args.sample_corpus_docs and int(args.sample_corpus_docs) > 0) or (args.max_corpus_docs and int(args.max_corpus_docs) > 0):
1023
+ raise ValueError("Use --index-from-queries with --sample-queries only (do not combine with corpus sampling).")
1024
+
1025
+ if args.upsert_wait:
1026
+ print("Qdrant upserts wait for completion (wait=True).")
1027
+ else:
1028
+ print("Qdrant upserts are async (wait=False).")
1029
+ print(f"Qdrant request timeout: {int(args.qdrant_timeout)}s, retries: {int(args.qdrant_retries)}.")
1030
+
1031
+ datasets: List[str] = []
1032
+ if args.datasets:
1033
+ datasets = list(args.datasets)
1034
+ elif args.dataset:
1035
+ datasets = [args.dataset]
1036
+ else:
1037
+ raise ValueError("Provide --dataset (single) or --datasets (one or more)")
1038
+
1039
+ if str(args.output).strip().lower() == "auto":
1040
+ args.output = _default_output_filename(args=args, datasets=datasets)
1041
+
1042
+ loaded: List[Tuple[str, Any, Any, Dict[str, Dict[str, int]]]] = []
1043
+ for ds_name in datasets:
1044
+ corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
1045
+ loaded.append((ds_name, corpus, queries, qrels))
1046
+
1047
+ output_dtype = np.float16 if args.qdrant_vector_dtype == "float16" else np.float32
1048
+ embedder = VisualEmbedder(
1049
+ model_name=args.model,
1050
+ batch_size=args.batch_size,
1051
+ torch_dtype=_parse_torch_dtype(args.torch_dtype),
1052
+ output_dtype=output_dtype,
1053
+ processor_speed=str(args.processor_speed),
1054
+ )
1055
+
1056
+ selected: List[Tuple[str, Any, Any, Dict[str, Dict[str, int]]]] = []
1057
+ for ds_name, corpus, queries, qrels in loaded:
1058
+ if args.sample_queries and int(args.sample_queries) > 0:
1059
+ queries = _sample_list(
1060
+ list(queries),
1061
+ k=int(args.sample_queries),
1062
+ strategy=str(args.sample_query_strategy),
1063
+ seed=int(args.sample_query_seed),
1064
+ )
1065
+ qrels = _filter_qrels(qrels, [q.query_id for q in queries])
1066
+
1067
+ if args.index_from_queries and args.sample_queries and int(args.sample_queries) > 0:
1068
+ rel_doc_ids = set()
1069
+ for q in queries:
1070
+ for did, score in qrels.get(q.query_id, {}).items():
1071
+ if score > 0:
1072
+ rel_doc_ids.add(str(did))
1073
+ corpus = [d for d in corpus if str(d.doc_id) in rel_doc_ids]
1074
+ else:
1075
+ if args.sample_corpus_docs and int(args.sample_corpus_docs) > 0:
1076
+ corpus = _sample_list(
1077
+ list(corpus),
1078
+ k=int(args.sample_corpus_docs),
1079
+ strategy=str(args.sample_corpus_strategy),
1080
+ seed=int(args.sample_seed),
1081
+ )
1082
+ elif args.max_corpus_docs and int(args.max_corpus_docs) > 0:
1083
+ corpus = corpus[: int(args.max_corpus_docs)]
1084
+
1085
+ selected.append((ds_name, corpus, queries, qrels))
1086
+
1087
+ payload_indexes: List[Dict[str, str]] = []
1088
+ if args.index:
1089
+ payload_indexes = _parse_payload_indexes(args.payload_index)
1090
+ if args.index_common_metadata:
1091
+ payload_indexes.extend(
1092
+ [
1093
+ {"field": "dataset", "type": "keyword"},
1094
+ {"field": "source_doc_id", "type": "keyword"},
1095
+ {"field": "doc_id", "type": "keyword"},
1096
+ {"field": "filename", "type": "keyword"},
1097
+ {"field": "page", "type": "keyword"},
1098
+ {"field": "page_number", "type": "integer"},
1099
+ {"field": "total_pages", "type": "integer"},
1100
+ {"field": "has_text", "type": "bool"},
1101
+ {"field": "text", "type": "text"},
1102
+ {"field": "original_url", "type": "keyword"},
1103
+ {"field": "resized_url", "type": "keyword"},
1104
+ {"field": "original_width", "type": "integer"},
1105
+ {"field": "original_height", "type": "integer"},
1106
+ {"field": "resized_width", "type": "integer"},
1107
+ {"field": "resized_height", "type": "integer"},
1108
+ {"field": "crop_empty_enabled", "type": "bool"},
1109
+ {"field": "crop_empty_remove_page_number", "type": "bool"},
1110
+ {"field": "crop_empty_percentage_to_remove", "type": "float"},
1111
+ {"field": "num_tiles", "type": "integer"},
1112
+ {"field": "tile_rows", "type": "integer"},
1113
+ {"field": "tile_cols", "type": "integer"},
1114
+ {"field": "patches_per_tile", "type": "integer"},
1115
+ {"field": "num_visual_tokens", "type": "integer"},
1116
+ {"field": "processor_version", "type": "keyword"},
1117
+ {"field": "year", "type": "integer"},
1118
+ {"field": "source", "type": "keyword"},
1119
+ ]
1120
+ )
1121
+ for i, (ds_name, corpus, queries, _qrels) in enumerate(selected):
1122
+ print(f"Indexing {ds_name}: corpus_docs={len(corpus)} queries={len(queries)}")
1123
+ _index_beir_corpus(
1124
+ dataset_name=ds_name,
1125
+ corpus=corpus,
1126
+ embedder=embedder,
1127
+ collection_name=args.collection,
1128
+ prefer_grpc=args.prefer_grpc,
1129
+ qdrant_vector_dtype=args.qdrant_vector_dtype,
1130
+ recreate=bool(args.recreate and i == 0),
1131
+ indexing_threshold=args.indexing_threshold,
1132
+ batch_size=args.batch_size,
1133
+ upload_batch_size=args.upload_batch_size,
1134
+ upload_workers=args.upload_workers,
1135
+ upsert_wait=bool(args.upsert_wait),
1136
+ max_corpus_docs=0,
1137
+ sample_corpus_docs=0,
1138
+ sample_corpus_strategy=str(args.sample_corpus_strategy),
1139
+ sample_seed=int(args.sample_seed),
1140
+ payload_indexes=payload_indexes,
1141
+ union_namespace=args.collection,
1142
+ model_name=args.model,
1143
+ resume=bool(args.resume),
1144
+ qdrant_timeout=int(args.qdrant_timeout),
1145
+ full_scan_threshold=int(args.full_scan_threshold),
1146
+ crop_empty=bool(args.crop_empty),
1147
+ crop_empty_percentage_to_remove=float(args.crop_empty_percentage_to_remove),
1148
+ crop_empty_remove_page_number=bool(args.crop_empty_remove_page_number),
1149
+ crop_empty_preserve_border_px=int(args.crop_empty_preserve_border_px),
1150
+ crop_empty_uniform_std_threshold=float(args.crop_empty_uniform_std_threshold),
1151
+ no_cloudinary=bool(args.no_cloudinary),
1152
+ cloudinary_folder=str(args.cloudinary_folder),
1153
+ retry_failures=bool(args.retry_failures),
1154
+ only_failures=bool(args.only_failures),
1155
+ )
1156
+
1157
+ out_path = _resolve_output_path(args.output, collection_name=str(args.collection))
1158
+
1159
+ if bool(args.no_eval):
1160
+ dataset_index_failures: Dict[str, Dict[str, Any]] = {}
1161
+ dataset_counts: Dict[str, Dict[str, int]] = {}
1162
+ for ds_name, corpus, queries, _qrels in selected:
1163
+ dataset_counts[ds_name] = {"corpus_docs": int(len(corpus)), "queries": int(len(queries)), "queries_eval": 0}
1164
+ failed_path = _failed_log_path(collection_name=args.collection, dataset_name=ds_name)
1165
+ failed_ids = _load_failed_union_ids(failed_path, dataset_name=ds_name, union_namespace=args.collection)
1166
+ dataset_index_failures[ds_name] = {
1167
+ "failed_log_path": str(failed_path),
1168
+ "failed_ids_count": int(len(failed_ids)),
1169
+ "qrels_removed": None,
1170
+ }
1171
+ _write_json_atomic(
1172
+ out_path,
1173
+ {
1174
+ "command": " ".join(sys.argv),
1175
+ "dataset": datasets[0] if len(datasets) == 1 else None,
1176
+ "datasets": datasets,
1177
+ "protocol": "beir",
1178
+ "collection": args.collection,
1179
+ "model": args.model,
1180
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
1181
+ "qdrant_vector_dtype": args.qdrant_vector_dtype,
1182
+ "mode": args.mode,
1183
+ "stage1_mode": args.stage1_mode if args.mode == "two_stage" else None,
1184
+ "prefetch_k": args.prefetch_k if args.mode == "two_stage" else None,
1185
+ "top_k": args.top_k,
1186
+ "evaluation_scope": str(args.evaluation_scope),
1187
+ "dataset_counts": dataset_counts,
1188
+ "dataset_errors": {},
1189
+ "dataset_index_failures": dataset_index_failures,
1190
+ "qdrant_timeout": int(args.qdrant_timeout),
1191
+ "qdrant_retries": int(args.qdrant_retries),
1192
+ "qdrant_retry_sleep": float(args.qdrant_retry_sleep),
1193
+ "full_scan_threshold": int(args.full_scan_threshold),
1194
+ "no_eval": True,
1195
+ "eval_wall_time_s": None,
1196
+ "metrics": None,
1197
+ "metrics_by_dataset": {},
1198
+ },
1199
+ )
1200
+ print(f"Wrote index-only report: {out_path}")
1201
+ return
1202
+
1203
+ retriever = MultiVectorRetriever(
1204
+ collection_name=args.collection,
1205
+ embedder=embedder,
1206
+ qdrant_url=os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL"),
1207
+ qdrant_api_key=(
1208
+ os.getenv("SIGIR_QDRANT_KEY")
1209
+ or os.getenv("SIGIR_QDRANT_API_KEY")
1210
+ or os.getenv("DEST_QDRANT_API_KEY")
1211
+ or os.getenv("QDRANT_API_KEY")
1212
+ ),
1213
+ prefer_grpc=args.prefer_grpc,
1214
+ request_timeout=int(args.qdrant_timeout),
1215
+ max_retries=int(args.qdrant_retries),
1216
+ retry_sleep=float(args.qdrant_retry_sleep),
1217
+ )
1218
+
1219
+ metrics_by_dataset: Dict[str, Dict[str, float]] = {}
1220
+ dataset_errors: Dict[str, str] = {}
1221
+ dataset_counts: Dict[str, Dict[str, int]] = {}
1222
+ dataset_index_failures: Dict[str, Dict[str, Any]] = {}
1223
+ eval_started_at = time.time()
1224
+
1225
+ def _build_run_record() -> Dict[str, Any]:
1226
+ single_dataset = datasets[0] if len(datasets) == 1 else None
1227
+ single_metrics = metrics_by_dataset.get(single_dataset) if single_dataset else None
1228
+ run_cmd = " ".join(sys.argv)
1229
+ return {
1230
+ "command": run_cmd,
1231
+ "dataset": single_dataset,
1232
+ "datasets": datasets,
1233
+ "protocol": "beir",
1234
+ "collection": args.collection,
1235
+ "model": args.model,
1236
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
1237
+ "qdrant_vector_dtype": args.qdrant_vector_dtype,
1238
+ "mode": args.mode,
1239
+ "stage1_mode": args.stage1_mode if args.mode == "two_stage" else None,
1240
+ "prefetch_k": args.prefetch_k if args.mode == "two_stage" else None,
1241
+ "stage1_k": int(args.stage1_k) if args.mode == "three_stage" else None,
1242
+ "stage2_k": int(args.stage2_k) if args.mode == "three_stage" else None,
1243
+ "top_k": args.top_k,
1244
+ "max_queries": args.max_queries,
1245
+ "max_corpus_docs": int(args.max_corpus_docs),
1246
+ "sample_corpus_docs": int(args.sample_corpus_docs),
1247
+ "sample_corpus_strategy": str(args.sample_corpus_strategy),
1248
+ "sample_seed": int(args.sample_seed),
1249
+ "sample_queries": int(args.sample_queries),
1250
+ "sample_query_strategy": str(args.sample_query_strategy),
1251
+ "sample_query_seed": int(args.sample_query_seed),
1252
+ "index_from_queries": bool(args.index_from_queries),
1253
+ "drop_empty_queries": bool(args.drop_empty_queries),
1254
+ "evaluation_scope": str(args.evaluation_scope),
1255
+ "payload_indexes": payload_indexes,
1256
+ "dataset_counts": dataset_counts,
1257
+ "dataset_errors": dataset_errors,
1258
+ "dataset_index_failures": dataset_index_failures,
1259
+ "qdrant_timeout": int(args.qdrant_timeout),
1260
+ "qdrant_retries": int(args.qdrant_retries),
1261
+ "qdrant_retry_sleep": float(args.qdrant_retry_sleep),
1262
+ "full_scan_threshold": int(args.full_scan_threshold),
1263
+ "eval_wall_time_s": float(max(time.time() - eval_started_at, 0.0)),
1264
+ "metrics": single_metrics,
1265
+ "metrics_by_dataset": metrics_by_dataset,
1266
+ }
1267
+
1268
+ for ds_name, corpus, queries, qrels in selected:
1269
+ print(
1270
+ f"Evaluating dataset={ds_name} "
1271
+ f"(corpus_docs={len(corpus)}, queries={len(queries)}) "
1272
+ f"scope={args.evaluation_scope} "
1273
+ f"mode={args.mode}"
1274
+ + (f", stage1_mode={args.stage1_mode}, prefetch_k={int(args.prefetch_k)}" if args.mode == "two_stage" else "")
1275
+ + (f", stage1_k={int(args.stage1_k)}, stage2_k={int(args.stage2_k)}" if args.mode == "three_stage" else "")
1276
+ + f", top_k={int(args.top_k)}"
1277
+ )
1278
+ sys.stdout.flush()
1279
+
1280
+ dataset_counts[ds_name] = {"corpus_docs": int(len(corpus)), "queries": int(len(queries)), "queries_eval": 0}
1281
+ id_map: Dict[str, str] = {}
1282
+ for doc in corpus:
1283
+ source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
1284
+ id_map[str(doc.doc_id)] = _union_point_id(
1285
+ dataset_name=ds_name,
1286
+ source_doc_id=source_doc_id,
1287
+ union_namespace=args.collection,
1288
+ )
1289
+
1290
+ remapped_qrels: Dict[str, Dict[str, int]] = {}
1291
+ for qid, rels in qrels.items():
1292
+ out_rels: Dict[str, int] = {}
1293
+ for did, score in rels.items():
1294
+ mapped = id_map.get(str(did))
1295
+ if mapped:
1296
+ out_rels[mapped] = int(score)
1297
+ if out_rels:
1298
+ remapped_qrels[qid] = out_rels
1299
+
1300
+ failed_path = _failed_log_path(collection_name=args.collection, dataset_name=ds_name)
1301
+ failed_ids = _load_failed_union_ids(failed_path, dataset_name=ds_name, union_namespace=args.collection)
1302
+ remapped_qrels, removed_rels = _remove_failed_from_qrels(remapped_qrels, failed_ids)
1303
+ dataset_index_failures[ds_name] = {
1304
+ "failed_log_path": str(failed_path),
1305
+ "failed_ids_count": int(len(failed_ids)),
1306
+ "qrels_removed": int(removed_rels),
1307
+ }
1308
+
1309
+ filter_obj = None
1310
+ if args.evaluation_scope == "per_dataset":
1311
+ from qdrant_client.http import models as qmodels
1312
+
1313
+ filter_obj = qmodels.Filter(
1314
+ must=[qmodels.FieldCondition(key="dataset", match=qmodels.MatchValue(value=str(ds_name)))]
1315
+ )
1316
+
1317
+ try:
1318
+ metrics_by_dataset[ds_name] = _evaluate(
1319
+ queries=queries,
1320
+ qrels=remapped_qrels,
1321
+ retriever=retriever,
1322
+ embedder=embedder,
1323
+ top_k=args.top_k,
1324
+ prefetch_k=args.prefetch_k,
1325
+ mode=args.mode,
1326
+ stage1_mode=args.stage1_mode,
1327
+ stage1_k=int(args.stage1_k),
1328
+ stage2_k=int(args.stage2_k),
1329
+ max_queries=int(args.max_queries),
1330
+ drop_empty_queries=bool(args.drop_empty_queries),
1331
+ filter_obj=filter_obj,
1332
+ )
1333
+ dataset_counts[ds_name]["queries_eval"] = int(metrics_by_dataset[ds_name].get("num_queries_eval", 0))
1334
+ ds_only_out = {
1335
+ **_build_run_record(),
1336
+ "dataset": str(ds_name),
1337
+ "datasets": [str(ds_name)],
1338
+ "metrics": metrics_by_dataset[ds_name],
1339
+ "metrics_by_dataset": {str(ds_name): metrics_by_dataset[ds_name]},
1340
+ }
1341
+ per_ds_path = out_path.with_name(
1342
+ f"{out_path.stem}__{_safe_filename(ds_name)}{out_path.suffix}"
1343
+ )
1344
+ _write_json_atomic(per_ds_path, ds_only_out)
1345
+ print(f"Wrote dataset report: {per_ds_path}")
1346
+ print(json.dumps({str(ds_name): metrics_by_dataset[ds_name]}, indent=2))
1347
+ sys.stdout.flush()
1348
+ except Exception as e:
1349
+ dataset_errors[ds_name] = f"{type(e).__name__}: {e}"
1350
+ if not bool(args.continue_on_error):
1351
+ _write_json_atomic(out_path, _build_run_record())
1352
+ raise
1353
+ finally:
1354
+ _write_json_atomic(out_path, _build_run_record())
1355
+
1356
+ single_dataset = datasets[0] if len(datasets) == 1 else None
1357
+ single_metrics = metrics_by_dataset.get(single_dataset) if single_dataset else None
1358
+ if len(datasets) == 1 and single_metrics is not None:
1359
+ print(json.dumps(single_metrics, indent=2))
1360
+ else:
1361
+ print(json.dumps(metrics_by_dataset, indent=2))
1362
+
1363
+
1364
+ if __name__ == "__main__":
1365
+ main()