visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,799 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from visual_rag import VisualEmbedder
|
|
11
|
+
from visual_rag.embedding.pooling import tile_level_mean_pooling
|
|
12
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
13
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
14
|
+
|
|
15
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_dataset_auto, paired_doc_id, paired_payload
|
|
16
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _torch_dtype_to_str(dtype) -> str:
|
|
20
|
+
if dtype is None:
|
|
21
|
+
return "auto"
|
|
22
|
+
s = str(dtype)
|
|
23
|
+
return s.replace("torch.", "")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _parse_torch_dtype(dtype_str: str):
|
|
27
|
+
if dtype_str == "auto":
|
|
28
|
+
return None
|
|
29
|
+
import torch
|
|
30
|
+
|
|
31
|
+
mapping = {
|
|
32
|
+
"float32": torch.float32,
|
|
33
|
+
"float16": torch.float16,
|
|
34
|
+
"bfloat16": torch.bfloat16,
|
|
35
|
+
}
|
|
36
|
+
return mapping[dtype_str]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _paired_collate(batch):
|
|
40
|
+
idxs = [b[0] for b in batch]
|
|
41
|
+
images = [b[1] for b in batch]
|
|
42
|
+
metas = [b[2] for b in batch]
|
|
43
|
+
return idxs, images, metas
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _PairedHFDataset:
|
|
47
|
+
def __init__(self, *, dataset_name: str, split: str, total_docs: int, image_col: str):
|
|
48
|
+
self.dataset_name = dataset_name
|
|
49
|
+
self.split = split
|
|
50
|
+
self.total_docs = int(total_docs)
|
|
51
|
+
self.image_col = image_col
|
|
52
|
+
self._ds = None
|
|
53
|
+
|
|
54
|
+
def __len__(self) -> int:
|
|
55
|
+
return self.total_docs
|
|
56
|
+
|
|
57
|
+
def _ensure_loaded(self):
|
|
58
|
+
if self._ds is not None:
|
|
59
|
+
return
|
|
60
|
+
from datasets import load_dataset
|
|
61
|
+
|
|
62
|
+
self._ds = load_dataset(self.dataset_name, split=self.split)
|
|
63
|
+
|
|
64
|
+
def __getitem__(self, idx: int):
|
|
65
|
+
self._ensure_loaded()
|
|
66
|
+
row = self._ds[int(idx)]
|
|
67
|
+
image = row[self.image_col]
|
|
68
|
+
meta = {k: v for k, v in row.items() if k != self.image_col}
|
|
69
|
+
return int(idx), image, meta
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _ensure_env(name: str) -> str:
|
|
73
|
+
value = os.getenv(name)
|
|
74
|
+
if not value:
|
|
75
|
+
raise ValueError(f"Missing env var: {name}")
|
|
76
|
+
return value
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _maybe_load_dotenv() -> None:
|
|
80
|
+
try:
|
|
81
|
+
from dotenv import load_dotenv
|
|
82
|
+
except ImportError:
|
|
83
|
+
return
|
|
84
|
+
if Path(".env").exists():
|
|
85
|
+
load_dotenv(".env")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _index_corpus(
|
|
89
|
+
*,
|
|
90
|
+
dataset_name: str,
|
|
91
|
+
collection_name: str,
|
|
92
|
+
corpus: List[Any],
|
|
93
|
+
embedder: VisualEmbedder,
|
|
94
|
+
qdrant_url: str,
|
|
95
|
+
qdrant_api_key: Optional[str],
|
|
96
|
+
prefer_grpc: bool,
|
|
97
|
+
qdrant_vector_dtype: str,
|
|
98
|
+
recreate: bool,
|
|
99
|
+
batch_size: int,
|
|
100
|
+
upload_batch_size: int,
|
|
101
|
+
upload_workers: int,
|
|
102
|
+
upsert_wait: bool,
|
|
103
|
+
indexing_threshold: int,
|
|
104
|
+
full_scan_threshold: int,
|
|
105
|
+
) -> None:
|
|
106
|
+
indexer = QdrantIndexer(
|
|
107
|
+
url=qdrant_url,
|
|
108
|
+
api_key=qdrant_api_key,
|
|
109
|
+
collection_name=collection_name,
|
|
110
|
+
prefer_grpc=prefer_grpc,
|
|
111
|
+
vector_datatype=qdrant_vector_dtype,
|
|
112
|
+
)
|
|
113
|
+
indexer.create_collection(
|
|
114
|
+
force_recreate=recreate,
|
|
115
|
+
indexing_threshold=indexing_threshold,
|
|
116
|
+
full_scan_threshold=int(full_scan_threshold),
|
|
117
|
+
)
|
|
118
|
+
indexer.create_payload_indexes(
|
|
119
|
+
fields=[
|
|
120
|
+
{"field": "dataset", "type": "keyword"},
|
|
121
|
+
{"field": "doc_id", "type": "keyword"},
|
|
122
|
+
{"field": "torch_dtype", "type": "keyword"},
|
|
123
|
+
{"field": "source", "type": "keyword"},
|
|
124
|
+
{"field": "image_filename", "type": "keyword"},
|
|
125
|
+
{"field": "page", "type": "keyword"},
|
|
126
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
127
|
+
]
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
total_docs = len(corpus)
|
|
131
|
+
embedded_docs = 0
|
|
132
|
+
enqueued_docs = 0
|
|
133
|
+
uploaded_docs = 0
|
|
134
|
+
start_time = time.time()
|
|
135
|
+
last_tick_time = start_time
|
|
136
|
+
last_tick_docs = 0
|
|
137
|
+
|
|
138
|
+
pbar = None
|
|
139
|
+
try:
|
|
140
|
+
from tqdm import tqdm
|
|
141
|
+
|
|
142
|
+
pbar = tqdm(total=total_docs, desc="📦 Indexing", unit="doc")
|
|
143
|
+
except ImportError:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
def _upload(points: List[Dict[str, Any]]) -> int:
|
|
147
|
+
return indexer.upload_batch(points, delay_between_batches=0.0, wait=upsert_wait, stop_event=stop_event)
|
|
148
|
+
|
|
149
|
+
executor = None
|
|
150
|
+
futures = []
|
|
151
|
+
import threading
|
|
152
|
+
|
|
153
|
+
stop_event = threading.Event()
|
|
154
|
+
if upload_workers and upload_workers > 0:
|
|
155
|
+
from concurrent.futures import ThreadPoolExecutor, wait as futures_wait, FIRST_EXCEPTION
|
|
156
|
+
|
|
157
|
+
executor = ThreadPoolExecutor(max_workers=upload_workers)
|
|
158
|
+
|
|
159
|
+
def _drain(block: bool = False) -> None:
|
|
160
|
+
nonlocal uploaded_docs
|
|
161
|
+
nonlocal last_tick_time
|
|
162
|
+
nonlocal last_tick_docs
|
|
163
|
+
if not futures:
|
|
164
|
+
return
|
|
165
|
+
if block:
|
|
166
|
+
done, _ = futures_wait(futures, return_when=FIRST_EXCEPTION)
|
|
167
|
+
else:
|
|
168
|
+
done, _ = futures_wait(futures, timeout=0, return_when=FIRST_EXCEPTION)
|
|
169
|
+
for d in list(done):
|
|
170
|
+
futures.remove(d)
|
|
171
|
+
uploaded_docs += int(d.result() or 0)
|
|
172
|
+
if pbar is not None:
|
|
173
|
+
now = time.time()
|
|
174
|
+
done_docs = int(pbar.n)
|
|
175
|
+
elapsed = max(now - start_time, 1e-9)
|
|
176
|
+
avg_s_per_doc = elapsed / max(done_docs, 1)
|
|
177
|
+
delta_docs = done_docs - last_tick_docs
|
|
178
|
+
delta_t = max(now - last_tick_time, 1e-9)
|
|
179
|
+
last_s_per_doc = delta_t / max(delta_docs, 1)
|
|
180
|
+
last_tick_time = now
|
|
181
|
+
last_tick_docs = done_docs
|
|
182
|
+
pbar.set_postfix(
|
|
183
|
+
{
|
|
184
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
185
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
186
|
+
"buffer": len(points_buffer),
|
|
187
|
+
"enq": enqueued_docs,
|
|
188
|
+
"upl": uploaded_docs,
|
|
189
|
+
"pending": len(futures),
|
|
190
|
+
}
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
points_buffer: List[Dict[str, Any]] = []
|
|
194
|
+
try:
|
|
195
|
+
for start in range(0, len(corpus), batch_size):
|
|
196
|
+
batch = corpus[start : start + batch_size]
|
|
197
|
+
images = [d.image for d in batch]
|
|
198
|
+
embeddings, token_infos = embedder.embed_images(
|
|
199
|
+
images,
|
|
200
|
+
batch_size=batch_size,
|
|
201
|
+
return_token_info=True,
|
|
202
|
+
show_progress=False,
|
|
203
|
+
)
|
|
204
|
+
embedded_docs += len(batch)
|
|
205
|
+
if pbar is not None:
|
|
206
|
+
pbar.update(len(batch))
|
|
207
|
+
now = time.time()
|
|
208
|
+
done_docs = int(pbar.n)
|
|
209
|
+
elapsed = max(now - start_time, 1e-9)
|
|
210
|
+
avg_s_per_doc = elapsed / max(done_docs, 1)
|
|
211
|
+
delta_docs = done_docs - last_tick_docs
|
|
212
|
+
delta_t = max(now - last_tick_time, 1e-9)
|
|
213
|
+
last_s_per_doc = delta_t / max(delta_docs, 1)
|
|
214
|
+
last_tick_time = now
|
|
215
|
+
last_tick_docs = done_docs
|
|
216
|
+
pbar.set_postfix(
|
|
217
|
+
{
|
|
218
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
219
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
220
|
+
"buffer": len(points_buffer),
|
|
221
|
+
"enq": enqueued_docs,
|
|
222
|
+
"upl": uploaded_docs,
|
|
223
|
+
"pending": len(futures),
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
for doc, emb, token_info in zip(batch, embeddings, token_infos):
|
|
228
|
+
if doc.image is None:
|
|
229
|
+
raise ValueError("CorpusDoc.image is None. For paired datasets, use _index_paired_dataset().")
|
|
230
|
+
emb_np = emb.cpu().float().numpy() if hasattr(emb, "cpu") else np.array(emb, dtype=np.float32)
|
|
231
|
+
visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
|
|
232
|
+
visual_embedding = emb_np[visual_indices].astype(np.float32)
|
|
233
|
+
|
|
234
|
+
n_rows = token_info.get("n_rows")
|
|
235
|
+
n_cols = token_info.get("n_cols")
|
|
236
|
+
if n_rows and n_cols:
|
|
237
|
+
num_tiles = int(n_rows) * int(n_cols) + 1
|
|
238
|
+
else:
|
|
239
|
+
num_tiles = 13
|
|
240
|
+
|
|
241
|
+
tile_pooled = tile_level_mean_pooling(visual_embedding, num_tiles=num_tiles, patches_per_tile=64)
|
|
242
|
+
global_pooled = tile_pooled.mean(axis=0).astype(np.float32)
|
|
243
|
+
|
|
244
|
+
payload = {
|
|
245
|
+
"dataset": dataset_name,
|
|
246
|
+
"doc_id": doc.doc_id,
|
|
247
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
248
|
+
**(doc.payload or {}),
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
points_buffer.append(
|
|
252
|
+
{
|
|
253
|
+
"id": doc.doc_id,
|
|
254
|
+
"visual_embedding": visual_embedding,
|
|
255
|
+
"tile_pooled_embedding": tile_pooled,
|
|
256
|
+
"global_pooled_embedding": global_pooled,
|
|
257
|
+
"metadata": payload,
|
|
258
|
+
}
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if len(points_buffer) >= upload_batch_size:
|
|
262
|
+
chunk = points_buffer
|
|
263
|
+
points_buffer = []
|
|
264
|
+
enqueued_docs += len(chunk)
|
|
265
|
+
if executor is None:
|
|
266
|
+
uploaded_docs += int(_upload(chunk) or 0)
|
|
267
|
+
else:
|
|
268
|
+
futures.append(executor.submit(_upload, chunk))
|
|
269
|
+
_drain(block=len(futures) >= upload_workers * 2)
|
|
270
|
+
if pbar is not None:
|
|
271
|
+
pbar.set_postfix(
|
|
272
|
+
{
|
|
273
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
274
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
275
|
+
"buffer": len(points_buffer),
|
|
276
|
+
"enq": enqueued_docs,
|
|
277
|
+
"upl": uploaded_docs,
|
|
278
|
+
"pending": len(futures),
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
if executor is not None:
|
|
282
|
+
_drain(block=False)
|
|
283
|
+
except KeyboardInterrupt:
|
|
284
|
+
stop_event.set()
|
|
285
|
+
if executor is not None:
|
|
286
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
287
|
+
raise
|
|
288
|
+
|
|
289
|
+
if points_buffer:
|
|
290
|
+
enqueued_docs += len(points_buffer)
|
|
291
|
+
if executor is None:
|
|
292
|
+
uploaded_docs += int(_upload(points_buffer) or 0)
|
|
293
|
+
else:
|
|
294
|
+
futures.append(executor.submit(_upload, points_buffer))
|
|
295
|
+
|
|
296
|
+
if executor is not None:
|
|
297
|
+
_drain(block=True)
|
|
298
|
+
executor.shutdown(wait=True)
|
|
299
|
+
|
|
300
|
+
if pbar is not None:
|
|
301
|
+
pbar.set_postfix(
|
|
302
|
+
{
|
|
303
|
+
"avg_s/doc": f"{(max(time.time() - start_time, 1e-9) / max(int(pbar.n), 1)):.2f}",
|
|
304
|
+
"last_s/doc": "n/a",
|
|
305
|
+
"buffer": 0,
|
|
306
|
+
"enq": enqueued_docs,
|
|
307
|
+
"upl": uploaded_docs,
|
|
308
|
+
"pending": 0,
|
|
309
|
+
}
|
|
310
|
+
)
|
|
311
|
+
pbar.close()
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _index_paired_dataset(
|
|
315
|
+
*,
|
|
316
|
+
dataset_name: str,
|
|
317
|
+
collection_name: str,
|
|
318
|
+
total_docs: int,
|
|
319
|
+
embedder: VisualEmbedder,
|
|
320
|
+
qdrant_url: str,
|
|
321
|
+
qdrant_api_key: Optional[str],
|
|
322
|
+
prefer_grpc: bool,
|
|
323
|
+
qdrant_vector_dtype: str,
|
|
324
|
+
recreate: bool,
|
|
325
|
+
batch_size: int,
|
|
326
|
+
upload_batch_size: int,
|
|
327
|
+
upload_workers: int,
|
|
328
|
+
upsert_wait: bool,
|
|
329
|
+
loader_workers: int,
|
|
330
|
+
prefetch_factor: int,
|
|
331
|
+
persistent_workers: bool,
|
|
332
|
+
pin_memory: bool,
|
|
333
|
+
use_dataloader: bool,
|
|
334
|
+
indexing_threshold: int,
|
|
335
|
+
full_scan_threshold: int,
|
|
336
|
+
) -> None:
|
|
337
|
+
try:
|
|
338
|
+
from datasets import load_dataset
|
|
339
|
+
except ImportError as e:
|
|
340
|
+
raise ImportError("datasets is required. Install with: pip install datasets") from e
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
import torch
|
|
344
|
+
from torch.utils.data import DataLoader
|
|
345
|
+
except ImportError as e:
|
|
346
|
+
raise ImportError("torch is required. Install with: pip install visual-rag-toolkit[embedding]") from e
|
|
347
|
+
|
|
348
|
+
ds0 = load_dataset(dataset_name, split="test")
|
|
349
|
+
cols = set(ds0.column_names)
|
|
350
|
+
image_col = "image" if "image" in cols else "page_image"
|
|
351
|
+
del ds0
|
|
352
|
+
|
|
353
|
+
indexer = QdrantIndexer(
|
|
354
|
+
url=qdrant_url,
|
|
355
|
+
api_key=qdrant_api_key,
|
|
356
|
+
collection_name=collection_name,
|
|
357
|
+
prefer_grpc=prefer_grpc,
|
|
358
|
+
vector_datatype=qdrant_vector_dtype,
|
|
359
|
+
)
|
|
360
|
+
indexer.create_collection(
|
|
361
|
+
force_recreate=recreate,
|
|
362
|
+
indexing_threshold=indexing_threshold,
|
|
363
|
+
full_scan_threshold=int(full_scan_threshold),
|
|
364
|
+
)
|
|
365
|
+
indexer.create_payload_indexes(
|
|
366
|
+
fields=[
|
|
367
|
+
{"field": "dataset", "type": "keyword"},
|
|
368
|
+
{"field": "doc_id", "type": "keyword"},
|
|
369
|
+
{"field": "torch_dtype", "type": "keyword"},
|
|
370
|
+
{"field": "source", "type": "keyword"},
|
|
371
|
+
{"field": "image_filename", "type": "keyword"},
|
|
372
|
+
{"field": "page", "type": "keyword"},
|
|
373
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
374
|
+
]
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
enqueued_docs = 0
|
|
378
|
+
uploaded_docs = 0
|
|
379
|
+
start_time = time.time()
|
|
380
|
+
last_tick_time = start_time
|
|
381
|
+
last_tick_docs = 0
|
|
382
|
+
|
|
383
|
+
pbar = None
|
|
384
|
+
try:
|
|
385
|
+
from tqdm import tqdm
|
|
386
|
+
|
|
387
|
+
pbar = tqdm(total=total_docs, desc="📦 Indexing", unit="doc")
|
|
388
|
+
except ImportError:
|
|
389
|
+
pass
|
|
390
|
+
|
|
391
|
+
def _upload(points: List[Dict[str, Any]]) -> int:
|
|
392
|
+
return indexer.upload_batch(points, delay_between_batches=0.0, wait=upsert_wait, stop_event=stop_event)
|
|
393
|
+
|
|
394
|
+
executor = None
|
|
395
|
+
futures = []
|
|
396
|
+
import threading
|
|
397
|
+
|
|
398
|
+
stop_event = threading.Event()
|
|
399
|
+
if upload_workers and upload_workers > 0:
|
|
400
|
+
from concurrent.futures import ThreadPoolExecutor, wait as futures_wait, FIRST_EXCEPTION
|
|
401
|
+
|
|
402
|
+
executor = ThreadPoolExecutor(max_workers=upload_workers)
|
|
403
|
+
|
|
404
|
+
def _drain(block: bool = False) -> None:
|
|
405
|
+
nonlocal uploaded_docs
|
|
406
|
+
nonlocal last_tick_time
|
|
407
|
+
nonlocal last_tick_docs
|
|
408
|
+
if not futures:
|
|
409
|
+
return
|
|
410
|
+
if block:
|
|
411
|
+
done, _ = futures_wait(futures, return_when=FIRST_EXCEPTION)
|
|
412
|
+
else:
|
|
413
|
+
done, _ = futures_wait(futures, timeout=0, return_when=FIRST_EXCEPTION)
|
|
414
|
+
for d in list(done):
|
|
415
|
+
futures.remove(d)
|
|
416
|
+
uploaded_docs += int(d.result() or 0)
|
|
417
|
+
if pbar is not None:
|
|
418
|
+
now = time.time()
|
|
419
|
+
done_docs = int(pbar.n)
|
|
420
|
+
elapsed = max(now - start_time, 1e-9)
|
|
421
|
+
avg_s_per_doc = elapsed / max(done_docs, 1)
|
|
422
|
+
delta_docs = done_docs - last_tick_docs
|
|
423
|
+
delta_t = max(now - last_tick_time, 1e-9)
|
|
424
|
+
last_s_per_doc = delta_t / max(delta_docs, 1)
|
|
425
|
+
last_tick_time = now
|
|
426
|
+
last_tick_docs = done_docs
|
|
427
|
+
pbar.set_postfix(
|
|
428
|
+
{
|
|
429
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
430
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
431
|
+
"buffer": len(points_buffer),
|
|
432
|
+
"enq": enqueued_docs,
|
|
433
|
+
"upl": uploaded_docs,
|
|
434
|
+
"pending": len(futures),
|
|
435
|
+
}
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
points_buffer: List[Dict[str, Any]] = []
|
|
439
|
+
try:
|
|
440
|
+
if use_dataloader or (loader_workers and loader_workers > 0):
|
|
441
|
+
dl_kwargs = {"batch_size": batch_size, "shuffle": False, "collate_fn": _paired_collate}
|
|
442
|
+
if loader_workers and loader_workers > 0:
|
|
443
|
+
dl_kwargs["num_workers"] = int(loader_workers)
|
|
444
|
+
dl_kwargs["prefetch_factor"] = int(prefetch_factor)
|
|
445
|
+
dl_kwargs["persistent_workers"] = bool(persistent_workers)
|
|
446
|
+
dl_kwargs["pin_memory"] = bool(pin_memory and torch.cuda.is_available())
|
|
447
|
+
|
|
448
|
+
data_loader = DataLoader(
|
|
449
|
+
_PairedHFDataset(dataset_name=dataset_name, split="test", total_docs=total_docs, image_col=image_col),
|
|
450
|
+
**dl_kwargs,
|
|
451
|
+
)
|
|
452
|
+
iterable = ((idxs, images, metas) for (idxs, images, metas) in data_loader)
|
|
453
|
+
else:
|
|
454
|
+
ds = load_dataset(dataset_name, split="test")
|
|
455
|
+
|
|
456
|
+
def _iter_batches():
|
|
457
|
+
for start in range(0, total_docs, batch_size):
|
|
458
|
+
batch = ds[start : start + batch_size]
|
|
459
|
+
images = batch[image_col]
|
|
460
|
+
metas = [{k: batch[k][i] for k in batch.keys() if k != image_col} for i in range(len(images))]
|
|
461
|
+
idxs = list(range(start, start + len(images)))
|
|
462
|
+
yield idxs, images, metas
|
|
463
|
+
|
|
464
|
+
iterable = _iter_batches()
|
|
465
|
+
|
|
466
|
+
for idxs, images, metas in iterable:
|
|
467
|
+
embeddings, token_infos = embedder.embed_images(
|
|
468
|
+
images,
|
|
469
|
+
batch_size=batch_size,
|
|
470
|
+
return_token_info=True,
|
|
471
|
+
show_progress=False,
|
|
472
|
+
)
|
|
473
|
+
if pbar is not None:
|
|
474
|
+
pbar.update(len(images))
|
|
475
|
+
now = time.time()
|
|
476
|
+
done_docs = int(pbar.n)
|
|
477
|
+
elapsed = max(now - start_time, 1e-9)
|
|
478
|
+
avg_s_per_doc = elapsed / max(done_docs, 1)
|
|
479
|
+
delta_docs = done_docs - last_tick_docs
|
|
480
|
+
delta_t = max(now - last_tick_time, 1e-9)
|
|
481
|
+
last_s_per_doc = delta_t / max(delta_docs, 1)
|
|
482
|
+
last_tick_time = now
|
|
483
|
+
last_tick_docs = done_docs
|
|
484
|
+
pbar.set_postfix(
|
|
485
|
+
{
|
|
486
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
487
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
488
|
+
"buffer": len(points_buffer),
|
|
489
|
+
"enq": enqueued_docs,
|
|
490
|
+
"upl": uploaded_docs,
|
|
491
|
+
"pending": len(futures),
|
|
492
|
+
}
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
for idx, meta, emb, token_info in zip(idxs, metas, embeddings, token_infos):
|
|
496
|
+
doc_id = paired_doc_id(meta, int(idx))
|
|
497
|
+
payload = {
|
|
498
|
+
"dataset": dataset_name,
|
|
499
|
+
"doc_id": doc_id,
|
|
500
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
501
|
+
**paired_payload(meta, int(idx)),
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
emb_np = emb.cpu().float().numpy() if hasattr(emb, "cpu") else np.array(emb, dtype=np.float32)
|
|
505
|
+
visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
|
|
506
|
+
visual_embedding = emb_np[visual_indices].astype(np.float32)
|
|
507
|
+
|
|
508
|
+
n_rows = token_info.get("n_rows")
|
|
509
|
+
n_cols = token_info.get("n_cols")
|
|
510
|
+
num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
|
|
511
|
+
|
|
512
|
+
tile_pooled = tile_level_mean_pooling(visual_embedding, num_tiles=num_tiles, patches_per_tile=64)
|
|
513
|
+
global_pooled = tile_pooled.mean(axis=0).astype(np.float32)
|
|
514
|
+
|
|
515
|
+
points_buffer.append(
|
|
516
|
+
{
|
|
517
|
+
"id": doc_id,
|
|
518
|
+
"visual_embedding": visual_embedding,
|
|
519
|
+
"tile_pooled_embedding": tile_pooled,
|
|
520
|
+
"global_pooled_embedding": global_pooled,
|
|
521
|
+
"metadata": payload,
|
|
522
|
+
}
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
if len(points_buffer) >= upload_batch_size:
|
|
526
|
+
chunk = points_buffer
|
|
527
|
+
points_buffer = []
|
|
528
|
+
enqueued_docs += len(chunk)
|
|
529
|
+
if executor is None:
|
|
530
|
+
uploaded_docs += int(_upload(chunk) or 0)
|
|
531
|
+
else:
|
|
532
|
+
futures.append(executor.submit(_upload, chunk))
|
|
533
|
+
_drain(block=len(futures) >= upload_workers * 2)
|
|
534
|
+
if executor is not None:
|
|
535
|
+
_drain(block=False)
|
|
536
|
+
except KeyboardInterrupt:
|
|
537
|
+
stop_event.set()
|
|
538
|
+
if executor is not None:
|
|
539
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
540
|
+
raise
|
|
541
|
+
|
|
542
|
+
if points_buffer:
|
|
543
|
+
enqueued_docs += len(points_buffer)
|
|
544
|
+
if executor is None:
|
|
545
|
+
uploaded_docs += int(_upload(points_buffer) or 0)
|
|
546
|
+
else:
|
|
547
|
+
futures.append(executor.submit(_upload, points_buffer))
|
|
548
|
+
|
|
549
|
+
if executor is not None:
|
|
550
|
+
_drain(block=True)
|
|
551
|
+
executor.shutdown(wait=True)
|
|
552
|
+
|
|
553
|
+
if pbar is not None:
|
|
554
|
+
pbar.set_postfix(
|
|
555
|
+
{
|
|
556
|
+
"avg_s/doc": f"{(max(time.time() - start_time, 1e-9) / max(int(pbar.n), 1)):.2f}",
|
|
557
|
+
"last_s/doc": "n/a",
|
|
558
|
+
"buffer": 0,
|
|
559
|
+
"enq": enqueued_docs,
|
|
560
|
+
"upl": uploaded_docs,
|
|
561
|
+
"pending": 0,
|
|
562
|
+
}
|
|
563
|
+
)
|
|
564
|
+
pbar.close()
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def _evaluate(
|
|
568
|
+
*,
|
|
569
|
+
queries: List[Any],
|
|
570
|
+
qrels: Dict[str, Dict[str, int]],
|
|
571
|
+
retriever: MultiVectorRetriever,
|
|
572
|
+
top_k: int,
|
|
573
|
+
prefetch_k: int,
|
|
574
|
+
mode: str,
|
|
575
|
+
stage1_mode: str,
|
|
576
|
+
) -> Dict[str, float]:
|
|
577
|
+
ndcg10: List[float] = []
|
|
578
|
+
mrr10: List[float] = []
|
|
579
|
+
recall10: List[float] = []
|
|
580
|
+
recall5: List[float] = []
|
|
581
|
+
latencies_ms: List[float] = []
|
|
582
|
+
|
|
583
|
+
for q in queries:
|
|
584
|
+
start = time.time()
|
|
585
|
+
results = retriever.search(
|
|
586
|
+
query=q.text,
|
|
587
|
+
top_k=top_k,
|
|
588
|
+
mode=mode,
|
|
589
|
+
prefetch_k=prefetch_k,
|
|
590
|
+
stage1_mode=stage1_mode,
|
|
591
|
+
)
|
|
592
|
+
latencies_ms.append((time.time() - start) * 1000.0)
|
|
593
|
+
|
|
594
|
+
ranking = [str(r["id"]) for r in results]
|
|
595
|
+
rels = qrels.get(q.query_id, {})
|
|
596
|
+
|
|
597
|
+
ndcg10.append(ndcg_at_k(ranking, rels, k=10))
|
|
598
|
+
mrr10.append(mrr_at_k(ranking, rels, k=10))
|
|
599
|
+
recall5.append(recall_at_k(ranking, rels, k=5))
|
|
600
|
+
recall10.append(recall_at_k(ranking, rels, k=10))
|
|
601
|
+
|
|
602
|
+
return {
|
|
603
|
+
"ndcg@10": float(np.mean(ndcg10)),
|
|
604
|
+
"mrr@10": float(np.mean(mrr10)),
|
|
605
|
+
"recall@5": float(np.mean(recall5)),
|
|
606
|
+
"recall@10": float(np.mean(recall10)),
|
|
607
|
+
"avg_latency_ms": float(np.mean(latencies_ms)),
|
|
608
|
+
"p95_latency_ms": float(np.percentile(latencies_ms, 95)),
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def main() -> None:
|
|
613
|
+
parser = argparse.ArgumentParser()
|
|
614
|
+
parser.add_argument("--dataset", type=str, default="vidore/tatdqa_test")
|
|
615
|
+
parser.add_argument("--collection", type=str, default="vidore_tatdqa_test")
|
|
616
|
+
parser.add_argument("--model", type=str, default="vidore/colSmol-500M")
|
|
617
|
+
parser.add_argument(
|
|
618
|
+
"--torch-dtype",
|
|
619
|
+
type=str,
|
|
620
|
+
default="auto",
|
|
621
|
+
choices=["auto", "float32", "float16", "bfloat16"],
|
|
622
|
+
help="Torch dtype for model weights (default: auto; CUDA->bfloat16, else float32).",
|
|
623
|
+
)
|
|
624
|
+
parser.add_argument(
|
|
625
|
+
"--qdrant-vector-dtype",
|
|
626
|
+
type=str,
|
|
627
|
+
default="float16",
|
|
628
|
+
choices=["float16", "float32"],
|
|
629
|
+
help="Datatype for vectors stored in Qdrant (default: float16).",
|
|
630
|
+
)
|
|
631
|
+
parser.add_argument("--batch-size", type=int, default=4)
|
|
632
|
+
parser.add_argument("--upload-batch-size", type=int, default=None)
|
|
633
|
+
parser.add_argument("--upload-workers", type=int, default=0)
|
|
634
|
+
wait_group = parser.add_mutually_exclusive_group()
|
|
635
|
+
wait_group.add_argument(
|
|
636
|
+
"--upsert-wait",
|
|
637
|
+
action="store_true",
|
|
638
|
+
help="Wait for Qdrant upserts to complete before continuing (default: false).",
|
|
639
|
+
)
|
|
640
|
+
wait_group.add_argument(
|
|
641
|
+
"--no-upsert-wait",
|
|
642
|
+
action="store_true",
|
|
643
|
+
help="Deprecated (default is already no-wait). Kept for backwards compatibility.",
|
|
644
|
+
)
|
|
645
|
+
parser.add_argument("--loader-workers", type=int, default=0)
|
|
646
|
+
parser.add_argument("--prefetch-factor", type=int, default=2)
|
|
647
|
+
parser.add_argument("--persistent-workers", action="store_true")
|
|
648
|
+
parser.add_argument("--pin-memory", action="store_true")
|
|
649
|
+
parser.add_argument(
|
|
650
|
+
"--use-dataloader",
|
|
651
|
+
action="store_true",
|
|
652
|
+
help="Use torch DataLoader even with --loader-workers 0 (default: false).",
|
|
653
|
+
)
|
|
654
|
+
grpc_group = parser.add_mutually_exclusive_group()
|
|
655
|
+
grpc_group.add_argument("--prefer-grpc", dest="prefer_grpc", action="store_true", default=True)
|
|
656
|
+
grpc_group.add_argument("--no-prefer-grpc", dest="prefer_grpc", action="store_false")
|
|
657
|
+
parser.add_argument("--index", action="store_true", help="Index corpus into Qdrant before evaluating")
|
|
658
|
+
parser.add_argument("--recreate", action="store_true", help="Delete and recreate the collection (implies --index)")
|
|
659
|
+
parser.add_argument(
|
|
660
|
+
"--indexing-threshold",
|
|
661
|
+
type=int,
|
|
662
|
+
default=0,
|
|
663
|
+
help="Qdrant optimizer indexing threshold (0 = always build indexes).",
|
|
664
|
+
)
|
|
665
|
+
parser.add_argument(
|
|
666
|
+
"--full-scan-threshold",
|
|
667
|
+
type=int,
|
|
668
|
+
default=0,
|
|
669
|
+
help="Qdrant HNSW full_scan_threshold (0 = always use HNSW).",
|
|
670
|
+
)
|
|
671
|
+
parser.add_argument("--top-k", type=int, default=10)
|
|
672
|
+
parser.add_argument("--prefetch-k", type=int, default=200)
|
|
673
|
+
parser.add_argument(
|
|
674
|
+
"--mode",
|
|
675
|
+
type=str,
|
|
676
|
+
default="single_full",
|
|
677
|
+
choices=["single_full", "single_tiles", "single_global", "two_stage"],
|
|
678
|
+
)
|
|
679
|
+
parser.add_argument(
|
|
680
|
+
"--stage1-mode",
|
|
681
|
+
type=str,
|
|
682
|
+
default="tokens_vs_tiles",
|
|
683
|
+
choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
|
|
684
|
+
)
|
|
685
|
+
parser.add_argument("--output", type=str, default="results/qdrant_vidore_tatdqa_test.json")
|
|
686
|
+
|
|
687
|
+
args = parser.parse_args()
|
|
688
|
+
|
|
689
|
+
_maybe_load_dotenv()
|
|
690
|
+
|
|
691
|
+
qdrant_url = _ensure_env("QDRANT_URL")
|
|
692
|
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
|
693
|
+
upload_batch_size = args.upload_batch_size or args.batch_size
|
|
694
|
+
upsert_wait = bool(args.upsert_wait)
|
|
695
|
+
|
|
696
|
+
if upsert_wait:
|
|
697
|
+
print("Qdrant upserts wait for completion (wait=True).")
|
|
698
|
+
else:
|
|
699
|
+
print("Qdrant upserts are async (wait=False).")
|
|
700
|
+
|
|
701
|
+
corpus, queries, qrels, protocol = load_vidore_dataset_auto(args.dataset)
|
|
702
|
+
|
|
703
|
+
embedder = VisualEmbedder(
|
|
704
|
+
model_name=args.model,
|
|
705
|
+
batch_size=args.batch_size,
|
|
706
|
+
torch_dtype=_parse_torch_dtype(args.torch_dtype),
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
if args.recreate:
|
|
710
|
+
args.index = True
|
|
711
|
+
|
|
712
|
+
if args.index:
|
|
713
|
+
if protocol == "paired":
|
|
714
|
+
_index_paired_dataset(
|
|
715
|
+
dataset_name=args.dataset,
|
|
716
|
+
collection_name=args.collection,
|
|
717
|
+
total_docs=len(corpus),
|
|
718
|
+
embedder=embedder,
|
|
719
|
+
qdrant_url=qdrant_url,
|
|
720
|
+
qdrant_api_key=qdrant_api_key,
|
|
721
|
+
prefer_grpc=args.prefer_grpc,
|
|
722
|
+
qdrant_vector_dtype=args.qdrant_vector_dtype,
|
|
723
|
+
recreate=args.recreate,
|
|
724
|
+
batch_size=args.batch_size,
|
|
725
|
+
upload_batch_size=upload_batch_size,
|
|
726
|
+
upload_workers=args.upload_workers,
|
|
727
|
+
upsert_wait=upsert_wait,
|
|
728
|
+
loader_workers=args.loader_workers,
|
|
729
|
+
prefetch_factor=args.prefetch_factor,
|
|
730
|
+
persistent_workers=args.persistent_workers,
|
|
731
|
+
pin_memory=args.pin_memory,
|
|
732
|
+
use_dataloader=args.use_dataloader,
|
|
733
|
+
indexing_threshold=args.indexing_threshold,
|
|
734
|
+
full_scan_threshold=args.full_scan_threshold,
|
|
735
|
+
)
|
|
736
|
+
else:
|
|
737
|
+
_index_corpus(
|
|
738
|
+
dataset_name=args.dataset,
|
|
739
|
+
collection_name=args.collection,
|
|
740
|
+
corpus=corpus,
|
|
741
|
+
embedder=embedder,
|
|
742
|
+
qdrant_url=qdrant_url,
|
|
743
|
+
qdrant_api_key=qdrant_api_key,
|
|
744
|
+
prefer_grpc=args.prefer_grpc,
|
|
745
|
+
qdrant_vector_dtype=args.qdrant_vector_dtype,
|
|
746
|
+
recreate=args.recreate,
|
|
747
|
+
batch_size=args.batch_size,
|
|
748
|
+
upload_batch_size=upload_batch_size,
|
|
749
|
+
upload_workers=args.upload_workers,
|
|
750
|
+
upsert_wait=upsert_wait,
|
|
751
|
+
indexing_threshold=args.indexing_threshold,
|
|
752
|
+
full_scan_threshold=args.full_scan_threshold,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
retriever = MultiVectorRetriever(
|
|
756
|
+
collection_name=args.collection,
|
|
757
|
+
embedder=embedder,
|
|
758
|
+
qdrant_url=qdrant_url,
|
|
759
|
+
qdrant_api_key=qdrant_api_key,
|
|
760
|
+
prefer_grpc=args.prefer_grpc,
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
metrics = _evaluate(
|
|
764
|
+
queries=queries,
|
|
765
|
+
qrels=qrels,
|
|
766
|
+
retriever=retriever,
|
|
767
|
+
top_k=args.top_k,
|
|
768
|
+
prefetch_k=args.prefetch_k,
|
|
769
|
+
mode=args.mode,
|
|
770
|
+
stage1_mode=args.stage1_mode,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
out_path = Path(args.output)
|
|
774
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
775
|
+
with open(out_path, "w") as f:
|
|
776
|
+
json.dump(
|
|
777
|
+
{
|
|
778
|
+
"dataset": args.dataset,
|
|
779
|
+
"protocol": protocol,
|
|
780
|
+
"collection": args.collection,
|
|
781
|
+
"model": args.model,
|
|
782
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
783
|
+
"qdrant_vector_dtype": args.qdrant_vector_dtype,
|
|
784
|
+
"mode": args.mode,
|
|
785
|
+
"stage1_mode": args.stage1_mode if args.mode == "two_stage" else None,
|
|
786
|
+
"prefetch_k": args.prefetch_k if args.mode == "two_stage" else None,
|
|
787
|
+
"metrics": metrics,
|
|
788
|
+
},
|
|
789
|
+
f,
|
|
790
|
+
indent=2,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
print(json.dumps(metrics, indent=2))
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
if __name__ == "__main__":
|
|
797
|
+
main()
|
|
798
|
+
|
|
799
|
+
|