visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,1365 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import tempfile
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
13
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
14
|
+
from visual_rag import VisualEmbedder
|
|
15
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
16
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
17
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _maybe_load_dotenv() -> None:
|
|
21
|
+
try:
|
|
22
|
+
from dotenv import load_dotenv
|
|
23
|
+
except ImportError:
|
|
24
|
+
return
|
|
25
|
+
if Path(".env").exists():
|
|
26
|
+
load_dotenv(".env")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _torch_dtype_to_str(dtype) -> str:
|
|
30
|
+
if dtype is None:
|
|
31
|
+
return "auto"
|
|
32
|
+
s = str(dtype)
|
|
33
|
+
return s.replace("torch.", "")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _parse_torch_dtype(dtype_str: str):
|
|
37
|
+
if dtype_str == "auto":
|
|
38
|
+
return None
|
|
39
|
+
import torch
|
|
40
|
+
|
|
41
|
+
mapping = {
|
|
42
|
+
"float32": torch.float32,
|
|
43
|
+
"float16": torch.float16,
|
|
44
|
+
"bfloat16": torch.bfloat16,
|
|
45
|
+
}
|
|
46
|
+
return mapping[dtype_str]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _stable_uuid(text: str) -> str:
|
|
50
|
+
import hashlib
|
|
51
|
+
|
|
52
|
+
hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
|
|
53
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _sample_list(items: List[Any], *, k: int, strategy: str, seed: int) -> List[Any]:
|
|
57
|
+
if not k or k <= 0:
|
|
58
|
+
return items
|
|
59
|
+
if k >= len(items):
|
|
60
|
+
return items
|
|
61
|
+
if strategy == "first":
|
|
62
|
+
return items[:k]
|
|
63
|
+
if strategy == "random":
|
|
64
|
+
import random
|
|
65
|
+
|
|
66
|
+
rng = random.Random(int(seed))
|
|
67
|
+
indices = rng.sample(range(len(items)), k)
|
|
68
|
+
return [items[i] for i in indices]
|
|
69
|
+
raise ValueError("sample strategy must be 'first' or 'random'")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _parse_payload_indexes(values: List[str]) -> List[Dict[str, str]]:
|
|
73
|
+
indexes: List[Dict[str, str]] = []
|
|
74
|
+
for raw in values or []:
|
|
75
|
+
if ":" not in raw:
|
|
76
|
+
raise ValueError("payload index must be in field:type format")
|
|
77
|
+
field, type_str = raw.split(":", 1)
|
|
78
|
+
field = field.strip()
|
|
79
|
+
type_str = type_str.strip()
|
|
80
|
+
if not field or not type_str:
|
|
81
|
+
raise ValueError("payload index must be in field:type format")
|
|
82
|
+
indexes.append({"field": field, "type": type_str})
|
|
83
|
+
return indexes
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
|
|
87
|
+
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
|
|
88
|
+
return _stable_uuid(f"{ns}::{source_doc_id}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _filter_qrels(qrels: Dict[str, Dict[str, int]], query_ids: List[str]) -> Dict[str, Dict[str, int]]:
|
|
92
|
+
keep = set(query_ids)
|
|
93
|
+
return {qid: rels for qid, rels in qrels.items() if qid in keep}
|
|
94
|
+
|
|
95
|
+
def _failed_log_path(*, collection_name: str, dataset_name: str) -> Path:
|
|
96
|
+
dir_name = _safe_filename(collection_name)
|
|
97
|
+
return Path("results") / dir_name / f"index_failures__{_safe_filename(dataset_name)}.jsonl"
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _resolve_output_path(raw_output: str, *, collection_name: str) -> Path:
|
|
101
|
+
"""
|
|
102
|
+
Default behavior:
|
|
103
|
+
- If --output is a bare filename, write it to results/{collection_name}/{filename}
|
|
104
|
+
- If --output points into the legacy results/reports/, rewrite into results/{collection_name}/
|
|
105
|
+
- If --output includes any other directory (relative or absolute), respect it
|
|
106
|
+
"""
|
|
107
|
+
p = Path(str(raw_output))
|
|
108
|
+
dir_name = _safe_filename(collection_name)
|
|
109
|
+
|
|
110
|
+
if str(p).startswith("results/reports/"):
|
|
111
|
+
return Path("results") / dir_name / p.name
|
|
112
|
+
if p.is_absolute():
|
|
113
|
+
return p
|
|
114
|
+
if p.parent == Path("."):
|
|
115
|
+
return Path("results") / dir_name / p.name
|
|
116
|
+
return p
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _default_output_filename(*, args, datasets: List[str]) -> str:
|
|
120
|
+
model_tag = _safe_filename(str(args.model).split("/")[-1])
|
|
121
|
+
scope_tag = _safe_filename(str(args.evaluation_scope))
|
|
122
|
+
mode_tag = _safe_filename(str(args.mode))
|
|
123
|
+
topk_tag = f"top{int(args.top_k)}"
|
|
124
|
+
ds_tag = f"{len(datasets)}ds"
|
|
125
|
+
|
|
126
|
+
parts = [model_tag, mode_tag]
|
|
127
|
+
if str(args.mode) == "two_stage":
|
|
128
|
+
parts.append(_safe_filename(str(args.stage1_mode)))
|
|
129
|
+
parts.append(f"pk{int(args.prefetch_k)}")
|
|
130
|
+
if str(args.mode) == "three_stage":
|
|
131
|
+
parts.append("tokens_vs_global")
|
|
132
|
+
parts.append(f"s1k{int(args.stage1_k)}")
|
|
133
|
+
parts.append("tokens_vs_experimental")
|
|
134
|
+
parts.append(f"s2k{int(args.stage2_k)}")
|
|
135
|
+
parts.extend([topk_tag, scope_tag, ds_tag])
|
|
136
|
+
|
|
137
|
+
if bool(args.crop_empty):
|
|
138
|
+
pct = int(round(float(args.crop_empty_percentage_to_remove) * 100))
|
|
139
|
+
parts.append(f"crop{pct}")
|
|
140
|
+
|
|
141
|
+
name = "__".join([p for p in parts if p])
|
|
142
|
+
return f"{name}.json"
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _append_jsonl(path: Path, obj: Dict[str, Any]) -> None:
|
|
146
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
with path.open("a") as f:
|
|
148
|
+
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _load_failed_ids(path: Path) -> set:
|
|
152
|
+
if not path.exists():
|
|
153
|
+
return set()
|
|
154
|
+
ids = set()
|
|
155
|
+
with path.open("r") as f:
|
|
156
|
+
for line in f:
|
|
157
|
+
line = (line or "").strip()
|
|
158
|
+
if not line:
|
|
159
|
+
continue
|
|
160
|
+
try:
|
|
161
|
+
obj = json.loads(line)
|
|
162
|
+
except Exception:
|
|
163
|
+
continue
|
|
164
|
+
for key in ("union_doc_id", "doc_id"):
|
|
165
|
+
v = obj.get(key)
|
|
166
|
+
if v:
|
|
167
|
+
ids.add(str(v))
|
|
168
|
+
return ids
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _load_failed_union_ids(
|
|
172
|
+
path: Path,
|
|
173
|
+
*,
|
|
174
|
+
dataset_name: str,
|
|
175
|
+
union_namespace: Optional[str],
|
|
176
|
+
) -> set:
|
|
177
|
+
"""
|
|
178
|
+
Load a set of union_doc_id values usable against Qdrant point IDs.
|
|
179
|
+
|
|
180
|
+
Older logs may contain union_doc_id computed without union_namespace.
|
|
181
|
+
We always recompute the current union_doc_id from source_doc_id to make retries consistent.
|
|
182
|
+
"""
|
|
183
|
+
if not path.exists():
|
|
184
|
+
return set()
|
|
185
|
+
out = set()
|
|
186
|
+
with path.open("r") as f:
|
|
187
|
+
for line in f:
|
|
188
|
+
s = (line or "").strip()
|
|
189
|
+
if not s:
|
|
190
|
+
continue
|
|
191
|
+
try:
|
|
192
|
+
obj = json.loads(s)
|
|
193
|
+
except Exception:
|
|
194
|
+
continue
|
|
195
|
+
src = obj.get("source_doc_id")
|
|
196
|
+
if src:
|
|
197
|
+
out.add(
|
|
198
|
+
_union_point_id(
|
|
199
|
+
dataset_name=str(dataset_name),
|
|
200
|
+
source_doc_id=str(src),
|
|
201
|
+
union_namespace=union_namespace,
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
u = obj.get("union_doc_id")
|
|
205
|
+
if u:
|
|
206
|
+
out.add(str(u))
|
|
207
|
+
return out
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _remove_failed_from_qrels(qrels: Dict[str, Dict[str, int]], failed_ids: set) -> Tuple[Dict[str, Dict[str, int]], int]:
|
|
211
|
+
removed = 0
|
|
212
|
+
if not failed_ids:
|
|
213
|
+
return qrels, 0
|
|
214
|
+
out: Dict[str, Dict[str, int]] = {}
|
|
215
|
+
for qid, rels in (qrels or {}).items():
|
|
216
|
+
new_rels: Dict[str, int] = {}
|
|
217
|
+
for did, score in (rels or {}).items():
|
|
218
|
+
if str(did) in failed_ids:
|
|
219
|
+
removed += 1
|
|
220
|
+
continue
|
|
221
|
+
new_rels[str(did)] = int(score)
|
|
222
|
+
out[str(qid)] = new_rels
|
|
223
|
+
return out, removed
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _evaluate(
|
|
227
|
+
*,
|
|
228
|
+
queries,
|
|
229
|
+
qrels: Dict[str, Dict[str, int]],
|
|
230
|
+
retriever: MultiVectorRetriever,
|
|
231
|
+
embedder: VisualEmbedder,
|
|
232
|
+
top_k: int,
|
|
233
|
+
prefetch_k: int,
|
|
234
|
+
mode: str,
|
|
235
|
+
stage1_mode: str,
|
|
236
|
+
stage1_k: int,
|
|
237
|
+
stage2_k: int,
|
|
238
|
+
max_queries: int,
|
|
239
|
+
drop_empty_queries: bool,
|
|
240
|
+
filter_obj=None,
|
|
241
|
+
) -> Dict[str, float]:
|
|
242
|
+
eval_started_at = time.time()
|
|
243
|
+
if drop_empty_queries:
|
|
244
|
+
queries = [q for q in queries if any(v > 0 for v in qrels.get(q.query_id, {}).values())]
|
|
245
|
+
if max_queries and max_queries > 0:
|
|
246
|
+
queries = queries[:max_queries]
|
|
247
|
+
if not queries:
|
|
248
|
+
return {
|
|
249
|
+
"ndcg@1": 0.0,
|
|
250
|
+
"ndcg@5": 0.0,
|
|
251
|
+
"ndcg@10": 0.0,
|
|
252
|
+
"ndcg@100": 0.0,
|
|
253
|
+
"mrr@1": 0.0,
|
|
254
|
+
"mrr@5": 0.0,
|
|
255
|
+
"mrr@10": 0.0,
|
|
256
|
+
"mrr@100": 0.0,
|
|
257
|
+
"recall@1": 0.0,
|
|
258
|
+
"recall@5": 0.0,
|
|
259
|
+
"recall@10": 0.0,
|
|
260
|
+
"recall@100": 0.0,
|
|
261
|
+
"avg_latency_ms": 0.0,
|
|
262
|
+
"p95_latency_ms": 0.0,
|
|
263
|
+
"eval_wall_time_s": 0.0,
|
|
264
|
+
"eval_search_time_s": 0.0,
|
|
265
|
+
"qps": 0.0,
|
|
266
|
+
"num_queries_eval": 0,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
ndcg1: List[float] = []
|
|
270
|
+
ndcg5: List[float] = []
|
|
271
|
+
ndcg10: List[float] = []
|
|
272
|
+
ndcg100: List[float] = []
|
|
273
|
+
mrr1: List[float] = []
|
|
274
|
+
mrr5: List[float] = []
|
|
275
|
+
mrr10: List[float] = []
|
|
276
|
+
mrr100: List[float] = []
|
|
277
|
+
recall1: List[float] = []
|
|
278
|
+
recall5: List[float] = []
|
|
279
|
+
recall10: List[float] = []
|
|
280
|
+
recall100: List[float] = []
|
|
281
|
+
latencies_ms: List[float] = []
|
|
282
|
+
|
|
283
|
+
retrieve_k = max(100, top_k)
|
|
284
|
+
|
|
285
|
+
query_texts = [q.text for q in queries]
|
|
286
|
+
query_embeddings = embedder.embed_queries(
|
|
287
|
+
query_texts,
|
|
288
|
+
batch_size=getattr(embedder, "batch_size", None),
|
|
289
|
+
show_progress=False,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
iterator = queries
|
|
293
|
+
try:
|
|
294
|
+
from tqdm import tqdm
|
|
295
|
+
|
|
296
|
+
iterator = tqdm(queries, desc="Searching", unit="q")
|
|
297
|
+
except ImportError:
|
|
298
|
+
pass
|
|
299
|
+
|
|
300
|
+
for q, qemb in zip(iterator, query_embeddings):
|
|
301
|
+
start = time.time()
|
|
302
|
+
try:
|
|
303
|
+
import torch
|
|
304
|
+
except ImportError:
|
|
305
|
+
torch = None
|
|
306
|
+
if torch is not None and isinstance(qemb, torch.Tensor):
|
|
307
|
+
qemb_np = qemb.detach().cpu().numpy()
|
|
308
|
+
else:
|
|
309
|
+
qemb_np = qemb.numpy()
|
|
310
|
+
|
|
311
|
+
results = retriever.search_embedded(
|
|
312
|
+
query_embedding=qemb_np,
|
|
313
|
+
top_k=retrieve_k,
|
|
314
|
+
mode=mode,
|
|
315
|
+
prefetch_k=prefetch_k,
|
|
316
|
+
stage1_mode=stage1_mode,
|
|
317
|
+
stage1_k=int(stage1_k),
|
|
318
|
+
stage2_k=int(stage2_k),
|
|
319
|
+
filter_obj=filter_obj,
|
|
320
|
+
)
|
|
321
|
+
latencies_ms.append((time.time() - start) * 1000.0)
|
|
322
|
+
|
|
323
|
+
ranking = [str(r["id"]) for r in results]
|
|
324
|
+
rels = qrels.get(q.query_id, {})
|
|
325
|
+
|
|
326
|
+
ndcg1.append(ndcg_at_k(ranking, rels, k=1))
|
|
327
|
+
ndcg5.append(ndcg_at_k(ranking, rels, k=5))
|
|
328
|
+
ndcg10.append(ndcg_at_k(ranking, rels, k=10))
|
|
329
|
+
ndcg100.append(ndcg_at_k(ranking, rels, k=100))
|
|
330
|
+
mrr1.append(mrr_at_k(ranking, rels, k=1))
|
|
331
|
+
mrr5.append(mrr_at_k(ranking, rels, k=5))
|
|
332
|
+
mrr10.append(mrr_at_k(ranking, rels, k=10))
|
|
333
|
+
mrr100.append(mrr_at_k(ranking, rels, k=100))
|
|
334
|
+
recall1.append(recall_at_k(ranking, rels, k=1))
|
|
335
|
+
recall5.append(recall_at_k(ranking, rels, k=5))
|
|
336
|
+
recall10.append(recall_at_k(ranking, rels, k=10))
|
|
337
|
+
recall100.append(recall_at_k(ranking, rels, k=100))
|
|
338
|
+
|
|
339
|
+
eval_wall_time_s = float(max(time.time() - eval_started_at, 0.0))
|
|
340
|
+
eval_search_time_s = float(np.sum(latencies_ms) / 1000.0) if latencies_ms else 0.0
|
|
341
|
+
qps = float(len(queries) / eval_wall_time_s) if eval_wall_time_s > 0 else 0.0
|
|
342
|
+
return {
|
|
343
|
+
"ndcg@1": float(np.mean(ndcg1)),
|
|
344
|
+
"ndcg@5": float(np.mean(ndcg5)),
|
|
345
|
+
"ndcg@10": float(np.mean(ndcg10)),
|
|
346
|
+
"ndcg@100": float(np.mean(ndcg100)),
|
|
347
|
+
"mrr@1": float(np.mean(mrr1)),
|
|
348
|
+
"mrr@5": float(np.mean(mrr5)),
|
|
349
|
+
"mrr@10": float(np.mean(mrr10)),
|
|
350
|
+
"mrr@100": float(np.mean(mrr100)),
|
|
351
|
+
"recall@1": float(np.mean(recall1)),
|
|
352
|
+
"recall@5": float(np.mean(recall5)),
|
|
353
|
+
"recall@10": float(np.mean(recall10)),
|
|
354
|
+
"recall@100": float(np.mean(recall100)),
|
|
355
|
+
"avg_latency_ms": float(np.mean(latencies_ms)),
|
|
356
|
+
"p95_latency_ms": float(np.percentile(latencies_ms, 95)),
|
|
357
|
+
"eval_wall_time_s": eval_wall_time_s,
|
|
358
|
+
"eval_search_time_s": eval_search_time_s,
|
|
359
|
+
"qps": qps,
|
|
360
|
+
"num_queries_eval": int(len(queries)),
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _write_json_atomic(path: Path, data: Dict[str, Any]) -> None:
|
|
365
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
366
|
+
fd, tmp_path = tempfile.mkstemp(prefix=path.name + ".", dir=str(path.parent))
|
|
367
|
+
try:
|
|
368
|
+
with os.fdopen(fd, "w") as f:
|
|
369
|
+
json.dump(data, f, indent=2)
|
|
370
|
+
os.replace(tmp_path, path)
|
|
371
|
+
finally:
|
|
372
|
+
try:
|
|
373
|
+
if os.path.exists(tmp_path):
|
|
374
|
+
os.unlink(tmp_path)
|
|
375
|
+
except Exception:
|
|
376
|
+
pass
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _safe_filename(text: str) -> str:
|
|
380
|
+
out = []
|
|
381
|
+
for ch in str(text):
|
|
382
|
+
if ch.isalnum() or ch in ("-", "_", "."):
|
|
383
|
+
out.append(ch)
|
|
384
|
+
else:
|
|
385
|
+
out.append("_")
|
|
386
|
+
return "".join(out).strip("_")
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _index_beir_corpus(
|
|
390
|
+
*,
|
|
391
|
+
dataset_name: str,
|
|
392
|
+
corpus,
|
|
393
|
+
embedder: VisualEmbedder,
|
|
394
|
+
collection_name: str,
|
|
395
|
+
prefer_grpc: bool,
|
|
396
|
+
qdrant_vector_dtype: str,
|
|
397
|
+
recreate: bool,
|
|
398
|
+
indexing_threshold: int,
|
|
399
|
+
batch_size: int,
|
|
400
|
+
upload_batch_size: int,
|
|
401
|
+
upload_workers: int,
|
|
402
|
+
upsert_wait: bool,
|
|
403
|
+
max_corpus_docs: int,
|
|
404
|
+
sample_corpus_docs: int,
|
|
405
|
+
sample_corpus_strategy: str,
|
|
406
|
+
sample_seed: int,
|
|
407
|
+
payload_indexes: List[Dict[str, str]],
|
|
408
|
+
union_namespace: Optional[str],
|
|
409
|
+
model_name: str,
|
|
410
|
+
resume: bool,
|
|
411
|
+
qdrant_timeout: int,
|
|
412
|
+
full_scan_threshold: int,
|
|
413
|
+
crop_empty: bool,
|
|
414
|
+
crop_empty_percentage_to_remove: float,
|
|
415
|
+
crop_empty_remove_page_number: bool,
|
|
416
|
+
crop_empty_preserve_border_px: int,
|
|
417
|
+
crop_empty_uniform_std_threshold: float,
|
|
418
|
+
no_cloudinary: bool,
|
|
419
|
+
cloudinary_folder: str,
|
|
420
|
+
retry_failures: bool,
|
|
421
|
+
only_failures: bool,
|
|
422
|
+
) -> None:
|
|
423
|
+
qdrant_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
|
424
|
+
if not qdrant_url:
|
|
425
|
+
raise ValueError("QDRANT_URL not set")
|
|
426
|
+
qdrant_api_key = (
|
|
427
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
428
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
429
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
430
|
+
or os.getenv("QDRANT_API_KEY")
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
indexer = QdrantIndexer(
|
|
434
|
+
url=qdrant_url,
|
|
435
|
+
api_key=qdrant_api_key,
|
|
436
|
+
collection_name=collection_name,
|
|
437
|
+
prefer_grpc=prefer_grpc,
|
|
438
|
+
vector_datatype=qdrant_vector_dtype,
|
|
439
|
+
timeout=int(qdrant_timeout),
|
|
440
|
+
)
|
|
441
|
+
indexer.create_collection(
|
|
442
|
+
force_recreate=recreate,
|
|
443
|
+
indexing_threshold=indexing_threshold,
|
|
444
|
+
full_scan_threshold=int(full_scan_threshold),
|
|
445
|
+
)
|
|
446
|
+
indexer.create_payload_indexes(fields=payload_indexes)
|
|
447
|
+
|
|
448
|
+
cloudinary_uploader: Optional[CloudinaryUploader] = None
|
|
449
|
+
if not bool(no_cloudinary):
|
|
450
|
+
try:
|
|
451
|
+
cloudinary_uploader = CloudinaryUploader(folder=str(cloudinary_folder))
|
|
452
|
+
except Exception:
|
|
453
|
+
cloudinary_uploader = None
|
|
454
|
+
|
|
455
|
+
failure_log = _failed_log_path(collection_name=collection_name, dataset_name=dataset_name)
|
|
456
|
+
failed_ids = _load_failed_union_ids(failure_log, dataset_name=dataset_name, union_namespace=union_namespace)
|
|
457
|
+
previously_failed_ids = set(failed_ids)
|
|
458
|
+
|
|
459
|
+
existing_ids = set()
|
|
460
|
+
if resume:
|
|
461
|
+
offset = None
|
|
462
|
+
while True:
|
|
463
|
+
points, next_offset = indexer.client.scroll(
|
|
464
|
+
collection_name=collection_name,
|
|
465
|
+
limit=1000,
|
|
466
|
+
offset=offset,
|
|
467
|
+
with_payload=False,
|
|
468
|
+
with_vectors=False,
|
|
469
|
+
)
|
|
470
|
+
for p in points:
|
|
471
|
+
existing_ids.add(str(p.id))
|
|
472
|
+
if not next_offset or not points:
|
|
473
|
+
break
|
|
474
|
+
offset = next_offset
|
|
475
|
+
if not bool(retry_failures) and not bool(only_failures):
|
|
476
|
+
existing_ids |= failed_ids
|
|
477
|
+
|
|
478
|
+
target_ids = None
|
|
479
|
+
if bool(only_failures):
|
|
480
|
+
target_ids = set(str(x) for x in failed_ids)
|
|
481
|
+
if not target_ids:
|
|
482
|
+
print(f"No failed ids found for dataset={dataset_name}; nothing to retry.")
|
|
483
|
+
return
|
|
484
|
+
|
|
485
|
+
if sample_corpus_docs and sample_corpus_docs > 0:
|
|
486
|
+
corpus = _sample_list(
|
|
487
|
+
list(corpus),
|
|
488
|
+
k=int(sample_corpus_docs),
|
|
489
|
+
strategy=str(sample_corpus_strategy),
|
|
490
|
+
seed=int(sample_seed),
|
|
491
|
+
)
|
|
492
|
+
elif max_corpus_docs and max_corpus_docs > 0:
|
|
493
|
+
corpus = corpus[:max_corpus_docs]
|
|
494
|
+
total_docs = len(corpus)
|
|
495
|
+
points_buffer: List[Dict[str, Any]] = []
|
|
496
|
+
|
|
497
|
+
def _safe_public_id(s: str) -> str:
|
|
498
|
+
out = _safe_filename(str(s))
|
|
499
|
+
return out[:180] if len(out) > 180 else out
|
|
500
|
+
|
|
501
|
+
def _ensure_pil(img):
|
|
502
|
+
try:
|
|
503
|
+
from PIL import Image
|
|
504
|
+
except Exception:
|
|
505
|
+
return img
|
|
506
|
+
if img is None:
|
|
507
|
+
return None
|
|
508
|
+
if isinstance(img, Image.Image):
|
|
509
|
+
return img
|
|
510
|
+
try:
|
|
511
|
+
return img.convert("RGB")
|
|
512
|
+
except Exception:
|
|
513
|
+
return img
|
|
514
|
+
|
|
515
|
+
def _resized_for_display(img):
|
|
516
|
+
from PIL import Image
|
|
517
|
+
|
|
518
|
+
if img is None or not isinstance(img, Image.Image):
|
|
519
|
+
return None
|
|
520
|
+
out = img.copy()
|
|
521
|
+
out.thumbnail((1024, 1024), Image.BICUBIC)
|
|
522
|
+
return out
|
|
523
|
+
uploaded_docs = 0
|
|
524
|
+
skipped_docs = 0
|
|
525
|
+
start_time = time.time()
|
|
526
|
+
last_tick_time = start_time
|
|
527
|
+
last_tick_docs = 0
|
|
528
|
+
last_s_per_doc = 0.0
|
|
529
|
+
|
|
530
|
+
pbar = None
|
|
531
|
+
try:
|
|
532
|
+
from tqdm import tqdm
|
|
533
|
+
|
|
534
|
+
pbar = tqdm(total=total_docs, desc="📦 Indexing corpus", unit="doc")
|
|
535
|
+
except ImportError:
|
|
536
|
+
pass
|
|
537
|
+
|
|
538
|
+
import threading
|
|
539
|
+
from concurrent.futures import ThreadPoolExecutor, wait as futures_wait, FIRST_EXCEPTION
|
|
540
|
+
|
|
541
|
+
stop_event = threading.Event()
|
|
542
|
+
executor = ThreadPoolExecutor(max_workers=int(upload_workers)) if upload_workers and upload_workers > 0 else None
|
|
543
|
+
futures = []
|
|
544
|
+
|
|
545
|
+
def _upload(points: List[Dict[str, Any]]) -> int:
|
|
546
|
+
uploaded = int(indexer.upload_batch(points, delay_between_batches=0.0, wait=upsert_wait, stop_event=stop_event) or 0)
|
|
547
|
+
if uploaded <= 0 and points:
|
|
548
|
+
for p in points:
|
|
549
|
+
pid = str(p.get("id") or "")
|
|
550
|
+
if pid and pid not in failed_ids:
|
|
551
|
+
_append_jsonl(
|
|
552
|
+
failure_log,
|
|
553
|
+
{
|
|
554
|
+
"dataset": dataset_name,
|
|
555
|
+
"collection": collection_name,
|
|
556
|
+
"model": model_name,
|
|
557
|
+
"source_doc_id": str((p.get("metadata") or {}).get("source_doc_id") or ""),
|
|
558
|
+
"doc_id": str((p.get("metadata") or {}).get("doc_id") or ""),
|
|
559
|
+
"union_doc_id": pid,
|
|
560
|
+
"error": "Qdrant upsert failed (all retries exhausted)",
|
|
561
|
+
},
|
|
562
|
+
)
|
|
563
|
+
failed_ids.add(pid)
|
|
564
|
+
return uploaded
|
|
565
|
+
|
|
566
|
+
def _drain(block: bool) -> None:
|
|
567
|
+
nonlocal uploaded_docs
|
|
568
|
+
if not futures:
|
|
569
|
+
return
|
|
570
|
+
done, _ = futures_wait(futures, return_when=FIRST_EXCEPTION, timeout=None if block else 0)
|
|
571
|
+
for d in list(done):
|
|
572
|
+
futures.remove(d)
|
|
573
|
+
uploaded_docs += int(d.result() or 0)
|
|
574
|
+
|
|
575
|
+
try:
|
|
576
|
+
for start in range(0, total_docs, batch_size):
|
|
577
|
+
batch = corpus[start : start + batch_size]
|
|
578
|
+
batch_total = len(batch)
|
|
579
|
+
if target_ids is not None:
|
|
580
|
+
filtered = []
|
|
581
|
+
for d in batch:
|
|
582
|
+
source_doc_id = str((d.payload or {}).get("source_doc_id") or d.doc_id)
|
|
583
|
+
union_doc_id = _union_point_id(
|
|
584
|
+
dataset_name=dataset_name,
|
|
585
|
+
source_doc_id=source_doc_id,
|
|
586
|
+
union_namespace=union_namespace,
|
|
587
|
+
)
|
|
588
|
+
if union_doc_id in existing_ids:
|
|
589
|
+
skipped_docs += 1
|
|
590
|
+
continue
|
|
591
|
+
if union_doc_id in target_ids:
|
|
592
|
+
filtered.append(d)
|
|
593
|
+
else:
|
|
594
|
+
skipped_docs += 1
|
|
595
|
+
batch = filtered
|
|
596
|
+
elif existing_ids:
|
|
597
|
+
filtered = []
|
|
598
|
+
for d in batch:
|
|
599
|
+
source_doc_id = str((d.payload or {}).get("source_doc_id") or d.doc_id)
|
|
600
|
+
union_doc_id = _union_point_id(
|
|
601
|
+
dataset_name=dataset_name,
|
|
602
|
+
source_doc_id=source_doc_id,
|
|
603
|
+
union_namespace=union_namespace,
|
|
604
|
+
)
|
|
605
|
+
if union_doc_id in existing_ids:
|
|
606
|
+
skipped_docs += 1
|
|
607
|
+
else:
|
|
608
|
+
filtered.append(d)
|
|
609
|
+
batch = filtered
|
|
610
|
+
|
|
611
|
+
if not batch:
|
|
612
|
+
if pbar is not None:
|
|
613
|
+
pbar.update(batch_total)
|
|
614
|
+
now = time.time()
|
|
615
|
+
done_docs = int(pbar.n)
|
|
616
|
+
elapsed = max(now - start_time, 1e-9)
|
|
617
|
+
avg_s_per_doc = elapsed / max(done_docs, 1)
|
|
618
|
+
delta_docs = done_docs - last_tick_docs
|
|
619
|
+
delta_t = max(now - last_tick_time, 1e-9)
|
|
620
|
+
if delta_docs > 0:
|
|
621
|
+
last_s_per_doc = delta_t / delta_docs
|
|
622
|
+
last_tick_time = now
|
|
623
|
+
last_tick_docs = done_docs
|
|
624
|
+
pbar.set_postfix(
|
|
625
|
+
{
|
|
626
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
627
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
628
|
+
"upl": uploaded_docs,
|
|
629
|
+
"skip": skipped_docs,
|
|
630
|
+
}
|
|
631
|
+
)
|
|
632
|
+
continue
|
|
633
|
+
|
|
634
|
+
if crop_empty:
|
|
635
|
+
from visual_rag.preprocessing.crop_empty import CropEmptyConfig, crop_empty as _crop_empty
|
|
636
|
+
|
|
637
|
+
crop_cfg = CropEmptyConfig(
|
|
638
|
+
percentage_to_remove=float(crop_empty_percentage_to_remove),
|
|
639
|
+
remove_page_number=bool(crop_empty_remove_page_number),
|
|
640
|
+
preserve_border_px=int(crop_empty_preserve_border_px),
|
|
641
|
+
uniform_rowcol_std_threshold=float(crop_empty_uniform_std_threshold),
|
|
642
|
+
)
|
|
643
|
+
crop_metas = []
|
|
644
|
+
images = []
|
|
645
|
+
original_images = []
|
|
646
|
+
for d in batch:
|
|
647
|
+
original_img = _ensure_pil(d.image)
|
|
648
|
+
original_images.append(original_img)
|
|
649
|
+
cropped, meta = _crop_empty(original_img, config=crop_cfg)
|
|
650
|
+
images.append(cropped)
|
|
651
|
+
crop_metas.append(meta)
|
|
652
|
+
else:
|
|
653
|
+
original_images = [_ensure_pil(d.image) for d in batch]
|
|
654
|
+
images = original_images
|
|
655
|
+
crop_metas = [None for _ in batch]
|
|
656
|
+
try:
|
|
657
|
+
embeddings, token_infos = embedder.embed_images(
|
|
658
|
+
images,
|
|
659
|
+
batch_size=batch_size,
|
|
660
|
+
return_token_info=True,
|
|
661
|
+
show_progress=False,
|
|
662
|
+
)
|
|
663
|
+
except Exception as e:
|
|
664
|
+
# Retry per-doc to isolate flaky backend / corrupted sample issues.
|
|
665
|
+
embeddings = []
|
|
666
|
+
token_infos = []
|
|
667
|
+
for doc_i, img_i, crop_meta_i in zip(batch, images, crop_metas):
|
|
668
|
+
try:
|
|
669
|
+
e1, t1 = embedder.embed_images(
|
|
670
|
+
[img_i],
|
|
671
|
+
batch_size=1,
|
|
672
|
+
return_token_info=True,
|
|
673
|
+
show_progress=False,
|
|
674
|
+
)
|
|
675
|
+
embeddings.append(e1[0])
|
|
676
|
+
token_infos.append(t1[0])
|
|
677
|
+
except Exception as e_single:
|
|
678
|
+
source_doc_id_i = str((doc_i.payload or {}).get("source_doc_id") or doc_i.doc_id)
|
|
679
|
+
union_doc_id_i = _union_point_id(
|
|
680
|
+
dataset_name=dataset_name,
|
|
681
|
+
source_doc_id=source_doc_id_i,
|
|
682
|
+
union_namespace=union_namespace,
|
|
683
|
+
)
|
|
684
|
+
if str(union_doc_id_i) not in failed_ids:
|
|
685
|
+
_append_jsonl(
|
|
686
|
+
failure_log,
|
|
687
|
+
{
|
|
688
|
+
"dataset": dataset_name,
|
|
689
|
+
"collection": collection_name,
|
|
690
|
+
"model": model_name,
|
|
691
|
+
"source_doc_id": str(source_doc_id_i),
|
|
692
|
+
"doc_id": str(getattr(doc_i, "doc_id", "")),
|
|
693
|
+
"union_doc_id": str(union_doc_id_i),
|
|
694
|
+
"error": str(e_single),
|
|
695
|
+
},
|
|
696
|
+
)
|
|
697
|
+
failed_ids.add(str(union_doc_id_i))
|
|
698
|
+
existing_ids.add(str(union_doc_id_i))
|
|
699
|
+
skipped_docs += 1
|
|
700
|
+
if pbar is not None:
|
|
701
|
+
pbar.update(batch_total)
|
|
702
|
+
continue
|
|
703
|
+
|
|
704
|
+
for doc, emb, token_info, crop_meta, original_img, embed_img in zip(
|
|
705
|
+
batch, embeddings, token_infos, crop_metas, original_images, images
|
|
706
|
+
):
|
|
707
|
+
try:
|
|
708
|
+
emb_np = emb.cpu().float().numpy() if hasattr(emb, "cpu") else np.array(emb, dtype=np.float32)
|
|
709
|
+
visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
|
|
710
|
+
visual_embedding = emb_np[visual_indices].astype(np.float32)
|
|
711
|
+
tile_pooled = embedder.mean_pool_visual_embedding(visual_embedding, token_info, target_vectors=32)
|
|
712
|
+
experimental_pooled = embedder.experimental_pool_visual_embedding(
|
|
713
|
+
visual_embedding, token_info, target_vectors=32, mean_pool=tile_pooled
|
|
714
|
+
)
|
|
715
|
+
global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)
|
|
716
|
+
except Exception as e_single:
|
|
717
|
+
source_doc_id_i = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
|
|
718
|
+
union_doc_id_i = _union_point_id(
|
|
719
|
+
dataset_name=dataset_name,
|
|
720
|
+
source_doc_id=source_doc_id_i,
|
|
721
|
+
union_namespace=union_namespace,
|
|
722
|
+
)
|
|
723
|
+
if str(union_doc_id_i) not in failed_ids:
|
|
724
|
+
_append_jsonl(
|
|
725
|
+
failure_log,
|
|
726
|
+
{
|
|
727
|
+
"dataset": dataset_name,
|
|
728
|
+
"collection": collection_name,
|
|
729
|
+
"model": model_name,
|
|
730
|
+
"source_doc_id": str(source_doc_id_i),
|
|
731
|
+
"doc_id": str(getattr(doc, "doc_id", "")),
|
|
732
|
+
"union_doc_id": str(union_doc_id_i),
|
|
733
|
+
"error": str(e_single),
|
|
734
|
+
},
|
|
735
|
+
)
|
|
736
|
+
failed_ids.add(str(union_doc_id_i))
|
|
737
|
+
existing_ids.add(str(union_doc_id_i))
|
|
738
|
+
skipped_docs += 1
|
|
739
|
+
continue
|
|
740
|
+
|
|
741
|
+
num_tiles = int(tile_pooled.shape[0])
|
|
742
|
+
patches_per_tile = int(visual_embedding.shape[0] // max(num_tiles, 1)) if num_tiles else 0
|
|
743
|
+
|
|
744
|
+
source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
|
|
745
|
+
union_doc_id = _union_point_id(
|
|
746
|
+
dataset_name=dataset_name,
|
|
747
|
+
source_doc_id=source_doc_id,
|
|
748
|
+
union_namespace=union_namespace,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
resized_img = _resized_for_display(embed_img) or embed_img
|
|
752
|
+
original_url = ""
|
|
753
|
+
cropped_url = ""
|
|
754
|
+
resized_url = ""
|
|
755
|
+
if cloudinary_uploader is not None and original_img is not None and resized_img is not None:
|
|
756
|
+
base_public_id = _safe_public_id(f"{dataset_name}__{union_doc_id}")
|
|
757
|
+
try:
|
|
758
|
+
if crop_empty:
|
|
759
|
+
o_url, c_url, r_url = cloudinary_uploader.upload_original_cropped_and_resized(
|
|
760
|
+
original_img,
|
|
761
|
+
embed_img if embed_img is not None and embed_img is not original_img else None,
|
|
762
|
+
resized_img,
|
|
763
|
+
base_public_id,
|
|
764
|
+
)
|
|
765
|
+
original_url = o_url or ""
|
|
766
|
+
cropped_url = c_url or ""
|
|
767
|
+
resized_url = r_url or ""
|
|
768
|
+
else:
|
|
769
|
+
o_url, r_url = cloudinary_uploader.upload_original_and_resized(
|
|
770
|
+
original_img,
|
|
771
|
+
resized_img,
|
|
772
|
+
base_public_id,
|
|
773
|
+
)
|
|
774
|
+
original_url = o_url or ""
|
|
775
|
+
resized_url = r_url or ""
|
|
776
|
+
except Exception:
|
|
777
|
+
pass
|
|
778
|
+
|
|
779
|
+
payload = {
|
|
780
|
+
"dataset": dataset_name,
|
|
781
|
+
"doc_id": doc.doc_id,
|
|
782
|
+
"union_doc_id": union_doc_id,
|
|
783
|
+
"page": resized_url or original_url or "",
|
|
784
|
+
"original_url": original_url,
|
|
785
|
+
"cropped_url": cropped_url,
|
|
786
|
+
"resized_url": resized_url,
|
|
787
|
+
"original_width": int(original_img.width) if original_img is not None else None,
|
|
788
|
+
"original_height": int(original_img.height) if original_img is not None else None,
|
|
789
|
+
"cropped_width": int(embed_img.width) if embed_img is not None else None,
|
|
790
|
+
"cropped_height": int(embed_img.height) if embed_img is not None else None,
|
|
791
|
+
"resized_width": int(resized_img.width) if resized_img is not None else None,
|
|
792
|
+
"resized_height": int(resized_img.height) if resized_img is not None else None,
|
|
793
|
+
"num_tiles": int(num_tiles),
|
|
794
|
+
"patches_per_tile": int(patches_per_tile),
|
|
795
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
796
|
+
"model_name": model_name,
|
|
797
|
+
"crop_empty_enabled": bool(crop_empty),
|
|
798
|
+
"crop_empty_crop_box": (crop_meta or {}).get("crop_box") if crop_empty else None,
|
|
799
|
+
"crop_empty_remove_page_number": bool(crop_empty_remove_page_number) if crop_empty else None,
|
|
800
|
+
"crop_empty_percentage_to_remove": float(crop_empty_percentage_to_remove) if crop_empty else None,
|
|
801
|
+
"index_recovery_previously_failed": bool(union_doc_id in previously_failed_ids),
|
|
802
|
+
"index_recovery_mode": (
|
|
803
|
+
"only_failures" if bool(only_failures) else ("retry_failures" if bool(retry_failures) else None)
|
|
804
|
+
),
|
|
805
|
+
"index_recovery_pooling_inferred_tiles": bool(
|
|
806
|
+
(token_info or {}).get("num_tiles") is None and (token_info or {}).get("n_rows") is None and (token_info or {}).get("n_cols") is None
|
|
807
|
+
),
|
|
808
|
+
"index_recovery_num_visual_tokens": int(visual_embedding.shape[0]),
|
|
809
|
+
**(doc.payload or {}),
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
points_buffer.append(
|
|
813
|
+
{
|
|
814
|
+
"id": union_doc_id,
|
|
815
|
+
"visual_embedding": visual_embedding,
|
|
816
|
+
"tile_pooled_embedding": tile_pooled,
|
|
817
|
+
"experimental_pooled_embedding": experimental_pooled,
|
|
818
|
+
"global_pooled_embedding": global_pooled,
|
|
819
|
+
"metadata": payload,
|
|
820
|
+
}
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
if len(points_buffer) >= upload_batch_size:
|
|
824
|
+
chunk = points_buffer
|
|
825
|
+
points_buffer = []
|
|
826
|
+
if executor is None:
|
|
827
|
+
uploaded_docs += int(_upload(chunk) or 0)
|
|
828
|
+
else:
|
|
829
|
+
futures.append(executor.submit(_upload, chunk))
|
|
830
|
+
_drain(block=len(futures) >= int(upload_workers) * 2)
|
|
831
|
+
|
|
832
|
+
if pbar is not None:
|
|
833
|
+
pbar.update(len(batch))
|
|
834
|
+
now = time.time()
|
|
835
|
+
done_docs = int(pbar.n)
|
|
836
|
+
elapsed = max(now - start_time, 1e-9)
|
|
837
|
+
avg_s_per_doc = elapsed / max(done_docs, 1)
|
|
838
|
+
delta_docs = done_docs - last_tick_docs
|
|
839
|
+
delta_t = max(now - last_tick_time, 1e-9)
|
|
840
|
+
if delta_docs > 0:
|
|
841
|
+
last_s_per_doc = delta_t / delta_docs
|
|
842
|
+
last_tick_time = now
|
|
843
|
+
last_tick_docs = done_docs
|
|
844
|
+
pbar.set_postfix(
|
|
845
|
+
{
|
|
846
|
+
"avg_s/doc": f"{avg_s_per_doc:.2f}",
|
|
847
|
+
"last_s/doc": f"{last_s_per_doc:.2f}",
|
|
848
|
+
"upl": uploaded_docs,
|
|
849
|
+
"skip": skipped_docs,
|
|
850
|
+
}
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
if executor is not None:
|
|
854
|
+
_drain(block=False)
|
|
855
|
+
except KeyboardInterrupt:
|
|
856
|
+
stop_event.set()
|
|
857
|
+
if executor is not None:
|
|
858
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
859
|
+
raise
|
|
860
|
+
|
|
861
|
+
if points_buffer:
|
|
862
|
+
if executor is None:
|
|
863
|
+
uploaded_docs += int(_upload(points_buffer) or 0)
|
|
864
|
+
else:
|
|
865
|
+
futures.append(executor.submit(_upload, points_buffer))
|
|
866
|
+
|
|
867
|
+
if executor is not None:
|
|
868
|
+
_drain(block=True)
|
|
869
|
+
executor.shutdown(wait=True)
|
|
870
|
+
|
|
871
|
+
if pbar is not None:
|
|
872
|
+
pbar.close()
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def main() -> None:
|
|
876
|
+
parser = argparse.ArgumentParser()
|
|
877
|
+
parser.add_argument("--dataset", type=str, default=None)
|
|
878
|
+
parser.add_argument("--datasets", type=str, nargs="+", default=None)
|
|
879
|
+
parser.add_argument("--collection", type=str, required=True)
|
|
880
|
+
parser.add_argument("--model", type=str, default="vidore/colSmol-500M")
|
|
881
|
+
parser.add_argument(
|
|
882
|
+
"--processor-speed",
|
|
883
|
+
type=str,
|
|
884
|
+
default="fast",
|
|
885
|
+
choices=["fast", "slow", "auto"],
|
|
886
|
+
help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
|
|
887
|
+
)
|
|
888
|
+
parser.add_argument(
|
|
889
|
+
"--torch-dtype",
|
|
890
|
+
type=str,
|
|
891
|
+
default="auto",
|
|
892
|
+
choices=["auto", "float32", "float16", "bfloat16"],
|
|
893
|
+
)
|
|
894
|
+
parser.add_argument(
|
|
895
|
+
"--qdrant-vector-dtype",
|
|
896
|
+
type=str,
|
|
897
|
+
default="float16",
|
|
898
|
+
choices=["float16", "float32"],
|
|
899
|
+
)
|
|
900
|
+
grpc_group = parser.add_mutually_exclusive_group()
|
|
901
|
+
grpc_group.add_argument("--prefer-grpc", dest="prefer_grpc", action="store_true", default=True)
|
|
902
|
+
grpc_group.add_argument("--no-prefer-grpc", dest="prefer_grpc", action="store_false")
|
|
903
|
+
parser.add_argument("--index", action="store_true")
|
|
904
|
+
parser.add_argument("--recreate", action="store_true")
|
|
905
|
+
parser.add_argument("--indexing-threshold", type=int, default=0)
|
|
906
|
+
parser.add_argument("--full-scan-threshold", type=int, default=0)
|
|
907
|
+
parser.add_argument("--batch-size", type=int, default=4)
|
|
908
|
+
parser.add_argument("--upload-batch-size", type=int, default=8)
|
|
909
|
+
parser.add_argument("--upload-workers", type=int, default=0)
|
|
910
|
+
parser.add_argument("--upsert-wait", action="store_true")
|
|
911
|
+
parser.add_argument("--max-corpus-docs", type=int, default=0)
|
|
912
|
+
parser.add_argument("--sample-corpus-docs", type=int, default=0)
|
|
913
|
+
parser.add_argument("--sample-corpus-strategy", type=str, default="first", choices=["first", "random"])
|
|
914
|
+
parser.add_argument("--sample-seed", type=int, default=0)
|
|
915
|
+
parser.add_argument("--sample-queries", type=int, default=0)
|
|
916
|
+
parser.add_argument("--sample-query-strategy", type=str, default="first", choices=["first", "random"])
|
|
917
|
+
parser.add_argument("--sample-query-seed", type=int, default=0)
|
|
918
|
+
parser.add_argument("--index-from-queries", action="store_true", default=False)
|
|
919
|
+
parser.add_argument("--resume", action="store_true", default=False)
|
|
920
|
+
parser.add_argument("--qdrant-timeout", type=int, default=120)
|
|
921
|
+
parser.add_argument("--qdrant-retries", type=int, default=3)
|
|
922
|
+
parser.add_argument("--qdrant-retry-sleep", type=float, default=0.5)
|
|
923
|
+
parser.add_argument("--crop-empty", action="store_true", default=False)
|
|
924
|
+
parser.add_argument("--crop-empty-percentage-to-remove", type=float, default=0.9)
|
|
925
|
+
parser.add_argument("--crop-empty-remove-page-number", action="store_true", default=False)
|
|
926
|
+
parser.add_argument("--crop-empty-preserve-border-px", type=int, default=1)
|
|
927
|
+
parser.add_argument("--crop-empty-uniform-std-threshold", type=float, default=0.0)
|
|
928
|
+
payload_group = parser.add_mutually_exclusive_group()
|
|
929
|
+
payload_group.add_argument("--index-common-metadata", action="store_true", default=True)
|
|
930
|
+
payload_group.add_argument("--no-index-common-metadata", dest="index_common_metadata", action="store_false")
|
|
931
|
+
parser.add_argument("--payload-index", action="append", default=[])
|
|
932
|
+
parser.add_argument(
|
|
933
|
+
"--no-cloudinary",
|
|
934
|
+
action="store_true",
|
|
935
|
+
help="Disable Cloudinary uploads during indexing (default: enabled).",
|
|
936
|
+
)
|
|
937
|
+
parser.add_argument(
|
|
938
|
+
"--cloudinary-folder",
|
|
939
|
+
type=str,
|
|
940
|
+
default="vidore-beir",
|
|
941
|
+
help="Cloudinary base folder for uploads (default: vidore-beir).",
|
|
942
|
+
)
|
|
943
|
+
parser.add_argument(
|
|
944
|
+
"--retry-failures",
|
|
945
|
+
action="store_true",
|
|
946
|
+
default=False,
|
|
947
|
+
help="On --resume, retry documents listed in index_failures__<collection>__<dataset>.jsonl (default: skip them).",
|
|
948
|
+
)
|
|
949
|
+
parser.add_argument(
|
|
950
|
+
"--only-failures",
|
|
951
|
+
action="store_true",
|
|
952
|
+
default=False,
|
|
953
|
+
help="Index only documents listed in index_failures__<collection>__<dataset>.jsonl.",
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
parser.add_argument("--top-k", type=int, default=100, help="Retrieve top-k results (default: 100 to calculate metrics at all cutoffs)")
|
|
957
|
+
parser.add_argument("--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage (default: 200)")
|
|
958
|
+
parser.add_argument(
|
|
959
|
+
"--no-eval",
|
|
960
|
+
action="store_true",
|
|
961
|
+
default=False,
|
|
962
|
+
help="If set, run indexing only and skip evaluation.",
|
|
963
|
+
)
|
|
964
|
+
parser.add_argument(
|
|
965
|
+
"--mode",
|
|
966
|
+
type=str,
|
|
967
|
+
default="single_full",
|
|
968
|
+
choices=["single_full", "single_tiles", "single_global", "two_stage", "three_stage"],
|
|
969
|
+
)
|
|
970
|
+
parser.add_argument(
|
|
971
|
+
"--stage1-mode",
|
|
972
|
+
type=str,
|
|
973
|
+
default="tokens_vs_tiles",
|
|
974
|
+
choices=[
|
|
975
|
+
"pooled_query_vs_tiles",
|
|
976
|
+
"tokens_vs_tiles",
|
|
977
|
+
"pooled_query_vs_experimental",
|
|
978
|
+
"tokens_vs_experimental",
|
|
979
|
+
"pooled_query_vs_global",
|
|
980
|
+
],
|
|
981
|
+
)
|
|
982
|
+
parser.add_argument("--stage1-k", type=int, default=1000, help="Three-stage stage1 top_k (default: 1000)")
|
|
983
|
+
parser.add_argument("--stage2-k", type=int, default=300, help="Three-stage stage2 top_k (default: 300)")
|
|
984
|
+
parser.add_argument("--max-queries", type=int, default=0)
|
|
985
|
+
drop_group = parser.add_mutually_exclusive_group()
|
|
986
|
+
drop_group.add_argument("--drop-empty-queries", dest="drop_empty_queries", action="store_true", default=True)
|
|
987
|
+
drop_group.add_argument("--no-drop-empty-queries", dest="drop_empty_queries", action="store_false")
|
|
988
|
+
parser.add_argument(
|
|
989
|
+
"--evaluation-scope",
|
|
990
|
+
type=str,
|
|
991
|
+
default="union",
|
|
992
|
+
choices=["union", "per_dataset"],
|
|
993
|
+
help="Evaluation scope: 'union' searches over the whole collection (cross-dataset distractors). "
|
|
994
|
+
"'per_dataset' applies a Qdrant filter so each dataset's queries search only its own subset (leaderboard-comparable).",
|
|
995
|
+
)
|
|
996
|
+
cont_group = parser.add_mutually_exclusive_group()
|
|
997
|
+
cont_group.add_argument(
|
|
998
|
+
"--continue-on-error",
|
|
999
|
+
dest="continue_on_error",
|
|
1000
|
+
action="store_true",
|
|
1001
|
+
default=True,
|
|
1002
|
+
help="Continue evaluating remaining datasets if one dataset fails (default: true).",
|
|
1003
|
+
)
|
|
1004
|
+
cont_group.add_argument(
|
|
1005
|
+
"--no-continue-on-error",
|
|
1006
|
+
dest="continue_on_error",
|
|
1007
|
+
action="store_false",
|
|
1008
|
+
help="Stop the run immediately on the first dataset evaluation failure.",
|
|
1009
|
+
)
|
|
1010
|
+
parser.add_argument("--output", type=str, default="auto")
|
|
1011
|
+
|
|
1012
|
+
args = parser.parse_args()
|
|
1013
|
+
|
|
1014
|
+
_maybe_load_dotenv()
|
|
1015
|
+
|
|
1016
|
+
if args.recreate:
|
|
1017
|
+
args.index = True
|
|
1018
|
+
|
|
1019
|
+
if args.sample_corpus_docs and int(args.sample_corpus_docs) > 0 and args.max_corpus_docs and int(args.max_corpus_docs) > 0:
|
|
1020
|
+
raise ValueError("Use only one of --sample-corpus-docs or --max-corpus-docs (not both).")
|
|
1021
|
+
if args.sample_queries and int(args.sample_queries) > 0 and args.index_from_queries:
|
|
1022
|
+
if (args.sample_corpus_docs and int(args.sample_corpus_docs) > 0) or (args.max_corpus_docs and int(args.max_corpus_docs) > 0):
|
|
1023
|
+
raise ValueError("Use --index-from-queries with --sample-queries only (do not combine with corpus sampling).")
|
|
1024
|
+
|
|
1025
|
+
if args.upsert_wait:
|
|
1026
|
+
print("Qdrant upserts wait for completion (wait=True).")
|
|
1027
|
+
else:
|
|
1028
|
+
print("Qdrant upserts are async (wait=False).")
|
|
1029
|
+
print(f"Qdrant request timeout: {int(args.qdrant_timeout)}s, retries: {int(args.qdrant_retries)}.")
|
|
1030
|
+
|
|
1031
|
+
datasets: List[str] = []
|
|
1032
|
+
if args.datasets:
|
|
1033
|
+
datasets = list(args.datasets)
|
|
1034
|
+
elif args.dataset:
|
|
1035
|
+
datasets = [args.dataset]
|
|
1036
|
+
else:
|
|
1037
|
+
raise ValueError("Provide --dataset (single) or --datasets (one or more)")
|
|
1038
|
+
|
|
1039
|
+
if str(args.output).strip().lower() == "auto":
|
|
1040
|
+
args.output = _default_output_filename(args=args, datasets=datasets)
|
|
1041
|
+
|
|
1042
|
+
loaded: List[Tuple[str, Any, Any, Dict[str, Dict[str, int]]]] = []
|
|
1043
|
+
for ds_name in datasets:
|
|
1044
|
+
corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
|
|
1045
|
+
loaded.append((ds_name, corpus, queries, qrels))
|
|
1046
|
+
|
|
1047
|
+
output_dtype = np.float16 if args.qdrant_vector_dtype == "float16" else np.float32
|
|
1048
|
+
embedder = VisualEmbedder(
|
|
1049
|
+
model_name=args.model,
|
|
1050
|
+
batch_size=args.batch_size,
|
|
1051
|
+
torch_dtype=_parse_torch_dtype(args.torch_dtype),
|
|
1052
|
+
output_dtype=output_dtype,
|
|
1053
|
+
processor_speed=str(args.processor_speed),
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
selected: List[Tuple[str, Any, Any, Dict[str, Dict[str, int]]]] = []
|
|
1057
|
+
for ds_name, corpus, queries, qrels in loaded:
|
|
1058
|
+
if args.sample_queries and int(args.sample_queries) > 0:
|
|
1059
|
+
queries = _sample_list(
|
|
1060
|
+
list(queries),
|
|
1061
|
+
k=int(args.sample_queries),
|
|
1062
|
+
strategy=str(args.sample_query_strategy),
|
|
1063
|
+
seed=int(args.sample_query_seed),
|
|
1064
|
+
)
|
|
1065
|
+
qrels = _filter_qrels(qrels, [q.query_id for q in queries])
|
|
1066
|
+
|
|
1067
|
+
if args.index_from_queries and args.sample_queries and int(args.sample_queries) > 0:
|
|
1068
|
+
rel_doc_ids = set()
|
|
1069
|
+
for q in queries:
|
|
1070
|
+
for did, score in qrels.get(q.query_id, {}).items():
|
|
1071
|
+
if score > 0:
|
|
1072
|
+
rel_doc_ids.add(str(did))
|
|
1073
|
+
corpus = [d for d in corpus if str(d.doc_id) in rel_doc_ids]
|
|
1074
|
+
else:
|
|
1075
|
+
if args.sample_corpus_docs and int(args.sample_corpus_docs) > 0:
|
|
1076
|
+
corpus = _sample_list(
|
|
1077
|
+
list(corpus),
|
|
1078
|
+
k=int(args.sample_corpus_docs),
|
|
1079
|
+
strategy=str(args.sample_corpus_strategy),
|
|
1080
|
+
seed=int(args.sample_seed),
|
|
1081
|
+
)
|
|
1082
|
+
elif args.max_corpus_docs and int(args.max_corpus_docs) > 0:
|
|
1083
|
+
corpus = corpus[: int(args.max_corpus_docs)]
|
|
1084
|
+
|
|
1085
|
+
selected.append((ds_name, corpus, queries, qrels))
|
|
1086
|
+
|
|
1087
|
+
payload_indexes: List[Dict[str, str]] = []
|
|
1088
|
+
if args.index:
|
|
1089
|
+
payload_indexes = _parse_payload_indexes(args.payload_index)
|
|
1090
|
+
if args.index_common_metadata:
|
|
1091
|
+
payload_indexes.extend(
|
|
1092
|
+
[
|
|
1093
|
+
{"field": "dataset", "type": "keyword"},
|
|
1094
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
1095
|
+
{"field": "doc_id", "type": "keyword"},
|
|
1096
|
+
{"field": "filename", "type": "keyword"},
|
|
1097
|
+
{"field": "page", "type": "keyword"},
|
|
1098
|
+
{"field": "page_number", "type": "integer"},
|
|
1099
|
+
{"field": "total_pages", "type": "integer"},
|
|
1100
|
+
{"field": "has_text", "type": "bool"},
|
|
1101
|
+
{"field": "text", "type": "text"},
|
|
1102
|
+
{"field": "original_url", "type": "keyword"},
|
|
1103
|
+
{"field": "resized_url", "type": "keyword"},
|
|
1104
|
+
{"field": "original_width", "type": "integer"},
|
|
1105
|
+
{"field": "original_height", "type": "integer"},
|
|
1106
|
+
{"field": "resized_width", "type": "integer"},
|
|
1107
|
+
{"field": "resized_height", "type": "integer"},
|
|
1108
|
+
{"field": "crop_empty_enabled", "type": "bool"},
|
|
1109
|
+
{"field": "crop_empty_remove_page_number", "type": "bool"},
|
|
1110
|
+
{"field": "crop_empty_percentage_to_remove", "type": "float"},
|
|
1111
|
+
{"field": "num_tiles", "type": "integer"},
|
|
1112
|
+
{"field": "tile_rows", "type": "integer"},
|
|
1113
|
+
{"field": "tile_cols", "type": "integer"},
|
|
1114
|
+
{"field": "patches_per_tile", "type": "integer"},
|
|
1115
|
+
{"field": "num_visual_tokens", "type": "integer"},
|
|
1116
|
+
{"field": "processor_version", "type": "keyword"},
|
|
1117
|
+
{"field": "year", "type": "integer"},
|
|
1118
|
+
{"field": "source", "type": "keyword"},
|
|
1119
|
+
]
|
|
1120
|
+
)
|
|
1121
|
+
for i, (ds_name, corpus, queries, _qrels) in enumerate(selected):
|
|
1122
|
+
print(f"Indexing {ds_name}: corpus_docs={len(corpus)} queries={len(queries)}")
|
|
1123
|
+
_index_beir_corpus(
|
|
1124
|
+
dataset_name=ds_name,
|
|
1125
|
+
corpus=corpus,
|
|
1126
|
+
embedder=embedder,
|
|
1127
|
+
collection_name=args.collection,
|
|
1128
|
+
prefer_grpc=args.prefer_grpc,
|
|
1129
|
+
qdrant_vector_dtype=args.qdrant_vector_dtype,
|
|
1130
|
+
recreate=bool(args.recreate and i == 0),
|
|
1131
|
+
indexing_threshold=args.indexing_threshold,
|
|
1132
|
+
batch_size=args.batch_size,
|
|
1133
|
+
upload_batch_size=args.upload_batch_size,
|
|
1134
|
+
upload_workers=args.upload_workers,
|
|
1135
|
+
upsert_wait=bool(args.upsert_wait),
|
|
1136
|
+
max_corpus_docs=0,
|
|
1137
|
+
sample_corpus_docs=0,
|
|
1138
|
+
sample_corpus_strategy=str(args.sample_corpus_strategy),
|
|
1139
|
+
sample_seed=int(args.sample_seed),
|
|
1140
|
+
payload_indexes=payload_indexes,
|
|
1141
|
+
union_namespace=args.collection,
|
|
1142
|
+
model_name=args.model,
|
|
1143
|
+
resume=bool(args.resume),
|
|
1144
|
+
qdrant_timeout=int(args.qdrant_timeout),
|
|
1145
|
+
full_scan_threshold=int(args.full_scan_threshold),
|
|
1146
|
+
crop_empty=bool(args.crop_empty),
|
|
1147
|
+
crop_empty_percentage_to_remove=float(args.crop_empty_percentage_to_remove),
|
|
1148
|
+
crop_empty_remove_page_number=bool(args.crop_empty_remove_page_number),
|
|
1149
|
+
crop_empty_preserve_border_px=int(args.crop_empty_preserve_border_px),
|
|
1150
|
+
crop_empty_uniform_std_threshold=float(args.crop_empty_uniform_std_threshold),
|
|
1151
|
+
no_cloudinary=bool(args.no_cloudinary),
|
|
1152
|
+
cloudinary_folder=str(args.cloudinary_folder),
|
|
1153
|
+
retry_failures=bool(args.retry_failures),
|
|
1154
|
+
only_failures=bool(args.only_failures),
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
out_path = _resolve_output_path(args.output, collection_name=str(args.collection))
|
|
1158
|
+
|
|
1159
|
+
if bool(args.no_eval):
|
|
1160
|
+
dataset_index_failures: Dict[str, Dict[str, Any]] = {}
|
|
1161
|
+
dataset_counts: Dict[str, Dict[str, int]] = {}
|
|
1162
|
+
for ds_name, corpus, queries, _qrels in selected:
|
|
1163
|
+
dataset_counts[ds_name] = {"corpus_docs": int(len(corpus)), "queries": int(len(queries)), "queries_eval": 0}
|
|
1164
|
+
failed_path = _failed_log_path(collection_name=args.collection, dataset_name=ds_name)
|
|
1165
|
+
failed_ids = _load_failed_union_ids(failed_path, dataset_name=ds_name, union_namespace=args.collection)
|
|
1166
|
+
dataset_index_failures[ds_name] = {
|
|
1167
|
+
"failed_log_path": str(failed_path),
|
|
1168
|
+
"failed_ids_count": int(len(failed_ids)),
|
|
1169
|
+
"qrels_removed": None,
|
|
1170
|
+
}
|
|
1171
|
+
_write_json_atomic(
|
|
1172
|
+
out_path,
|
|
1173
|
+
{
|
|
1174
|
+
"command": " ".join(sys.argv),
|
|
1175
|
+
"dataset": datasets[0] if len(datasets) == 1 else None,
|
|
1176
|
+
"datasets": datasets,
|
|
1177
|
+
"protocol": "beir",
|
|
1178
|
+
"collection": args.collection,
|
|
1179
|
+
"model": args.model,
|
|
1180
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
1181
|
+
"qdrant_vector_dtype": args.qdrant_vector_dtype,
|
|
1182
|
+
"mode": args.mode,
|
|
1183
|
+
"stage1_mode": args.stage1_mode if args.mode == "two_stage" else None,
|
|
1184
|
+
"prefetch_k": args.prefetch_k if args.mode == "two_stage" else None,
|
|
1185
|
+
"top_k": args.top_k,
|
|
1186
|
+
"evaluation_scope": str(args.evaluation_scope),
|
|
1187
|
+
"dataset_counts": dataset_counts,
|
|
1188
|
+
"dataset_errors": {},
|
|
1189
|
+
"dataset_index_failures": dataset_index_failures,
|
|
1190
|
+
"qdrant_timeout": int(args.qdrant_timeout),
|
|
1191
|
+
"qdrant_retries": int(args.qdrant_retries),
|
|
1192
|
+
"qdrant_retry_sleep": float(args.qdrant_retry_sleep),
|
|
1193
|
+
"full_scan_threshold": int(args.full_scan_threshold),
|
|
1194
|
+
"no_eval": True,
|
|
1195
|
+
"eval_wall_time_s": None,
|
|
1196
|
+
"metrics": None,
|
|
1197
|
+
"metrics_by_dataset": {},
|
|
1198
|
+
},
|
|
1199
|
+
)
|
|
1200
|
+
print(f"Wrote index-only report: {out_path}")
|
|
1201
|
+
return
|
|
1202
|
+
|
|
1203
|
+
retriever = MultiVectorRetriever(
|
|
1204
|
+
collection_name=args.collection,
|
|
1205
|
+
embedder=embedder,
|
|
1206
|
+
qdrant_url=os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL"),
|
|
1207
|
+
qdrant_api_key=(
|
|
1208
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
1209
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
1210
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
1211
|
+
or os.getenv("QDRANT_API_KEY")
|
|
1212
|
+
),
|
|
1213
|
+
prefer_grpc=args.prefer_grpc,
|
|
1214
|
+
request_timeout=int(args.qdrant_timeout),
|
|
1215
|
+
max_retries=int(args.qdrant_retries),
|
|
1216
|
+
retry_sleep=float(args.qdrant_retry_sleep),
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
metrics_by_dataset: Dict[str, Dict[str, float]] = {}
|
|
1220
|
+
dataset_errors: Dict[str, str] = {}
|
|
1221
|
+
dataset_counts: Dict[str, Dict[str, int]] = {}
|
|
1222
|
+
dataset_index_failures: Dict[str, Dict[str, Any]] = {}
|
|
1223
|
+
eval_started_at = time.time()
|
|
1224
|
+
|
|
1225
|
+
def _build_run_record() -> Dict[str, Any]:
|
|
1226
|
+
single_dataset = datasets[0] if len(datasets) == 1 else None
|
|
1227
|
+
single_metrics = metrics_by_dataset.get(single_dataset) if single_dataset else None
|
|
1228
|
+
run_cmd = " ".join(sys.argv)
|
|
1229
|
+
return {
|
|
1230
|
+
"command": run_cmd,
|
|
1231
|
+
"dataset": single_dataset,
|
|
1232
|
+
"datasets": datasets,
|
|
1233
|
+
"protocol": "beir",
|
|
1234
|
+
"collection": args.collection,
|
|
1235
|
+
"model": args.model,
|
|
1236
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
1237
|
+
"qdrant_vector_dtype": args.qdrant_vector_dtype,
|
|
1238
|
+
"mode": args.mode,
|
|
1239
|
+
"stage1_mode": args.stage1_mode if args.mode == "two_stage" else None,
|
|
1240
|
+
"prefetch_k": args.prefetch_k if args.mode == "two_stage" else None,
|
|
1241
|
+
"stage1_k": int(args.stage1_k) if args.mode == "three_stage" else None,
|
|
1242
|
+
"stage2_k": int(args.stage2_k) if args.mode == "three_stage" else None,
|
|
1243
|
+
"top_k": args.top_k,
|
|
1244
|
+
"max_queries": args.max_queries,
|
|
1245
|
+
"max_corpus_docs": int(args.max_corpus_docs),
|
|
1246
|
+
"sample_corpus_docs": int(args.sample_corpus_docs),
|
|
1247
|
+
"sample_corpus_strategy": str(args.sample_corpus_strategy),
|
|
1248
|
+
"sample_seed": int(args.sample_seed),
|
|
1249
|
+
"sample_queries": int(args.sample_queries),
|
|
1250
|
+
"sample_query_strategy": str(args.sample_query_strategy),
|
|
1251
|
+
"sample_query_seed": int(args.sample_query_seed),
|
|
1252
|
+
"index_from_queries": bool(args.index_from_queries),
|
|
1253
|
+
"drop_empty_queries": bool(args.drop_empty_queries),
|
|
1254
|
+
"evaluation_scope": str(args.evaluation_scope),
|
|
1255
|
+
"payload_indexes": payload_indexes,
|
|
1256
|
+
"dataset_counts": dataset_counts,
|
|
1257
|
+
"dataset_errors": dataset_errors,
|
|
1258
|
+
"dataset_index_failures": dataset_index_failures,
|
|
1259
|
+
"qdrant_timeout": int(args.qdrant_timeout),
|
|
1260
|
+
"qdrant_retries": int(args.qdrant_retries),
|
|
1261
|
+
"qdrant_retry_sleep": float(args.qdrant_retry_sleep),
|
|
1262
|
+
"full_scan_threshold": int(args.full_scan_threshold),
|
|
1263
|
+
"eval_wall_time_s": float(max(time.time() - eval_started_at, 0.0)),
|
|
1264
|
+
"metrics": single_metrics,
|
|
1265
|
+
"metrics_by_dataset": metrics_by_dataset,
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
for ds_name, corpus, queries, qrels in selected:
|
|
1269
|
+
print(
|
|
1270
|
+
f"Evaluating dataset={ds_name} "
|
|
1271
|
+
f"(corpus_docs={len(corpus)}, queries={len(queries)}) "
|
|
1272
|
+
f"scope={args.evaluation_scope} "
|
|
1273
|
+
f"mode={args.mode}"
|
|
1274
|
+
+ (f", stage1_mode={args.stage1_mode}, prefetch_k={int(args.prefetch_k)}" if args.mode == "two_stage" else "")
|
|
1275
|
+
+ (f", stage1_k={int(args.stage1_k)}, stage2_k={int(args.stage2_k)}" if args.mode == "three_stage" else "")
|
|
1276
|
+
+ f", top_k={int(args.top_k)}"
|
|
1277
|
+
)
|
|
1278
|
+
sys.stdout.flush()
|
|
1279
|
+
|
|
1280
|
+
dataset_counts[ds_name] = {"corpus_docs": int(len(corpus)), "queries": int(len(queries)), "queries_eval": 0}
|
|
1281
|
+
id_map: Dict[str, str] = {}
|
|
1282
|
+
for doc in corpus:
|
|
1283
|
+
source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id)
|
|
1284
|
+
id_map[str(doc.doc_id)] = _union_point_id(
|
|
1285
|
+
dataset_name=ds_name,
|
|
1286
|
+
source_doc_id=source_doc_id,
|
|
1287
|
+
union_namespace=args.collection,
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
remapped_qrels: Dict[str, Dict[str, int]] = {}
|
|
1291
|
+
for qid, rels in qrels.items():
|
|
1292
|
+
out_rels: Dict[str, int] = {}
|
|
1293
|
+
for did, score in rels.items():
|
|
1294
|
+
mapped = id_map.get(str(did))
|
|
1295
|
+
if mapped:
|
|
1296
|
+
out_rels[mapped] = int(score)
|
|
1297
|
+
if out_rels:
|
|
1298
|
+
remapped_qrels[qid] = out_rels
|
|
1299
|
+
|
|
1300
|
+
failed_path = _failed_log_path(collection_name=args.collection, dataset_name=ds_name)
|
|
1301
|
+
failed_ids = _load_failed_union_ids(failed_path, dataset_name=ds_name, union_namespace=args.collection)
|
|
1302
|
+
remapped_qrels, removed_rels = _remove_failed_from_qrels(remapped_qrels, failed_ids)
|
|
1303
|
+
dataset_index_failures[ds_name] = {
|
|
1304
|
+
"failed_log_path": str(failed_path),
|
|
1305
|
+
"failed_ids_count": int(len(failed_ids)),
|
|
1306
|
+
"qrels_removed": int(removed_rels),
|
|
1307
|
+
}
|
|
1308
|
+
|
|
1309
|
+
filter_obj = None
|
|
1310
|
+
if args.evaluation_scope == "per_dataset":
|
|
1311
|
+
from qdrant_client.http import models as qmodels
|
|
1312
|
+
|
|
1313
|
+
filter_obj = qmodels.Filter(
|
|
1314
|
+
must=[qmodels.FieldCondition(key="dataset", match=qmodels.MatchValue(value=str(ds_name)))]
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
try:
|
|
1318
|
+
metrics_by_dataset[ds_name] = _evaluate(
|
|
1319
|
+
queries=queries,
|
|
1320
|
+
qrels=remapped_qrels,
|
|
1321
|
+
retriever=retriever,
|
|
1322
|
+
embedder=embedder,
|
|
1323
|
+
top_k=args.top_k,
|
|
1324
|
+
prefetch_k=args.prefetch_k,
|
|
1325
|
+
mode=args.mode,
|
|
1326
|
+
stage1_mode=args.stage1_mode,
|
|
1327
|
+
stage1_k=int(args.stage1_k),
|
|
1328
|
+
stage2_k=int(args.stage2_k),
|
|
1329
|
+
max_queries=int(args.max_queries),
|
|
1330
|
+
drop_empty_queries=bool(args.drop_empty_queries),
|
|
1331
|
+
filter_obj=filter_obj,
|
|
1332
|
+
)
|
|
1333
|
+
dataset_counts[ds_name]["queries_eval"] = int(metrics_by_dataset[ds_name].get("num_queries_eval", 0))
|
|
1334
|
+
ds_only_out = {
|
|
1335
|
+
**_build_run_record(),
|
|
1336
|
+
"dataset": str(ds_name),
|
|
1337
|
+
"datasets": [str(ds_name)],
|
|
1338
|
+
"metrics": metrics_by_dataset[ds_name],
|
|
1339
|
+
"metrics_by_dataset": {str(ds_name): metrics_by_dataset[ds_name]},
|
|
1340
|
+
}
|
|
1341
|
+
per_ds_path = out_path.with_name(
|
|
1342
|
+
f"{out_path.stem}__{_safe_filename(ds_name)}{out_path.suffix}"
|
|
1343
|
+
)
|
|
1344
|
+
_write_json_atomic(per_ds_path, ds_only_out)
|
|
1345
|
+
print(f"Wrote dataset report: {per_ds_path}")
|
|
1346
|
+
print(json.dumps({str(ds_name): metrics_by_dataset[ds_name]}, indent=2))
|
|
1347
|
+
sys.stdout.flush()
|
|
1348
|
+
except Exception as e:
|
|
1349
|
+
dataset_errors[ds_name] = f"{type(e).__name__}: {e}"
|
|
1350
|
+
if not bool(args.continue_on_error):
|
|
1351
|
+
_write_json_atomic(out_path, _build_run_record())
|
|
1352
|
+
raise
|
|
1353
|
+
finally:
|
|
1354
|
+
_write_json_atomic(out_path, _build_run_record())
|
|
1355
|
+
|
|
1356
|
+
single_dataset = datasets[0] if len(datasets) == 1 else None
|
|
1357
|
+
single_metrics = metrics_by_dataset.get(single_dataset) if single_dataset else None
|
|
1358
|
+
if len(datasets) == 1 and single_metrics is not None:
|
|
1359
|
+
print(json.dumps(single_metrics, indent=2))
|
|
1360
|
+
else:
|
|
1361
|
+
print(json.dumps(metrics_by_dataset, indent=2))
|
|
1362
|
+
|
|
1363
|
+
|
|
1364
|
+
if __name__ == "__main__":
|
|
1365
|
+
main()
|