visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/commands.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Command builders and code generators."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def build_index_command(config: Dict[str, Any]) -> str:
|
|
7
|
+
cmd_parts = ["python -m benchmarks.vidore_beir_qdrant.run_qdrant_beir"]
|
|
8
|
+
cmd_parts.append(f"--datasets {' '.join(config['datasets'])}")
|
|
9
|
+
cmd_parts.append(f"--collection {config['collection']}")
|
|
10
|
+
cmd_parts.append(f"--model {config['model']}")
|
|
11
|
+
cmd_parts.append("--index")
|
|
12
|
+
if config.get("recreate"):
|
|
13
|
+
cmd_parts.append("--recreate")
|
|
14
|
+
if config.get("resume"):
|
|
15
|
+
cmd_parts.append("--resume")
|
|
16
|
+
if config.get("prefer_grpc"):
|
|
17
|
+
cmd_parts.append("--prefer-grpc")
|
|
18
|
+
else:
|
|
19
|
+
cmd_parts.append("--no-prefer-grpc")
|
|
20
|
+
cmd_parts.append(f"--torch-dtype {config.get('torch_dtype', 'float16')}")
|
|
21
|
+
cmd_parts.append(f"--qdrant-vector-dtype {config.get('qdrant_vector_dtype', 'float16')}")
|
|
22
|
+
cmd_parts.append(f"--batch-size {config.get('batch_size', 4)}")
|
|
23
|
+
cmd_parts.append(f"--upload-batch-size {config.get('upload_batch_size', 8)}")
|
|
24
|
+
cmd_parts.append(f"--qdrant-timeout {config.get('qdrant_timeout', 180)}")
|
|
25
|
+
cmd_parts.append(f"--qdrant-retries {config.get('qdrant_retries', 5)}")
|
|
26
|
+
if config.get("crop_empty"):
|
|
27
|
+
cmd_parts.append("--crop-empty")
|
|
28
|
+
cmd_parts.append(f"--crop-empty-percentage-to-remove {config.get('crop_percentage', 0.99)}")
|
|
29
|
+
if config.get("no_cloudinary"):
|
|
30
|
+
cmd_parts.append("--no-cloudinary")
|
|
31
|
+
max_docs = config.get("max_docs")
|
|
32
|
+
if max_docs and max_docs > 0:
|
|
33
|
+
cmd_parts.append(f"--max-corpus-docs {max_docs}")
|
|
34
|
+
cmd_parts.append("--no-eval")
|
|
35
|
+
return " \\\n ".join(cmd_parts)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def generate_python_index_code(config: Dict[str, Any]) -> str:
|
|
39
|
+
datasets_str = ", ".join([f'"{ds}"' for ds in config.get("datasets", [])])
|
|
40
|
+
model = config.get("model", "vidore/colpali-v1.3")
|
|
41
|
+
collection = config.get("collection", "")
|
|
42
|
+
batch_size = config.get("batch_size", 4)
|
|
43
|
+
prefer_grpc = config.get("prefer_grpc", True)
|
|
44
|
+
crop_empty = config.get("crop_empty", False)
|
|
45
|
+
max_docs = config.get("max_docs")
|
|
46
|
+
|
|
47
|
+
torch_dtype = config.get("torch_dtype", "float16")
|
|
48
|
+
qdrant_dtype = config.get("qdrant_vector_dtype", "float16")
|
|
49
|
+
|
|
50
|
+
torch_dtype_map = {"float16": "torch.float16", "float32": "torch.float32", "bfloat16": "torch.bfloat16"}
|
|
51
|
+
torch_dtype_val = torch_dtype_map.get(torch_dtype, "torch.float16")
|
|
52
|
+
|
|
53
|
+
code_lines = [
|
|
54
|
+
"import os",
|
|
55
|
+
"import torch",
|
|
56
|
+
"from visual_rag import VisualEmbedder",
|
|
57
|
+
"from visual_rag.indexing import QdrantIndexer",
|
|
58
|
+
"from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset",
|
|
59
|
+
"",
|
|
60
|
+
"# Configuration",
|
|
61
|
+
f'COLLECTION = "{collection}"',
|
|
62
|
+
f'MODEL = "{model}"',
|
|
63
|
+
f"BATCH_SIZE = {batch_size}",
|
|
64
|
+
f'DATASETS = [{datasets_str}]',
|
|
65
|
+
f'TORCH_DTYPE = {torch_dtype_val}',
|
|
66
|
+
f'QDRANT_DTYPE = "{qdrant_dtype}"',
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
if max_docs:
|
|
70
|
+
code_lines.append(f"MAX_DOCS = {max_docs} # Limit docs per dataset")
|
|
71
|
+
|
|
72
|
+
code_lines.extend([
|
|
73
|
+
"",
|
|
74
|
+
"# Initialize embedder",
|
|
75
|
+
"embedder = VisualEmbedder(",
|
|
76
|
+
" model_name=MODEL,",
|
|
77
|
+
" torch_dtype=TORCH_DTYPE,",
|
|
78
|
+
")",
|
|
79
|
+
"",
|
|
80
|
+
"# Initialize indexer",
|
|
81
|
+
"indexer = QdrantIndexer(",
|
|
82
|
+
' url=os.getenv("QDRANT_URL"),',
|
|
83
|
+
' api_key=os.getenv("QDRANT_API_KEY"),',
|
|
84
|
+
" collection_name=COLLECTION,",
|
|
85
|
+
f" prefer_grpc={prefer_grpc},",
|
|
86
|
+
" vector_datatype=QDRANT_DTYPE,",
|
|
87
|
+
")",
|
|
88
|
+
"",
|
|
89
|
+
"# Create collection",
|
|
90
|
+
f"indexer.create_collection(force_recreate={config.get('recreate', False)})",
|
|
91
|
+
'indexer.create_payload_indexes(fields=[',
|
|
92
|
+
' {"field": "dataset", "type": "keyword"},',
|
|
93
|
+
' {"field": "doc_id", "type": "keyword"},',
|
|
94
|
+
' {"field": "source_doc_id", "type": "keyword"},',
|
|
95
|
+
"])",
|
|
96
|
+
"",
|
|
97
|
+
"# Index each dataset",
|
|
98
|
+
"for ds_name in DATASETS:",
|
|
99
|
+
" print(f'Loading {ds_name}...')",
|
|
100
|
+
" corpus, queries, qrels = load_vidore_beir_dataset(ds_name)",
|
|
101
|
+
])
|
|
102
|
+
|
|
103
|
+
if max_docs:
|
|
104
|
+
code_lines.append(" corpus = corpus[:MAX_DOCS] # Limit")
|
|
105
|
+
|
|
106
|
+
code_lines.extend([
|
|
107
|
+
" print(f'Indexing {len(corpus)} documents...')",
|
|
108
|
+
"",
|
|
109
|
+
" for i in range(0, len(corpus), BATCH_SIZE):",
|
|
110
|
+
" batch = corpus[i:i + BATCH_SIZE]",
|
|
111
|
+
" images = [doc.image for doc in batch]",
|
|
112
|
+
"",
|
|
113
|
+
" # Embed images",
|
|
114
|
+
" embeddings, token_infos = embedder.embed_images(",
|
|
115
|
+
" images, return_token_info=True",
|
|
116
|
+
" )",
|
|
117
|
+
"",
|
|
118
|
+
" # Build points with multi-vector representations",
|
|
119
|
+
" points = []",
|
|
120
|
+
" for doc, emb, info in zip(batch, embeddings, token_infos):",
|
|
121
|
+
" emb_np = emb.cpu().numpy()",
|
|
122
|
+
" visual_idx = info.get('visual_token_indices', range(len(emb_np)))",
|
|
123
|
+
" visual_emb = emb_np[visual_idx]",
|
|
124
|
+
"",
|
|
125
|
+
" tile_pooled = embedder.mean_pool_visual_embedding(visual_emb, info)",
|
|
126
|
+
" experimental = embedder.experimental_pool_visual_embedding(",
|
|
127
|
+
" visual_emb, info, mean_pool=tile_pooled",
|
|
128
|
+
" )",
|
|
129
|
+
" global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)",
|
|
130
|
+
"",
|
|
131
|
+
" points.append({",
|
|
132
|
+
' "id": f"{ds_name}_{doc.doc_id}",',
|
|
133
|
+
' "visual_embedding": visual_emb,',
|
|
134
|
+
' "tile_pooled_embedding": tile_pooled,',
|
|
135
|
+
' "experimental_pooled_embedding": experimental,',
|
|
136
|
+
' "global_pooled_embedding": global_pooled,',
|
|
137
|
+
' "metadata": {',
|
|
138
|
+
' "dataset": ds_name,',
|
|
139
|
+
' "doc_id": doc.doc_id,',
|
|
140
|
+
' "source_doc_id": doc.payload.get("source_doc_id"),',
|
|
141
|
+
" },",
|
|
142
|
+
" })",
|
|
143
|
+
"",
|
|
144
|
+
" indexer.upload_batch(points)",
|
|
145
|
+
" print(f' Batch {i//BATCH_SIZE + 1}: {len(points)} uploaded')",
|
|
146
|
+
"",
|
|
147
|
+
' print(f"Done: {ds_name}")',
|
|
148
|
+
])
|
|
149
|
+
|
|
150
|
+
if crop_empty:
|
|
151
|
+
code_lines.insert(3, "from visual_rag.preprocessing.crop_empty import crop_empty, CropEmptyConfig")
|
|
152
|
+
code_lines.insert(len(code_lines) - 20, " # Note: Add crop_empty() preprocessing before embedding")
|
|
153
|
+
|
|
154
|
+
return "\n".join(code_lines)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def build_eval_command(config: Dict[str, Any]) -> str:
|
|
158
|
+
cmd_parts = ["python -m benchmarks.vidore_beir_qdrant.run_qdrant_beir"]
|
|
159
|
+
cmd_parts.append(f"--datasets {' '.join(config['datasets'])}")
|
|
160
|
+
cmd_parts.append(f"--collection {config['collection']}")
|
|
161
|
+
cmd_parts.append(f"--model {config['model']}")
|
|
162
|
+
cmd_parts.append(f"--mode {config['mode']}")
|
|
163
|
+
if config["mode"] == "two_stage":
|
|
164
|
+
cmd_parts.append(f"--stage1-mode {config.get('stage1_mode', 'tokens_vs_tiles')}")
|
|
165
|
+
cmd_parts.append(f"--prefetch-k {config.get('prefetch_k', 256)}")
|
|
166
|
+
elif config["mode"] == "three_stage":
|
|
167
|
+
cmd_parts.append(f"--stage1-k {config.get('stage1_k', 1000)}")
|
|
168
|
+
cmd_parts.append(f"--stage2-k {config.get('stage2_k', 300)}")
|
|
169
|
+
cmd_parts.append(f"--top-k {config.get('top_k', 100)}")
|
|
170
|
+
cmd_parts.append(f"--evaluation-scope {config.get('evaluation_scope', 'union')}")
|
|
171
|
+
if config.get("prefer_grpc"):
|
|
172
|
+
cmd_parts.append("--prefer-grpc")
|
|
173
|
+
else:
|
|
174
|
+
cmd_parts.append("--no-prefer-grpc")
|
|
175
|
+
cmd_parts.append(f"--torch-dtype {config.get('torch_dtype', 'float16')}")
|
|
176
|
+
cmd_parts.append(f"--qdrant-vector-dtype {config.get('qdrant_vector_dtype', 'float16')}")
|
|
177
|
+
cmd_parts.append(f"--qdrant-timeout {config.get('qdrant_timeout', 180)}")
|
|
178
|
+
if config.get("result_prefix"):
|
|
179
|
+
cmd_parts.append(f"--output {config['result_prefix']}")
|
|
180
|
+
return " \\\n ".join(cmd_parts)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def generate_python_eval_code(config: Dict[str, Any]) -> str:
|
|
184
|
+
datasets_str = ", ".join([f'"{ds}"' for ds in config.get("datasets", [])])
|
|
185
|
+
mode = config.get("mode", "single_full")
|
|
186
|
+
model = config.get("model", "vidore/colpali-v1.3")
|
|
187
|
+
collection = config.get("collection", "")
|
|
188
|
+
top_k = config.get("top_k", 100)
|
|
189
|
+
scope = config.get("evaluation_scope", "union")
|
|
190
|
+
prefer_grpc = config.get("prefer_grpc", True)
|
|
191
|
+
torch_dtype = config.get("torch_dtype", "float16")
|
|
192
|
+
|
|
193
|
+
torch_dtype_map = {"float16": "torch.float16", "float32": "torch.float32", "bfloat16": "torch.bfloat16"}
|
|
194
|
+
torch_dtype_val = torch_dtype_map.get(torch_dtype, "torch.float16")
|
|
195
|
+
|
|
196
|
+
code_lines = [
|
|
197
|
+
"import os",
|
|
198
|
+
"import torch",
|
|
199
|
+
"from qdrant_client import QdrantClient",
|
|
200
|
+
"from visual_rag import VisualEmbedder",
|
|
201
|
+
"from visual_rag.retrieval import MultiVectorRetriever",
|
|
202
|
+
"",
|
|
203
|
+
"# Configuration",
|
|
204
|
+
f'COLLECTION = "{collection}"',
|
|
205
|
+
f'MODEL = "{model}"',
|
|
206
|
+
f"TOP_K = {top_k}",
|
|
207
|
+
f'DATASETS = [{datasets_str}]',
|
|
208
|
+
f"TORCH_DTYPE = {torch_dtype_val}",
|
|
209
|
+
"",
|
|
210
|
+
"# Initialize clients",
|
|
211
|
+
"client = QdrantClient(",
|
|
212
|
+
' url=os.getenv("QDRANT_URL"),',
|
|
213
|
+
' api_key=os.getenv("QDRANT_API_KEY"),',
|
|
214
|
+
f" prefer_grpc={prefer_grpc},",
|
|
215
|
+
")",
|
|
216
|
+
"",
|
|
217
|
+
"embedder = VisualEmbedder(",
|
|
218
|
+
" model_name=MODEL,",
|
|
219
|
+
" torch_dtype=TORCH_DTYPE,",
|
|
220
|
+
")",
|
|
221
|
+
"",
|
|
222
|
+
"# Initialize retriever",
|
|
223
|
+
"retriever = MultiVectorRetriever(",
|
|
224
|
+
" client=client,",
|
|
225
|
+
" collection_name=COLLECTION,",
|
|
226
|
+
" embedder=embedder,",
|
|
227
|
+
")",
|
|
228
|
+
"",
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
if mode == "single_full":
|
|
232
|
+
code_lines.extend([
|
|
233
|
+
"# Single-stage full retrieval",
|
|
234
|
+
"def search(query: str):",
|
|
235
|
+
" query_embedding = embedder.embed_query(query)",
|
|
236
|
+
" return retriever.search_single_stage(",
|
|
237
|
+
" query_embedding=query_embedding,",
|
|
238
|
+
f" limit={top_k},",
|
|
239
|
+
' vector_name="initial",',
|
|
240
|
+
" )",
|
|
241
|
+
])
|
|
242
|
+
elif mode == "single_tiles":
|
|
243
|
+
code_lines.extend([
|
|
244
|
+
"# Single-stage tiles retrieval",
|
|
245
|
+
"def search(query: str):",
|
|
246
|
+
" query_embedding = embedder.embed_query(query)",
|
|
247
|
+
" return retriever.search_single_stage(",
|
|
248
|
+
" query_embedding=query_embedding,",
|
|
249
|
+
f" limit={top_k},",
|
|
250
|
+
' vector_name="mean_pooling",',
|
|
251
|
+
" )",
|
|
252
|
+
])
|
|
253
|
+
elif mode == "single_global":
|
|
254
|
+
code_lines.extend([
|
|
255
|
+
"# Single-stage global retrieval",
|
|
256
|
+
"def search(query: str):",
|
|
257
|
+
" query_embedding = embedder.embed_query(query)",
|
|
258
|
+
" return retriever.search_single_stage(",
|
|
259
|
+
" query_embedding=query_embedding,",
|
|
260
|
+
f" limit={top_k},",
|
|
261
|
+
' vector_name="global_pooling",',
|
|
262
|
+
" )",
|
|
263
|
+
])
|
|
264
|
+
elif mode == "two_stage":
|
|
265
|
+
prefetch_k = config.get("prefetch_k", 256)
|
|
266
|
+
stage1_mode = config.get("stage1_mode", "tokens_vs_tiles")
|
|
267
|
+
code_lines.extend([
|
|
268
|
+
"# Two-stage retrieval",
|
|
269
|
+
"from visual_rag.retrieval import TwoStageRetriever",
|
|
270
|
+
"",
|
|
271
|
+
"two_stage = TwoStageRetriever(",
|
|
272
|
+
" client=client,",
|
|
273
|
+
" collection_name=COLLECTION,",
|
|
274
|
+
" embedder=embedder,",
|
|
275
|
+
")",
|
|
276
|
+
"",
|
|
277
|
+
"def search(query: str):",
|
|
278
|
+
" query_embedding = embedder.embed_query(query)",
|
|
279
|
+
" return two_stage.search(",
|
|
280
|
+
" query_embedding=query_embedding,",
|
|
281
|
+
f" prefetch_limit={prefetch_k},",
|
|
282
|
+
f" limit={top_k},",
|
|
283
|
+
f' stage1_mode="{stage1_mode}",',
|
|
284
|
+
" )",
|
|
285
|
+
])
|
|
286
|
+
elif mode == "three_stage":
|
|
287
|
+
stage1_k = config.get("stage1_k", 1000)
|
|
288
|
+
stage2_k = config.get("stage2_k", 300)
|
|
289
|
+
code_lines.extend([
|
|
290
|
+
"# Three-stage retrieval",
|
|
291
|
+
"from visual_rag.retrieval import ThreeStageRetriever",
|
|
292
|
+
"",
|
|
293
|
+
"three_stage = ThreeStageRetriever(",
|
|
294
|
+
" client=client,",
|
|
295
|
+
" collection_name=COLLECTION,",
|
|
296
|
+
" embedder=embedder,",
|
|
297
|
+
")",
|
|
298
|
+
"",
|
|
299
|
+
"def search(query: str):",
|
|
300
|
+
" query_embedding = embedder.embed_query(query)",
|
|
301
|
+
" return three_stage.search(",
|
|
302
|
+
" query_embedding=query_embedding,",
|
|
303
|
+
f" stage1_limit={stage1_k},",
|
|
304
|
+
f" stage2_limit={stage2_k},",
|
|
305
|
+
f" limit={top_k},",
|
|
306
|
+
" )",
|
|
307
|
+
])
|
|
308
|
+
|
|
309
|
+
if scope == "per_dataset":
|
|
310
|
+
code_lines.extend([
|
|
311
|
+
"",
|
|
312
|
+
"# Per-dataset filtering",
|
|
313
|
+
"from qdrant_client.models import Filter, FieldCondition, MatchValue",
|
|
314
|
+
"",
|
|
315
|
+
'def search_dataset(query: str, dataset: str = "vidore/esg_reports_v2"):',
|
|
316
|
+
" query_embedding = embedder.embed_query(query)",
|
|
317
|
+
" dataset_filter = Filter(",
|
|
318
|
+
" must=[FieldCondition(",
|
|
319
|
+
' key="dataset",',
|
|
320
|
+
" match=MatchValue(value=dataset),",
|
|
321
|
+
" )]",
|
|
322
|
+
" )",
|
|
323
|
+
" # Add filter to your search call",
|
|
324
|
+
])
|
|
325
|
+
|
|
326
|
+
code_lines.extend([
|
|
327
|
+
"",
|
|
328
|
+
'# Example usage',
|
|
329
|
+
'results = search("What is the company revenue?")',
|
|
330
|
+
'for r in results:',
|
|
331
|
+
' print(f"Score: {r.score:.4f}, Doc: {r.payload.get(\'doc_id\')}")',
|
|
332
|
+
])
|
|
333
|
+
|
|
334
|
+
return "\n".join(code_lines)
|
demo/config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Configuration constants for the demo app."""
|
|
2
|
+
|
|
3
|
+
AVAILABLE_MODELS = [
|
|
4
|
+
"vidore/colpali-v1.3",
|
|
5
|
+
"vidore/colSmol-500M",
|
|
6
|
+
]
|
|
7
|
+
|
|
8
|
+
BENCHMARK_DATASETS = [
|
|
9
|
+
"vidore/esg_reports_v2",
|
|
10
|
+
"vidore/biomedical_lectures_v2",
|
|
11
|
+
"vidore/economics_reports_v2",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
DATASET_STATS = {
|
|
15
|
+
"vidore/esg_reports_v2": {"docs": 1538, "queries": 228},
|
|
16
|
+
"vidore/biomedical_lectures_v2": {"docs": 1016, "queries": 640},
|
|
17
|
+
"vidore/economics_reports_v2": {"docs": 452, "queries": 232},
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
RETRIEVAL_MODES = [
|
|
21
|
+
"single_full",
|
|
22
|
+
"single_tiles",
|
|
23
|
+
"single_global",
|
|
24
|
+
"two_stage",
|
|
25
|
+
"three_stage",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
STAGE1_MODES = [
|
|
29
|
+
"tokens_vs_tiles",
|
|
30
|
+
"tokens_vs_experimental",
|
|
31
|
+
"pooled_query_vs_tiles",
|
|
32
|
+
"pooled_query_vs_experimental",
|
|
33
|
+
"pooled_query_vs_global",
|
|
34
|
+
]
|
demo/download_models.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Pre-download HuggingFace models for Visual RAG Toolkit.
|
|
3
|
+
|
|
4
|
+
This script downloads models during Docker build to cache them in the image,
|
|
5
|
+
avoiding download delays during container startup.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
os.environ.setdefault("HF_HOME", "/app/.cache/huggingface")
|
|
12
|
+
os.environ.setdefault("TRANSFORMERS_CACHE", "/app/.cache/huggingface")
|
|
13
|
+
|
|
14
|
+
MODELS_TO_DOWNLOAD = [
|
|
15
|
+
"vidore/colpali-v1.3",
|
|
16
|
+
"vidore/colSmol-500M",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
def download_colpali_models():
|
|
20
|
+
"""Download ColPali models and their processors."""
|
|
21
|
+
print("=" * 60)
|
|
22
|
+
print("Downloading ColPali models for Visual RAG Toolkit")
|
|
23
|
+
print("=" * 60)
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from colpali_engine.models import ColPali, ColPaliProcessor
|
|
27
|
+
except ImportError:
|
|
28
|
+
print("[WARN] colpali-engine not installed, trying transformers directly")
|
|
29
|
+
from transformers import AutoModel, AutoProcessor
|
|
30
|
+
|
|
31
|
+
for model_name in MODELS_TO_DOWNLOAD:
|
|
32
|
+
print(f"\n[INFO] Downloading model: {model_name}")
|
|
33
|
+
try:
|
|
34
|
+
AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
|
35
|
+
AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
36
|
+
print(f"[OK] Downloaded: {model_name}")
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f"[WARN] Could not download {model_name}: {e}")
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
for model_name in MODELS_TO_DOWNLOAD:
|
|
42
|
+
print(f"\n[INFO] Downloading model: {model_name}")
|
|
43
|
+
try:
|
|
44
|
+
if "colsmol" in model_name.lower():
|
|
45
|
+
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
|
46
|
+
ColQwen2.from_pretrained(model_name, trust_remote_code=True)
|
|
47
|
+
ColQwen2Processor.from_pretrained(model_name, trust_remote_code=True)
|
|
48
|
+
else:
|
|
49
|
+
ColPali.from_pretrained(model_name, trust_remote_code=True)
|
|
50
|
+
ColPaliProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
51
|
+
print(f"[OK] Downloaded: {model_name}")
|
|
52
|
+
except Exception as e:
|
|
53
|
+
print(f"[WARN] Could not download {model_name} with colpali-engine: {e}")
|
|
54
|
+
try:
|
|
55
|
+
from transformers import AutoModel, AutoProcessor
|
|
56
|
+
AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
|
57
|
+
AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
58
|
+
print(f"[OK] Downloaded via transformers: {model_name}")
|
|
59
|
+
except Exception as e2:
|
|
60
|
+
print(f"[ERROR] Failed to download {model_name}: {e2}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def main():
|
|
64
|
+
print(f"[INFO] HF_HOME: {os.environ.get('HF_HOME', 'not set')}")
|
|
65
|
+
print(f"[INFO] Cache dir: {os.environ.get('TRANSFORMERS_CACHE', 'not set')}")
|
|
66
|
+
|
|
67
|
+
download_colpali_models()
|
|
68
|
+
|
|
69
|
+
print("\n" + "=" * 60)
|
|
70
|
+
print("Model download complete!")
|
|
71
|
+
print("=" * 60)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
main()
|