visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,799 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import numpy as np
9
+
10
+ from visual_rag import VisualEmbedder
11
+ from visual_rag.embedding.pooling import tile_level_mean_pooling
12
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
13
+ from visual_rag.retrieval import MultiVectorRetriever
14
+
15
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_dataset_auto, paired_doc_id, paired_payload
16
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
17
+
18
+
19
+ def _torch_dtype_to_str(dtype) -> str:
20
+ if dtype is None:
21
+ return "auto"
22
+ s = str(dtype)
23
+ return s.replace("torch.", "")
24
+
25
+
26
+ def _parse_torch_dtype(dtype_str: str):
27
+ if dtype_str == "auto":
28
+ return None
29
+ import torch
30
+
31
+ mapping = {
32
+ "float32": torch.float32,
33
+ "float16": torch.float16,
34
+ "bfloat16": torch.bfloat16,
35
+ }
36
+ return mapping[dtype_str]
37
+
38
+
39
+ def _paired_collate(batch):
40
+ idxs = [b[0] for b in batch]
41
+ images = [b[1] for b in batch]
42
+ metas = [b[2] for b in batch]
43
+ return idxs, images, metas
44
+
45
+
46
+ class _PairedHFDataset:
47
+ def __init__(self, *, dataset_name: str, split: str, total_docs: int, image_col: str):
48
+ self.dataset_name = dataset_name
49
+ self.split = split
50
+ self.total_docs = int(total_docs)
51
+ self.image_col = image_col
52
+ self._ds = None
53
+
54
+ def __len__(self) -> int:
55
+ return self.total_docs
56
+
57
+ def _ensure_loaded(self):
58
+ if self._ds is not None:
59
+ return
60
+ from datasets import load_dataset
61
+
62
+ self._ds = load_dataset(self.dataset_name, split=self.split)
63
+
64
+ def __getitem__(self, idx: int):
65
+ self._ensure_loaded()
66
+ row = self._ds[int(idx)]
67
+ image = row[self.image_col]
68
+ meta = {k: v for k, v in row.items() if k != self.image_col}
69
+ return int(idx), image, meta
70
+
71
+
72
+ def _ensure_env(name: str) -> str:
73
+ value = os.getenv(name)
74
+ if not value:
75
+ raise ValueError(f"Missing env var: {name}")
76
+ return value
77
+
78
+
79
+ def _maybe_load_dotenv() -> None:
80
+ try:
81
+ from dotenv import load_dotenv
82
+ except ImportError:
83
+ return
84
+ if Path(".env").exists():
85
+ load_dotenv(".env")
86
+
87
+
88
+ def _index_corpus(
89
+ *,
90
+ dataset_name: str,
91
+ collection_name: str,
92
+ corpus: List[Any],
93
+ embedder: VisualEmbedder,
94
+ qdrant_url: str,
95
+ qdrant_api_key: Optional[str],
96
+ prefer_grpc: bool,
97
+ qdrant_vector_dtype: str,
98
+ recreate: bool,
99
+ batch_size: int,
100
+ upload_batch_size: int,
101
+ upload_workers: int,
102
+ upsert_wait: bool,
103
+ indexing_threshold: int,
104
+ full_scan_threshold: int,
105
+ ) -> None:
106
+ indexer = QdrantIndexer(
107
+ url=qdrant_url,
108
+ api_key=qdrant_api_key,
109
+ collection_name=collection_name,
110
+ prefer_grpc=prefer_grpc,
111
+ vector_datatype=qdrant_vector_dtype,
112
+ )
113
+ indexer.create_collection(
114
+ force_recreate=recreate,
115
+ indexing_threshold=indexing_threshold,
116
+ full_scan_threshold=int(full_scan_threshold),
117
+ )
118
+ indexer.create_payload_indexes(
119
+ fields=[
120
+ {"field": "dataset", "type": "keyword"},
121
+ {"field": "doc_id", "type": "keyword"},
122
+ {"field": "torch_dtype", "type": "keyword"},
123
+ {"field": "source", "type": "keyword"},
124
+ {"field": "image_filename", "type": "keyword"},
125
+ {"field": "page", "type": "keyword"},
126
+ {"field": "source_doc_id", "type": "keyword"},
127
+ ]
128
+ )
129
+
130
+ total_docs = len(corpus)
131
+ embedded_docs = 0
132
+ enqueued_docs = 0
133
+ uploaded_docs = 0
134
+ start_time = time.time()
135
+ last_tick_time = start_time
136
+ last_tick_docs = 0
137
+
138
+ pbar = None
139
+ try:
140
+ from tqdm import tqdm
141
+
142
+ pbar = tqdm(total=total_docs, desc="📦 Indexing", unit="doc")
143
+ except ImportError:
144
+ pass
145
+
146
+ def _upload(points: List[Dict[str, Any]]) -> int:
147
+ return indexer.upload_batch(points, delay_between_batches=0.0, wait=upsert_wait, stop_event=stop_event)
148
+
149
+ executor = None
150
+ futures = []
151
+ import threading
152
+
153
+ stop_event = threading.Event()
154
+ if upload_workers and upload_workers > 0:
155
+ from concurrent.futures import ThreadPoolExecutor, wait as futures_wait, FIRST_EXCEPTION
156
+
157
+ executor = ThreadPoolExecutor(max_workers=upload_workers)
158
+
159
+ def _drain(block: bool = False) -> None:
160
+ nonlocal uploaded_docs
161
+ nonlocal last_tick_time
162
+ nonlocal last_tick_docs
163
+ if not futures:
164
+ return
165
+ if block:
166
+ done, _ = futures_wait(futures, return_when=FIRST_EXCEPTION)
167
+ else:
168
+ done, _ = futures_wait(futures, timeout=0, return_when=FIRST_EXCEPTION)
169
+ for d in list(done):
170
+ futures.remove(d)
171
+ uploaded_docs += int(d.result() or 0)
172
+ if pbar is not None:
173
+ now = time.time()
174
+ done_docs = int(pbar.n)
175
+ elapsed = max(now - start_time, 1e-9)
176
+ avg_s_per_doc = elapsed / max(done_docs, 1)
177
+ delta_docs = done_docs - last_tick_docs
178
+ delta_t = max(now - last_tick_time, 1e-9)
179
+ last_s_per_doc = delta_t / max(delta_docs, 1)
180
+ last_tick_time = now
181
+ last_tick_docs = done_docs
182
+ pbar.set_postfix(
183
+ {
184
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
185
+ "last_s/doc": f"{last_s_per_doc:.2f}",
186
+ "buffer": len(points_buffer),
187
+ "enq": enqueued_docs,
188
+ "upl": uploaded_docs,
189
+ "pending": len(futures),
190
+ }
191
+ )
192
+
193
+ points_buffer: List[Dict[str, Any]] = []
194
+ try:
195
+ for start in range(0, len(corpus), batch_size):
196
+ batch = corpus[start : start + batch_size]
197
+ images = [d.image for d in batch]
198
+ embeddings, token_infos = embedder.embed_images(
199
+ images,
200
+ batch_size=batch_size,
201
+ return_token_info=True,
202
+ show_progress=False,
203
+ )
204
+ embedded_docs += len(batch)
205
+ if pbar is not None:
206
+ pbar.update(len(batch))
207
+ now = time.time()
208
+ done_docs = int(pbar.n)
209
+ elapsed = max(now - start_time, 1e-9)
210
+ avg_s_per_doc = elapsed / max(done_docs, 1)
211
+ delta_docs = done_docs - last_tick_docs
212
+ delta_t = max(now - last_tick_time, 1e-9)
213
+ last_s_per_doc = delta_t / max(delta_docs, 1)
214
+ last_tick_time = now
215
+ last_tick_docs = done_docs
216
+ pbar.set_postfix(
217
+ {
218
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
219
+ "last_s/doc": f"{last_s_per_doc:.2f}",
220
+ "buffer": len(points_buffer),
221
+ "enq": enqueued_docs,
222
+ "upl": uploaded_docs,
223
+ "pending": len(futures),
224
+ }
225
+ )
226
+
227
+ for doc, emb, token_info in zip(batch, embeddings, token_infos):
228
+ if doc.image is None:
229
+ raise ValueError("CorpusDoc.image is None. For paired datasets, use _index_paired_dataset().")
230
+ emb_np = emb.cpu().float().numpy() if hasattr(emb, "cpu") else np.array(emb, dtype=np.float32)
231
+ visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
232
+ visual_embedding = emb_np[visual_indices].astype(np.float32)
233
+
234
+ n_rows = token_info.get("n_rows")
235
+ n_cols = token_info.get("n_cols")
236
+ if n_rows and n_cols:
237
+ num_tiles = int(n_rows) * int(n_cols) + 1
238
+ else:
239
+ num_tiles = 13
240
+
241
+ tile_pooled = tile_level_mean_pooling(visual_embedding, num_tiles=num_tiles, patches_per_tile=64)
242
+ global_pooled = tile_pooled.mean(axis=0).astype(np.float32)
243
+
244
+ payload = {
245
+ "dataset": dataset_name,
246
+ "doc_id": doc.doc_id,
247
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
248
+ **(doc.payload or {}),
249
+ }
250
+
251
+ points_buffer.append(
252
+ {
253
+ "id": doc.doc_id,
254
+ "visual_embedding": visual_embedding,
255
+ "tile_pooled_embedding": tile_pooled,
256
+ "global_pooled_embedding": global_pooled,
257
+ "metadata": payload,
258
+ }
259
+ )
260
+
261
+ if len(points_buffer) >= upload_batch_size:
262
+ chunk = points_buffer
263
+ points_buffer = []
264
+ enqueued_docs += len(chunk)
265
+ if executor is None:
266
+ uploaded_docs += int(_upload(chunk) or 0)
267
+ else:
268
+ futures.append(executor.submit(_upload, chunk))
269
+ _drain(block=len(futures) >= upload_workers * 2)
270
+ if pbar is not None:
271
+ pbar.set_postfix(
272
+ {
273
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
274
+ "last_s/doc": f"{last_s_per_doc:.2f}",
275
+ "buffer": len(points_buffer),
276
+ "enq": enqueued_docs,
277
+ "upl": uploaded_docs,
278
+ "pending": len(futures),
279
+ }
280
+ )
281
+ if executor is not None:
282
+ _drain(block=False)
283
+ except KeyboardInterrupt:
284
+ stop_event.set()
285
+ if executor is not None:
286
+ executor.shutdown(wait=False, cancel_futures=True)
287
+ raise
288
+
289
+ if points_buffer:
290
+ enqueued_docs += len(points_buffer)
291
+ if executor is None:
292
+ uploaded_docs += int(_upload(points_buffer) or 0)
293
+ else:
294
+ futures.append(executor.submit(_upload, points_buffer))
295
+
296
+ if executor is not None:
297
+ _drain(block=True)
298
+ executor.shutdown(wait=True)
299
+
300
+ if pbar is not None:
301
+ pbar.set_postfix(
302
+ {
303
+ "avg_s/doc": f"{(max(time.time() - start_time, 1e-9) / max(int(pbar.n), 1)):.2f}",
304
+ "last_s/doc": "n/a",
305
+ "buffer": 0,
306
+ "enq": enqueued_docs,
307
+ "upl": uploaded_docs,
308
+ "pending": 0,
309
+ }
310
+ )
311
+ pbar.close()
312
+
313
+
314
+ def _index_paired_dataset(
315
+ *,
316
+ dataset_name: str,
317
+ collection_name: str,
318
+ total_docs: int,
319
+ embedder: VisualEmbedder,
320
+ qdrant_url: str,
321
+ qdrant_api_key: Optional[str],
322
+ prefer_grpc: bool,
323
+ qdrant_vector_dtype: str,
324
+ recreate: bool,
325
+ batch_size: int,
326
+ upload_batch_size: int,
327
+ upload_workers: int,
328
+ upsert_wait: bool,
329
+ loader_workers: int,
330
+ prefetch_factor: int,
331
+ persistent_workers: bool,
332
+ pin_memory: bool,
333
+ use_dataloader: bool,
334
+ indexing_threshold: int,
335
+ full_scan_threshold: int,
336
+ ) -> None:
337
+ try:
338
+ from datasets import load_dataset
339
+ except ImportError as e:
340
+ raise ImportError("datasets is required. Install with: pip install datasets") from e
341
+
342
+ try:
343
+ import torch
344
+ from torch.utils.data import DataLoader
345
+ except ImportError as e:
346
+ raise ImportError("torch is required. Install with: pip install visual-rag-toolkit[embedding]") from e
347
+
348
+ ds0 = load_dataset(dataset_name, split="test")
349
+ cols = set(ds0.column_names)
350
+ image_col = "image" if "image" in cols else "page_image"
351
+ del ds0
352
+
353
+ indexer = QdrantIndexer(
354
+ url=qdrant_url,
355
+ api_key=qdrant_api_key,
356
+ collection_name=collection_name,
357
+ prefer_grpc=prefer_grpc,
358
+ vector_datatype=qdrant_vector_dtype,
359
+ )
360
+ indexer.create_collection(
361
+ force_recreate=recreate,
362
+ indexing_threshold=indexing_threshold,
363
+ full_scan_threshold=int(full_scan_threshold),
364
+ )
365
+ indexer.create_payload_indexes(
366
+ fields=[
367
+ {"field": "dataset", "type": "keyword"},
368
+ {"field": "doc_id", "type": "keyword"},
369
+ {"field": "torch_dtype", "type": "keyword"},
370
+ {"field": "source", "type": "keyword"},
371
+ {"field": "image_filename", "type": "keyword"},
372
+ {"field": "page", "type": "keyword"},
373
+ {"field": "source_doc_id", "type": "keyword"},
374
+ ]
375
+ )
376
+
377
+ enqueued_docs = 0
378
+ uploaded_docs = 0
379
+ start_time = time.time()
380
+ last_tick_time = start_time
381
+ last_tick_docs = 0
382
+
383
+ pbar = None
384
+ try:
385
+ from tqdm import tqdm
386
+
387
+ pbar = tqdm(total=total_docs, desc="📦 Indexing", unit="doc")
388
+ except ImportError:
389
+ pass
390
+
391
+ def _upload(points: List[Dict[str, Any]]) -> int:
392
+ return indexer.upload_batch(points, delay_between_batches=0.0, wait=upsert_wait, stop_event=stop_event)
393
+
394
+ executor = None
395
+ futures = []
396
+ import threading
397
+
398
+ stop_event = threading.Event()
399
+ if upload_workers and upload_workers > 0:
400
+ from concurrent.futures import ThreadPoolExecutor, wait as futures_wait, FIRST_EXCEPTION
401
+
402
+ executor = ThreadPoolExecutor(max_workers=upload_workers)
403
+
404
+ def _drain(block: bool = False) -> None:
405
+ nonlocal uploaded_docs
406
+ nonlocal last_tick_time
407
+ nonlocal last_tick_docs
408
+ if not futures:
409
+ return
410
+ if block:
411
+ done, _ = futures_wait(futures, return_when=FIRST_EXCEPTION)
412
+ else:
413
+ done, _ = futures_wait(futures, timeout=0, return_when=FIRST_EXCEPTION)
414
+ for d in list(done):
415
+ futures.remove(d)
416
+ uploaded_docs += int(d.result() or 0)
417
+ if pbar is not None:
418
+ now = time.time()
419
+ done_docs = int(pbar.n)
420
+ elapsed = max(now - start_time, 1e-9)
421
+ avg_s_per_doc = elapsed / max(done_docs, 1)
422
+ delta_docs = done_docs - last_tick_docs
423
+ delta_t = max(now - last_tick_time, 1e-9)
424
+ last_s_per_doc = delta_t / max(delta_docs, 1)
425
+ last_tick_time = now
426
+ last_tick_docs = done_docs
427
+ pbar.set_postfix(
428
+ {
429
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
430
+ "last_s/doc": f"{last_s_per_doc:.2f}",
431
+ "buffer": len(points_buffer),
432
+ "enq": enqueued_docs,
433
+ "upl": uploaded_docs,
434
+ "pending": len(futures),
435
+ }
436
+ )
437
+
438
+ points_buffer: List[Dict[str, Any]] = []
439
+ try:
440
+ if use_dataloader or (loader_workers and loader_workers > 0):
441
+ dl_kwargs = {"batch_size": batch_size, "shuffle": False, "collate_fn": _paired_collate}
442
+ if loader_workers and loader_workers > 0:
443
+ dl_kwargs["num_workers"] = int(loader_workers)
444
+ dl_kwargs["prefetch_factor"] = int(prefetch_factor)
445
+ dl_kwargs["persistent_workers"] = bool(persistent_workers)
446
+ dl_kwargs["pin_memory"] = bool(pin_memory and torch.cuda.is_available())
447
+
448
+ data_loader = DataLoader(
449
+ _PairedHFDataset(dataset_name=dataset_name, split="test", total_docs=total_docs, image_col=image_col),
450
+ **dl_kwargs,
451
+ )
452
+ iterable = ((idxs, images, metas) for (idxs, images, metas) in data_loader)
453
+ else:
454
+ ds = load_dataset(dataset_name, split="test")
455
+
456
+ def _iter_batches():
457
+ for start in range(0, total_docs, batch_size):
458
+ batch = ds[start : start + batch_size]
459
+ images = batch[image_col]
460
+ metas = [{k: batch[k][i] for k in batch.keys() if k != image_col} for i in range(len(images))]
461
+ idxs = list(range(start, start + len(images)))
462
+ yield idxs, images, metas
463
+
464
+ iterable = _iter_batches()
465
+
466
+ for idxs, images, metas in iterable:
467
+ embeddings, token_infos = embedder.embed_images(
468
+ images,
469
+ batch_size=batch_size,
470
+ return_token_info=True,
471
+ show_progress=False,
472
+ )
473
+ if pbar is not None:
474
+ pbar.update(len(images))
475
+ now = time.time()
476
+ done_docs = int(pbar.n)
477
+ elapsed = max(now - start_time, 1e-9)
478
+ avg_s_per_doc = elapsed / max(done_docs, 1)
479
+ delta_docs = done_docs - last_tick_docs
480
+ delta_t = max(now - last_tick_time, 1e-9)
481
+ last_s_per_doc = delta_t / max(delta_docs, 1)
482
+ last_tick_time = now
483
+ last_tick_docs = done_docs
484
+ pbar.set_postfix(
485
+ {
486
+ "avg_s/doc": f"{avg_s_per_doc:.2f}",
487
+ "last_s/doc": f"{last_s_per_doc:.2f}",
488
+ "buffer": len(points_buffer),
489
+ "enq": enqueued_docs,
490
+ "upl": uploaded_docs,
491
+ "pending": len(futures),
492
+ }
493
+ )
494
+
495
+ for idx, meta, emb, token_info in zip(idxs, metas, embeddings, token_infos):
496
+ doc_id = paired_doc_id(meta, int(idx))
497
+ payload = {
498
+ "dataset": dataset_name,
499
+ "doc_id": doc_id,
500
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
501
+ **paired_payload(meta, int(idx)),
502
+ }
503
+
504
+ emb_np = emb.cpu().float().numpy() if hasattr(emb, "cpu") else np.array(emb, dtype=np.float32)
505
+ visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
506
+ visual_embedding = emb_np[visual_indices].astype(np.float32)
507
+
508
+ n_rows = token_info.get("n_rows")
509
+ n_cols = token_info.get("n_cols")
510
+ num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
511
+
512
+ tile_pooled = tile_level_mean_pooling(visual_embedding, num_tiles=num_tiles, patches_per_tile=64)
513
+ global_pooled = tile_pooled.mean(axis=0).astype(np.float32)
514
+
515
+ points_buffer.append(
516
+ {
517
+ "id": doc_id,
518
+ "visual_embedding": visual_embedding,
519
+ "tile_pooled_embedding": tile_pooled,
520
+ "global_pooled_embedding": global_pooled,
521
+ "metadata": payload,
522
+ }
523
+ )
524
+
525
+ if len(points_buffer) >= upload_batch_size:
526
+ chunk = points_buffer
527
+ points_buffer = []
528
+ enqueued_docs += len(chunk)
529
+ if executor is None:
530
+ uploaded_docs += int(_upload(chunk) or 0)
531
+ else:
532
+ futures.append(executor.submit(_upload, chunk))
533
+ _drain(block=len(futures) >= upload_workers * 2)
534
+ if executor is not None:
535
+ _drain(block=False)
536
+ except KeyboardInterrupt:
537
+ stop_event.set()
538
+ if executor is not None:
539
+ executor.shutdown(wait=False, cancel_futures=True)
540
+ raise
541
+
542
+ if points_buffer:
543
+ enqueued_docs += len(points_buffer)
544
+ if executor is None:
545
+ uploaded_docs += int(_upload(points_buffer) or 0)
546
+ else:
547
+ futures.append(executor.submit(_upload, points_buffer))
548
+
549
+ if executor is not None:
550
+ _drain(block=True)
551
+ executor.shutdown(wait=True)
552
+
553
+ if pbar is not None:
554
+ pbar.set_postfix(
555
+ {
556
+ "avg_s/doc": f"{(max(time.time() - start_time, 1e-9) / max(int(pbar.n), 1)):.2f}",
557
+ "last_s/doc": "n/a",
558
+ "buffer": 0,
559
+ "enq": enqueued_docs,
560
+ "upl": uploaded_docs,
561
+ "pending": 0,
562
+ }
563
+ )
564
+ pbar.close()
565
+
566
+
567
+ def _evaluate(
568
+ *,
569
+ queries: List[Any],
570
+ qrels: Dict[str, Dict[str, int]],
571
+ retriever: MultiVectorRetriever,
572
+ top_k: int,
573
+ prefetch_k: int,
574
+ mode: str,
575
+ stage1_mode: str,
576
+ ) -> Dict[str, float]:
577
+ ndcg10: List[float] = []
578
+ mrr10: List[float] = []
579
+ recall10: List[float] = []
580
+ recall5: List[float] = []
581
+ latencies_ms: List[float] = []
582
+
583
+ for q in queries:
584
+ start = time.time()
585
+ results = retriever.search(
586
+ query=q.text,
587
+ top_k=top_k,
588
+ mode=mode,
589
+ prefetch_k=prefetch_k,
590
+ stage1_mode=stage1_mode,
591
+ )
592
+ latencies_ms.append((time.time() - start) * 1000.0)
593
+
594
+ ranking = [str(r["id"]) for r in results]
595
+ rels = qrels.get(q.query_id, {})
596
+
597
+ ndcg10.append(ndcg_at_k(ranking, rels, k=10))
598
+ mrr10.append(mrr_at_k(ranking, rels, k=10))
599
+ recall5.append(recall_at_k(ranking, rels, k=5))
600
+ recall10.append(recall_at_k(ranking, rels, k=10))
601
+
602
+ return {
603
+ "ndcg@10": float(np.mean(ndcg10)),
604
+ "mrr@10": float(np.mean(mrr10)),
605
+ "recall@5": float(np.mean(recall5)),
606
+ "recall@10": float(np.mean(recall10)),
607
+ "avg_latency_ms": float(np.mean(latencies_ms)),
608
+ "p95_latency_ms": float(np.percentile(latencies_ms, 95)),
609
+ }
610
+
611
+
612
+ def main() -> None:
613
+ parser = argparse.ArgumentParser()
614
+ parser.add_argument("--dataset", type=str, default="vidore/tatdqa_test")
615
+ parser.add_argument("--collection", type=str, default="vidore_tatdqa_test")
616
+ parser.add_argument("--model", type=str, default="vidore/colSmol-500M")
617
+ parser.add_argument(
618
+ "--torch-dtype",
619
+ type=str,
620
+ default="auto",
621
+ choices=["auto", "float32", "float16", "bfloat16"],
622
+ help="Torch dtype for model weights (default: auto; CUDA->bfloat16, else float32).",
623
+ )
624
+ parser.add_argument(
625
+ "--qdrant-vector-dtype",
626
+ type=str,
627
+ default="float16",
628
+ choices=["float16", "float32"],
629
+ help="Datatype for vectors stored in Qdrant (default: float16).",
630
+ )
631
+ parser.add_argument("--batch-size", type=int, default=4)
632
+ parser.add_argument("--upload-batch-size", type=int, default=None)
633
+ parser.add_argument("--upload-workers", type=int, default=0)
634
+ wait_group = parser.add_mutually_exclusive_group()
635
+ wait_group.add_argument(
636
+ "--upsert-wait",
637
+ action="store_true",
638
+ help="Wait for Qdrant upserts to complete before continuing (default: false).",
639
+ )
640
+ wait_group.add_argument(
641
+ "--no-upsert-wait",
642
+ action="store_true",
643
+ help="Deprecated (default is already no-wait). Kept for backwards compatibility.",
644
+ )
645
+ parser.add_argument("--loader-workers", type=int, default=0)
646
+ parser.add_argument("--prefetch-factor", type=int, default=2)
647
+ parser.add_argument("--persistent-workers", action="store_true")
648
+ parser.add_argument("--pin-memory", action="store_true")
649
+ parser.add_argument(
650
+ "--use-dataloader",
651
+ action="store_true",
652
+ help="Use torch DataLoader even with --loader-workers 0 (default: false).",
653
+ )
654
+ grpc_group = parser.add_mutually_exclusive_group()
655
+ grpc_group.add_argument("--prefer-grpc", dest="prefer_grpc", action="store_true", default=True)
656
+ grpc_group.add_argument("--no-prefer-grpc", dest="prefer_grpc", action="store_false")
657
+ parser.add_argument("--index", action="store_true", help="Index corpus into Qdrant before evaluating")
658
+ parser.add_argument("--recreate", action="store_true", help="Delete and recreate the collection (implies --index)")
659
+ parser.add_argument(
660
+ "--indexing-threshold",
661
+ type=int,
662
+ default=0,
663
+ help="Qdrant optimizer indexing threshold (0 = always build indexes).",
664
+ )
665
+ parser.add_argument(
666
+ "--full-scan-threshold",
667
+ type=int,
668
+ default=0,
669
+ help="Qdrant HNSW full_scan_threshold (0 = always use HNSW).",
670
+ )
671
+ parser.add_argument("--top-k", type=int, default=10)
672
+ parser.add_argument("--prefetch-k", type=int, default=200)
673
+ parser.add_argument(
674
+ "--mode",
675
+ type=str,
676
+ default="single_full",
677
+ choices=["single_full", "single_tiles", "single_global", "two_stage"],
678
+ )
679
+ parser.add_argument(
680
+ "--stage1-mode",
681
+ type=str,
682
+ default="tokens_vs_tiles",
683
+ choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
684
+ )
685
+ parser.add_argument("--output", type=str, default="results/qdrant_vidore_tatdqa_test.json")
686
+
687
+ args = parser.parse_args()
688
+
689
+ _maybe_load_dotenv()
690
+
691
+ qdrant_url = _ensure_env("QDRANT_URL")
692
+ qdrant_api_key = os.getenv("QDRANT_API_KEY")
693
+ upload_batch_size = args.upload_batch_size or args.batch_size
694
+ upsert_wait = bool(args.upsert_wait)
695
+
696
+ if upsert_wait:
697
+ print("Qdrant upserts wait for completion (wait=True).")
698
+ else:
699
+ print("Qdrant upserts are async (wait=False).")
700
+
701
+ corpus, queries, qrels, protocol = load_vidore_dataset_auto(args.dataset)
702
+
703
+ embedder = VisualEmbedder(
704
+ model_name=args.model,
705
+ batch_size=args.batch_size,
706
+ torch_dtype=_parse_torch_dtype(args.torch_dtype),
707
+ )
708
+
709
+ if args.recreate:
710
+ args.index = True
711
+
712
+ if args.index:
713
+ if protocol == "paired":
714
+ _index_paired_dataset(
715
+ dataset_name=args.dataset,
716
+ collection_name=args.collection,
717
+ total_docs=len(corpus),
718
+ embedder=embedder,
719
+ qdrant_url=qdrant_url,
720
+ qdrant_api_key=qdrant_api_key,
721
+ prefer_grpc=args.prefer_grpc,
722
+ qdrant_vector_dtype=args.qdrant_vector_dtype,
723
+ recreate=args.recreate,
724
+ batch_size=args.batch_size,
725
+ upload_batch_size=upload_batch_size,
726
+ upload_workers=args.upload_workers,
727
+ upsert_wait=upsert_wait,
728
+ loader_workers=args.loader_workers,
729
+ prefetch_factor=args.prefetch_factor,
730
+ persistent_workers=args.persistent_workers,
731
+ pin_memory=args.pin_memory,
732
+ use_dataloader=args.use_dataloader,
733
+ indexing_threshold=args.indexing_threshold,
734
+ full_scan_threshold=args.full_scan_threshold,
735
+ )
736
+ else:
737
+ _index_corpus(
738
+ dataset_name=args.dataset,
739
+ collection_name=args.collection,
740
+ corpus=corpus,
741
+ embedder=embedder,
742
+ qdrant_url=qdrant_url,
743
+ qdrant_api_key=qdrant_api_key,
744
+ prefer_grpc=args.prefer_grpc,
745
+ qdrant_vector_dtype=args.qdrant_vector_dtype,
746
+ recreate=args.recreate,
747
+ batch_size=args.batch_size,
748
+ upload_batch_size=upload_batch_size,
749
+ upload_workers=args.upload_workers,
750
+ upsert_wait=upsert_wait,
751
+ indexing_threshold=args.indexing_threshold,
752
+ full_scan_threshold=args.full_scan_threshold,
753
+ )
754
+
755
+ retriever = MultiVectorRetriever(
756
+ collection_name=args.collection,
757
+ embedder=embedder,
758
+ qdrant_url=qdrant_url,
759
+ qdrant_api_key=qdrant_api_key,
760
+ prefer_grpc=args.prefer_grpc,
761
+ )
762
+
763
+ metrics = _evaluate(
764
+ queries=queries,
765
+ qrels=qrels,
766
+ retriever=retriever,
767
+ top_k=args.top_k,
768
+ prefetch_k=args.prefetch_k,
769
+ mode=args.mode,
770
+ stage1_mode=args.stage1_mode,
771
+ )
772
+
773
+ out_path = Path(args.output)
774
+ out_path.parent.mkdir(parents=True, exist_ok=True)
775
+ with open(out_path, "w") as f:
776
+ json.dump(
777
+ {
778
+ "dataset": args.dataset,
779
+ "protocol": protocol,
780
+ "collection": args.collection,
781
+ "model": args.model,
782
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
783
+ "qdrant_vector_dtype": args.qdrant_vector_dtype,
784
+ "mode": args.mode,
785
+ "stage1_mode": args.stage1_mode if args.mode == "two_stage" else None,
786
+ "prefetch_k": args.prefetch_k if args.mode == "two_stage" else None,
787
+ "metrics": metrics,
788
+ },
789
+ f,
790
+ indent=2,
791
+ )
792
+
793
+ print(json.dumps(metrics, indent=2))
794
+
795
+
796
+ if __name__ == "__main__":
797
+ main()
798
+
799
+