visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
visual_rag/cli/main.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Visual RAG Toolkit CLI
|
|
4
|
+
|
|
5
|
+
Provides command-line interface for:
|
|
6
|
+
- Processing PDFs (embedding, Cloudinary upload, Qdrant indexing)
|
|
7
|
+
- Searching documents
|
|
8
|
+
- Managing collections
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
# Process PDFs (like process_pdfs_saliency_v2.py)
|
|
12
|
+
visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
|
|
13
|
+
|
|
14
|
+
# Search
|
|
15
|
+
visual-rag search --query "budget allocation" --collection my_docs
|
|
16
|
+
|
|
17
|
+
# Show collection info
|
|
18
|
+
visual-rag info --collection my_docs
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import argparse
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
import sys
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from urllib.parse import urlparse
|
|
27
|
+
|
|
28
|
+
from dotenv import load_dotenv
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def setup_logging(debug: bool = False):
|
|
34
|
+
"""Configure logging."""
|
|
35
|
+
level = logging.DEBUG if debug else logging.INFO
|
|
36
|
+
logging.basicConfig(
|
|
37
|
+
level=level,
|
|
38
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
39
|
+
force=True,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def cmd_process(args):
|
|
44
|
+
"""
|
|
45
|
+
Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant.
|
|
46
|
+
|
|
47
|
+
Equivalent to process_pdfs_saliency_v2.py
|
|
48
|
+
"""
|
|
49
|
+
from visual_rag import CloudinaryUploader, QdrantIndexer, VisualEmbedder, load_config
|
|
50
|
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
|
51
|
+
|
|
52
|
+
# Load environment
|
|
53
|
+
load_dotenv()
|
|
54
|
+
|
|
55
|
+
# Load config
|
|
56
|
+
config = {}
|
|
57
|
+
if args.config and Path(args.config).exists():
|
|
58
|
+
config = load_config(args.config)
|
|
59
|
+
|
|
60
|
+
# Get PDFs
|
|
61
|
+
reports_dir = Path(args.reports_dir)
|
|
62
|
+
if not reports_dir.exists():
|
|
63
|
+
logger.error(f"❌ Reports directory not found: {reports_dir}")
|
|
64
|
+
sys.exit(1)
|
|
65
|
+
|
|
66
|
+
pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF"))
|
|
67
|
+
if not pdf_paths:
|
|
68
|
+
logger.error(f"❌ No PDF files found in: {reports_dir}")
|
|
69
|
+
sys.exit(1)
|
|
70
|
+
|
|
71
|
+
logger.info(f"📁 Found {len(pdf_paths)} PDF files")
|
|
72
|
+
|
|
73
|
+
# Load metadata mapping
|
|
74
|
+
metadata_mapping = {}
|
|
75
|
+
if args.metadata_file:
|
|
76
|
+
metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file))
|
|
77
|
+
|
|
78
|
+
# Dry run - just show summary
|
|
79
|
+
if args.dry_run:
|
|
80
|
+
logger.info("🏃 DRY RUN MODE")
|
|
81
|
+
logger.info(f" PDFs: {len(pdf_paths)}")
|
|
82
|
+
logger.info(f" Metadata entries: {len(metadata_mapping)}")
|
|
83
|
+
logger.info(f" Collection: {args.collection}")
|
|
84
|
+
logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}")
|
|
85
|
+
|
|
86
|
+
for pdf in pdf_paths[:10]:
|
|
87
|
+
has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗"
|
|
88
|
+
logger.info(f" {has_meta} {pdf.name}")
|
|
89
|
+
if len(pdf_paths) > 10:
|
|
90
|
+
logger.info(f" ... and {len(pdf_paths) - 10} more")
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
# Get settings
|
|
94
|
+
model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M")
|
|
95
|
+
collection_name = args.collection or config.get("qdrant", {}).get(
|
|
96
|
+
"collection_name", "visual_documents"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
torch_dtype = None
|
|
100
|
+
if args.torch_dtype != "auto":
|
|
101
|
+
import torch
|
|
102
|
+
|
|
103
|
+
torch_dtype = {
|
|
104
|
+
"float32": torch.float32,
|
|
105
|
+
"float16": torch.float16,
|
|
106
|
+
"bfloat16": torch.bfloat16,
|
|
107
|
+
}[args.torch_dtype]
|
|
108
|
+
|
|
109
|
+
logger.info(f"🤖 Initializing embedder: {model_name}")
|
|
110
|
+
embedder = VisualEmbedder(
|
|
111
|
+
model_name=model_name,
|
|
112
|
+
batch_size=args.batch_size,
|
|
113
|
+
torch_dtype=torch_dtype,
|
|
114
|
+
processor_speed=str(getattr(args, "processor_speed", "fast")),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Initialize Qdrant indexer
|
|
118
|
+
qdrant_url = (
|
|
119
|
+
os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
|
120
|
+
)
|
|
121
|
+
qdrant_api_key = (
|
|
122
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
123
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
124
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
125
|
+
or os.getenv("QDRANT_API_KEY")
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if not qdrant_url:
|
|
129
|
+
logger.error("❌ QDRANT_URL environment variable not set")
|
|
130
|
+
sys.exit(1)
|
|
131
|
+
|
|
132
|
+
logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}")
|
|
133
|
+
indexer = QdrantIndexer(
|
|
134
|
+
url=qdrant_url,
|
|
135
|
+
api_key=qdrant_api_key,
|
|
136
|
+
collection_name=collection_name,
|
|
137
|
+
prefer_grpc=args.prefer_grpc,
|
|
138
|
+
vector_datatype=args.qdrant_vector_dtype,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Create collection if needed
|
|
142
|
+
indexer.create_collection(force_recreate=args.force_recreate)
|
|
143
|
+
inferred_fields = []
|
|
144
|
+
inferred_fields.append({"field": "filename", "type": "keyword"})
|
|
145
|
+
inferred_fields.append({"field": "page_number", "type": "integer"})
|
|
146
|
+
inferred_fields.append({"field": "has_text", "type": "bool"})
|
|
147
|
+
|
|
148
|
+
if metadata_mapping:
|
|
149
|
+
keys = set()
|
|
150
|
+
for _, meta in metadata_mapping.items():
|
|
151
|
+
if isinstance(meta, dict):
|
|
152
|
+
keys.update(meta.keys())
|
|
153
|
+
for k in sorted(keys):
|
|
154
|
+
if k in ("filename", "page_number", "has_text"):
|
|
155
|
+
continue
|
|
156
|
+
inferred_type = "keyword"
|
|
157
|
+
for _, meta in metadata_mapping.items():
|
|
158
|
+
if not isinstance(meta, dict):
|
|
159
|
+
continue
|
|
160
|
+
v = meta.get(k)
|
|
161
|
+
if isinstance(v, bool):
|
|
162
|
+
inferred_type = "bool"
|
|
163
|
+
break
|
|
164
|
+
if isinstance(v, int):
|
|
165
|
+
inferred_type = "integer"
|
|
166
|
+
break
|
|
167
|
+
if isinstance(v, float):
|
|
168
|
+
inferred_type = "float"
|
|
169
|
+
break
|
|
170
|
+
inferred_fields.append({"field": k, "type": inferred_type})
|
|
171
|
+
|
|
172
|
+
indexer.create_payload_indexes(fields=inferred_fields)
|
|
173
|
+
|
|
174
|
+
# Initialize Cloudinary uploader (optional)
|
|
175
|
+
cloudinary_uploader = None
|
|
176
|
+
if not args.no_cloudinary:
|
|
177
|
+
try:
|
|
178
|
+
project_name = config.get("project_name", "visual_docs")
|
|
179
|
+
cloudinary_uploader = CloudinaryUploader(folder=project_name)
|
|
180
|
+
except ValueError as e:
|
|
181
|
+
logger.warning(f"⚠️ Cloudinary not configured: {e}")
|
|
182
|
+
logger.warning(" Continuing without Cloudinary uploads")
|
|
183
|
+
|
|
184
|
+
# Create pipeline
|
|
185
|
+
pipeline = ProcessingPipeline(
|
|
186
|
+
embedder=embedder,
|
|
187
|
+
indexer=indexer,
|
|
188
|
+
cloudinary_uploader=cloudinary_uploader,
|
|
189
|
+
metadata_mapping=metadata_mapping,
|
|
190
|
+
config=config,
|
|
191
|
+
embedding_strategy=args.strategy,
|
|
192
|
+
crop_empty=bool(getattr(args, "crop_empty", False)),
|
|
193
|
+
crop_empty_percentage_to_remove=float(
|
|
194
|
+
getattr(args, "crop_empty_percentage_to_remove", 0.9)
|
|
195
|
+
),
|
|
196
|
+
crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Process PDFs
|
|
200
|
+
total_uploaded = 0
|
|
201
|
+
total_skipped = 0
|
|
202
|
+
total_failed = 0
|
|
203
|
+
|
|
204
|
+
skip_existing = not args.no_skip_existing
|
|
205
|
+
|
|
206
|
+
for pdf_idx, pdf_path in enumerate(pdf_paths, 1):
|
|
207
|
+
logger.info(f"\n{'='*60}")
|
|
208
|
+
logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}")
|
|
209
|
+
logger.info(f"{'='*60}")
|
|
210
|
+
|
|
211
|
+
result = pipeline.process_pdf(
|
|
212
|
+
pdf_path,
|
|
213
|
+
skip_existing=skip_existing,
|
|
214
|
+
upload_to_cloudinary=(not args.no_cloudinary),
|
|
215
|
+
upload_to_qdrant=True,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
total_uploaded += result["uploaded"]
|
|
219
|
+
total_skipped += result["skipped"]
|
|
220
|
+
total_failed += result["failed"]
|
|
221
|
+
|
|
222
|
+
# Summary
|
|
223
|
+
logger.info(f"\n{'='*60}")
|
|
224
|
+
logger.info("📊 SUMMARY")
|
|
225
|
+
logger.info(f"{'='*60}")
|
|
226
|
+
logger.info(f" Total PDFs: {len(pdf_paths)}")
|
|
227
|
+
logger.info(f" Uploaded: {total_uploaded}")
|
|
228
|
+
logger.info(f" Skipped: {total_skipped}")
|
|
229
|
+
logger.info(f" Failed: {total_failed}")
|
|
230
|
+
|
|
231
|
+
info = indexer.get_collection_info()
|
|
232
|
+
if info:
|
|
233
|
+
logger.info(f" Collection points: {info.get('points_count', 'N/A')}")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def cmd_search(args):
|
|
237
|
+
"""Search documents."""
|
|
238
|
+
from qdrant_client import QdrantClient
|
|
239
|
+
|
|
240
|
+
from visual_rag import VisualEmbedder
|
|
241
|
+
from visual_rag.retrieval import SingleStageRetriever, TwoStageRetriever
|
|
242
|
+
|
|
243
|
+
load_dotenv()
|
|
244
|
+
|
|
245
|
+
qdrant_url = (
|
|
246
|
+
os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
|
247
|
+
)
|
|
248
|
+
qdrant_api_key = (
|
|
249
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
250
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
251
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
252
|
+
or os.getenv("QDRANT_API_KEY")
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if not qdrant_url:
|
|
256
|
+
logger.error("❌ QDRANT_URL not set")
|
|
257
|
+
sys.exit(1)
|
|
258
|
+
|
|
259
|
+
# Initialize
|
|
260
|
+
logger.info(f"🤖 Loading model: {args.model}")
|
|
261
|
+
embedder = VisualEmbedder(
|
|
262
|
+
model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast"))
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
logger.info("🔌 Connecting to Qdrant")
|
|
266
|
+
grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
|
|
267
|
+
client = QdrantClient(
|
|
268
|
+
url=qdrant_url,
|
|
269
|
+
api_key=qdrant_api_key,
|
|
270
|
+
prefer_grpc=args.prefer_grpc,
|
|
271
|
+
grpc_port=grpc_port,
|
|
272
|
+
check_compatibility=False,
|
|
273
|
+
)
|
|
274
|
+
two_stage = TwoStageRetriever(client, args.collection)
|
|
275
|
+
single_stage = SingleStageRetriever(client, args.collection)
|
|
276
|
+
|
|
277
|
+
# Embed query
|
|
278
|
+
logger.info(f"🔍 Query: {args.query}")
|
|
279
|
+
query_embedding = embedder.embed_query(args.query)
|
|
280
|
+
|
|
281
|
+
# Build filter
|
|
282
|
+
filter_obj = None
|
|
283
|
+
if args.year or args.source or args.district:
|
|
284
|
+
filter_obj = two_stage.build_filter(
|
|
285
|
+
year=args.year,
|
|
286
|
+
source=args.source,
|
|
287
|
+
district=args.district,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Search
|
|
291
|
+
query_np = query_embedding.detach().cpu().numpy()
|
|
292
|
+
if args.strategy == "single_full":
|
|
293
|
+
results = single_stage.search(
|
|
294
|
+
query_embedding=query_np,
|
|
295
|
+
top_k=args.top_k,
|
|
296
|
+
strategy="multi_vector",
|
|
297
|
+
filter_obj=filter_obj,
|
|
298
|
+
)
|
|
299
|
+
elif args.strategy == "single_tiles":
|
|
300
|
+
results = single_stage.search(
|
|
301
|
+
query_embedding=query_np,
|
|
302
|
+
top_k=args.top_k,
|
|
303
|
+
strategy="tiles_maxsim",
|
|
304
|
+
filter_obj=filter_obj,
|
|
305
|
+
)
|
|
306
|
+
elif args.strategy == "single_global":
|
|
307
|
+
results = single_stage.search(
|
|
308
|
+
query_embedding=query_np,
|
|
309
|
+
top_k=args.top_k,
|
|
310
|
+
strategy="pooled_global",
|
|
311
|
+
filter_obj=filter_obj,
|
|
312
|
+
)
|
|
313
|
+
else:
|
|
314
|
+
results = two_stage.search(
|
|
315
|
+
query_embedding=query_np,
|
|
316
|
+
top_k=args.top_k,
|
|
317
|
+
prefetch_k=args.prefetch_k,
|
|
318
|
+
filter_obj=filter_obj,
|
|
319
|
+
stage1_mode=args.stage1_mode,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Display results
|
|
323
|
+
logger.info(f"\n📊 Results ({len(results)}):")
|
|
324
|
+
for i, result in enumerate(results, 1):
|
|
325
|
+
payload = result.get("payload", {})
|
|
326
|
+
score = result.get("score_final", result.get("score_stage1", 0))
|
|
327
|
+
|
|
328
|
+
filename = payload.get("filename", "N/A")
|
|
329
|
+
page_num = payload.get("page_number", "N/A")
|
|
330
|
+
year = payload.get("year", "N/A")
|
|
331
|
+
source = payload.get("source", "N/A")
|
|
332
|
+
|
|
333
|
+
logger.info(f" {i}. {filename} p.{page_num}")
|
|
334
|
+
logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}")
|
|
335
|
+
|
|
336
|
+
# Text snippet
|
|
337
|
+
text = payload.get("text", "")
|
|
338
|
+
if text and args.show_text:
|
|
339
|
+
snippet = text[:200].replace("\n", " ")
|
|
340
|
+
logger.info(f" Text: {snippet}...")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def cmd_info(args):
|
|
344
|
+
"""Show collection info."""
|
|
345
|
+
from qdrant_client import QdrantClient
|
|
346
|
+
|
|
347
|
+
load_dotenv()
|
|
348
|
+
|
|
349
|
+
qdrant_url = (
|
|
350
|
+
os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
|
351
|
+
)
|
|
352
|
+
qdrant_api_key = (
|
|
353
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
354
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
355
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
356
|
+
or os.getenv("QDRANT_API_KEY")
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if not qdrant_url:
|
|
360
|
+
logger.error("❌ QDRANT_URL not set")
|
|
361
|
+
sys.exit(1)
|
|
362
|
+
|
|
363
|
+
grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
|
|
364
|
+
client = QdrantClient(
|
|
365
|
+
url=qdrant_url,
|
|
366
|
+
api_key=qdrant_api_key,
|
|
367
|
+
prefer_grpc=args.prefer_grpc,
|
|
368
|
+
grpc_port=grpc_port,
|
|
369
|
+
check_compatibility=False,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
info = client.get_collection(args.collection)
|
|
374
|
+
|
|
375
|
+
status = info.status
|
|
376
|
+
if hasattr(status, "value"):
|
|
377
|
+
status = status.value
|
|
378
|
+
|
|
379
|
+
indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
|
|
380
|
+
if isinstance(indexed_count, dict):
|
|
381
|
+
indexed_count = sum(indexed_count.values())
|
|
382
|
+
|
|
383
|
+
logger.info(f"📊 Collection: {args.collection}")
|
|
384
|
+
logger.info(f" Status: {status}")
|
|
385
|
+
logger.info(f" Points: {info.points_count}")
|
|
386
|
+
logger.info(f" Indexed vectors: {indexed_count}")
|
|
387
|
+
|
|
388
|
+
# Show vector config
|
|
389
|
+
if hasattr(info, "config") and hasattr(info.config, "params"):
|
|
390
|
+
vectors = getattr(info.config.params, "vectors", {})
|
|
391
|
+
if vectors:
|
|
392
|
+
logger.info(f" Vectors: {list(vectors.keys())}")
|
|
393
|
+
|
|
394
|
+
except Exception as e:
|
|
395
|
+
logger.error(f"❌ Could not get collection info: {e}")
|
|
396
|
+
sys.exit(1)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def main():
|
|
400
|
+
"""Main CLI entry point."""
|
|
401
|
+
parser = argparse.ArgumentParser(
|
|
402
|
+
prog="visual-rag",
|
|
403
|
+
description="Visual RAG Toolkit - Visual document retrieval with ColPali",
|
|
404
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
405
|
+
epilog="""
|
|
406
|
+
Examples:
|
|
407
|
+
# Process PDFs (like process_pdfs_saliency_v2.py)
|
|
408
|
+
visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
|
|
409
|
+
|
|
410
|
+
# Process without Cloudinary
|
|
411
|
+
visual-rag process --reports-dir ./pdfs --no-cloudinary
|
|
412
|
+
|
|
413
|
+
# Search
|
|
414
|
+
visual-rag search --query "budget allocation" --collection my_docs
|
|
415
|
+
|
|
416
|
+
# Search with filters
|
|
417
|
+
visual-rag search --query "budget" --year 2023 --source "Local Government"
|
|
418
|
+
|
|
419
|
+
# Show collection info
|
|
420
|
+
visual-rag info --collection my_docs
|
|
421
|
+
""",
|
|
422
|
+
)
|
|
423
|
+
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
|
424
|
+
|
|
425
|
+
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
426
|
+
|
|
427
|
+
# =========================================================================
|
|
428
|
+
# PROCESS command
|
|
429
|
+
# =========================================================================
|
|
430
|
+
process_parser = subparsers.add_parser(
|
|
431
|
+
"process",
|
|
432
|
+
help="Process PDFs: embed, upload to Cloudinary, index in Qdrant",
|
|
433
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
434
|
+
)
|
|
435
|
+
process_parser.add_argument(
|
|
436
|
+
"--reports-dir", type=str, required=True, help="Directory containing PDF files"
|
|
437
|
+
)
|
|
438
|
+
process_parser.add_argument(
|
|
439
|
+
"--metadata-file",
|
|
440
|
+
type=str,
|
|
441
|
+
help="JSON file with filename → metadata mapping (like filename_metadata.json)",
|
|
442
|
+
)
|
|
443
|
+
process_parser.add_argument(
|
|
444
|
+
"--collection", type=str, default="visual_documents", help="Qdrant collection name"
|
|
445
|
+
)
|
|
446
|
+
process_parser.add_argument(
|
|
447
|
+
"--model",
|
|
448
|
+
type=str,
|
|
449
|
+
default="vidore/colSmol-500M",
|
|
450
|
+
help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)",
|
|
451
|
+
)
|
|
452
|
+
process_parser.add_argument("--batch-size", type=int, default=8, help="Embedding batch size")
|
|
453
|
+
process_parser.add_argument("--config", type=str, help="Path to config.yaml file")
|
|
454
|
+
process_parser.add_argument(
|
|
455
|
+
"--no-cloudinary", action="store_true", help="Skip Cloudinary uploads"
|
|
456
|
+
)
|
|
457
|
+
process_parser.add_argument(
|
|
458
|
+
"--crop-empty",
|
|
459
|
+
action="store_true",
|
|
460
|
+
help="Crop empty whitespace from page images before embedding (default: off).",
|
|
461
|
+
)
|
|
462
|
+
process_parser.add_argument(
|
|
463
|
+
"--crop-empty-percentage-to-remove",
|
|
464
|
+
type=float,
|
|
465
|
+
default=0.9,
|
|
466
|
+
help="Kept for traceability; currently does not affect cropping behavior (default: 0.9).",
|
|
467
|
+
)
|
|
468
|
+
process_parser.add_argument(
|
|
469
|
+
"--crop-empty-remove-page-number",
|
|
470
|
+
action="store_true",
|
|
471
|
+
help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).",
|
|
472
|
+
)
|
|
473
|
+
process_parser.add_argument(
|
|
474
|
+
"--no-skip-existing",
|
|
475
|
+
action="store_true",
|
|
476
|
+
help="Process all pages even if they exist in Qdrant",
|
|
477
|
+
)
|
|
478
|
+
process_parser.add_argument(
|
|
479
|
+
"--force-recreate", action="store_true", help="Delete and recreate collection"
|
|
480
|
+
)
|
|
481
|
+
process_parser.add_argument(
|
|
482
|
+
"--dry-run", action="store_true", help="Show what would be processed without doing it"
|
|
483
|
+
)
|
|
484
|
+
process_parser.add_argument(
|
|
485
|
+
"--strategy",
|
|
486
|
+
type=str,
|
|
487
|
+
default="pooling",
|
|
488
|
+
choices=["pooling", "standard", "all"],
|
|
489
|
+
help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), "
|
|
490
|
+
"'all' (embed once, store BOTH for comparison)",
|
|
491
|
+
)
|
|
492
|
+
process_parser.add_argument(
|
|
493
|
+
"--torch-dtype",
|
|
494
|
+
type=str,
|
|
495
|
+
default="auto",
|
|
496
|
+
choices=["auto", "float32", "float16", "bfloat16"],
|
|
497
|
+
help="Torch dtype for model weights (default: auto; CUDA->bfloat16, else float32).",
|
|
498
|
+
)
|
|
499
|
+
process_parser.add_argument(
|
|
500
|
+
"--qdrant-vector-dtype",
|
|
501
|
+
type=str,
|
|
502
|
+
default="float16",
|
|
503
|
+
choices=["float16", "float32"],
|
|
504
|
+
help="Datatype for vectors stored in Qdrant (default: float16).",
|
|
505
|
+
)
|
|
506
|
+
process_parser.add_argument(
|
|
507
|
+
"--processor-speed",
|
|
508
|
+
type=str,
|
|
509
|
+
default="fast",
|
|
510
|
+
choices=["fast", "slow", "auto"],
|
|
511
|
+
help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
|
|
512
|
+
)
|
|
513
|
+
process_grpc_group = process_parser.add_mutually_exclusive_group()
|
|
514
|
+
process_grpc_group.add_argument(
|
|
515
|
+
"--prefer-grpc",
|
|
516
|
+
dest="prefer_grpc",
|
|
517
|
+
action="store_true",
|
|
518
|
+
default=True,
|
|
519
|
+
help="Use gRPC for Qdrant client (recommended).",
|
|
520
|
+
)
|
|
521
|
+
process_grpc_group.add_argument(
|
|
522
|
+
"--no-prefer-grpc",
|
|
523
|
+
dest="prefer_grpc",
|
|
524
|
+
action="store_false",
|
|
525
|
+
help="Disable gRPC for Qdrant client.",
|
|
526
|
+
)
|
|
527
|
+
process_parser.set_defaults(func=cmd_process)
|
|
528
|
+
|
|
529
|
+
# =========================================================================
|
|
530
|
+
# SEARCH command
|
|
531
|
+
# =========================================================================
|
|
532
|
+
search_parser = subparsers.add_parser(
|
|
533
|
+
"search",
|
|
534
|
+
help="Search documents",
|
|
535
|
+
)
|
|
536
|
+
search_parser.add_argument("--query", type=str, required=True, help="Search query")
|
|
537
|
+
search_parser.add_argument(
|
|
538
|
+
"--collection", type=str, default="visual_documents", help="Qdrant collection name"
|
|
539
|
+
)
|
|
540
|
+
search_parser.add_argument(
|
|
541
|
+
"--model", type=str, default="vidore/colSmol-500M", help="Model name"
|
|
542
|
+
)
|
|
543
|
+
search_parser.add_argument(
|
|
544
|
+
"--processor-speed",
|
|
545
|
+
type=str,
|
|
546
|
+
default="fast",
|
|
547
|
+
choices=["fast", "slow", "auto"],
|
|
548
|
+
help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
|
|
549
|
+
)
|
|
550
|
+
search_parser.add_argument("--top-k", type=int, default=10, help="Number of results")
|
|
551
|
+
search_parser.add_argument(
|
|
552
|
+
"--strategy",
|
|
553
|
+
type=str,
|
|
554
|
+
default="single_full",
|
|
555
|
+
choices=["single_full", "single_tiles", "single_global", "two_stage"],
|
|
556
|
+
help="Search strategy",
|
|
557
|
+
)
|
|
558
|
+
search_parser.add_argument(
|
|
559
|
+
"--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage retrieval"
|
|
560
|
+
)
|
|
561
|
+
search_parser.add_argument(
|
|
562
|
+
"--stage1-mode",
|
|
563
|
+
type=str,
|
|
564
|
+
default="pooled_query_vs_tiles",
|
|
565
|
+
choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
|
|
566
|
+
help="Stage 1 mode for two-stage retrieval",
|
|
567
|
+
)
|
|
568
|
+
search_parser.add_argument("--year", type=int, help="Filter by year")
|
|
569
|
+
search_parser.add_argument("--source", type=str, help="Filter by source")
|
|
570
|
+
search_parser.add_argument("--district", type=str, help="Filter by district")
|
|
571
|
+
search_parser.add_argument(
|
|
572
|
+
"--show-text", action="store_true", help="Show text snippets in results"
|
|
573
|
+
)
|
|
574
|
+
search_grpc_group = search_parser.add_mutually_exclusive_group()
|
|
575
|
+
search_grpc_group.add_argument(
|
|
576
|
+
"--prefer-grpc",
|
|
577
|
+
dest="prefer_grpc",
|
|
578
|
+
action="store_true",
|
|
579
|
+
default=True,
|
|
580
|
+
help="Use gRPC for Qdrant client (recommended).",
|
|
581
|
+
)
|
|
582
|
+
search_grpc_group.add_argument(
|
|
583
|
+
"--no-prefer-grpc",
|
|
584
|
+
dest="prefer_grpc",
|
|
585
|
+
action="store_false",
|
|
586
|
+
help="Disable gRPC for Qdrant client.",
|
|
587
|
+
)
|
|
588
|
+
search_parser.set_defaults(func=cmd_search)
|
|
589
|
+
|
|
590
|
+
# =========================================================================
|
|
591
|
+
# INFO command
|
|
592
|
+
# =========================================================================
|
|
593
|
+
info_parser = subparsers.add_parser(
|
|
594
|
+
"info",
|
|
595
|
+
help="Show collection info",
|
|
596
|
+
)
|
|
597
|
+
info_parser.add_argument(
|
|
598
|
+
"--collection", type=str, default="visual_documents", help="Qdrant collection name"
|
|
599
|
+
)
|
|
600
|
+
info_grpc_group = info_parser.add_mutually_exclusive_group()
|
|
601
|
+
info_grpc_group.add_argument(
|
|
602
|
+
"--prefer-grpc",
|
|
603
|
+
dest="prefer_grpc",
|
|
604
|
+
action="store_true",
|
|
605
|
+
default=True,
|
|
606
|
+
help="Use gRPC for Qdrant client (recommended).",
|
|
607
|
+
)
|
|
608
|
+
info_grpc_group.add_argument(
|
|
609
|
+
"--no-prefer-grpc",
|
|
610
|
+
dest="prefer_grpc",
|
|
611
|
+
action="store_false",
|
|
612
|
+
help="Disable gRPC for Qdrant client.",
|
|
613
|
+
)
|
|
614
|
+
info_parser.set_defaults(func=cmd_info)
|
|
615
|
+
|
|
616
|
+
# Parse and execute
|
|
617
|
+
args = parser.parse_args()
|
|
618
|
+
|
|
619
|
+
setup_logging(args.debug)
|
|
620
|
+
|
|
621
|
+
if not args.command:
|
|
622
|
+
parser.print_help()
|
|
623
|
+
sys.exit(0)
|
|
624
|
+
|
|
625
|
+
args.func(args)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
if __name__ == "__main__":
|
|
629
|
+
main()
|