visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
visual_rag/cli/main.py ADDED
@@ -0,0 +1,629 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Visual RAG Toolkit CLI
4
+
5
+ Provides command-line interface for:
6
+ - Processing PDFs (embedding, Cloudinary upload, Qdrant indexing)
7
+ - Searching documents
8
+ - Managing collections
9
+
10
+ Usage:
11
+ # Process PDFs (like process_pdfs_saliency_v2.py)
12
+ visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
13
+
14
+ # Search
15
+ visual-rag search --query "budget allocation" --collection my_docs
16
+
17
+ # Show collection info
18
+ visual-rag info --collection my_docs
19
+ """
20
+
21
+ import argparse
22
+ import logging
23
+ import os
24
+ import sys
25
+ from pathlib import Path
26
+ from urllib.parse import urlparse
27
+
28
+ from dotenv import load_dotenv
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def setup_logging(debug: bool = False):
34
+ """Configure logging."""
35
+ level = logging.DEBUG if debug else logging.INFO
36
+ logging.basicConfig(
37
+ level=level,
38
+ format="%(asctime)s - %(levelname)s - %(message)s",
39
+ force=True,
40
+ )
41
+
42
+
43
+ def cmd_process(args):
44
+ """
45
+ Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant.
46
+
47
+ Equivalent to process_pdfs_saliency_v2.py
48
+ """
49
+ from visual_rag import CloudinaryUploader, QdrantIndexer, VisualEmbedder, load_config
50
+ from visual_rag.indexing.pipeline import ProcessingPipeline
51
+
52
+ # Load environment
53
+ load_dotenv()
54
+
55
+ # Load config
56
+ config = {}
57
+ if args.config and Path(args.config).exists():
58
+ config = load_config(args.config)
59
+
60
+ # Get PDFs
61
+ reports_dir = Path(args.reports_dir)
62
+ if not reports_dir.exists():
63
+ logger.error(f"❌ Reports directory not found: {reports_dir}")
64
+ sys.exit(1)
65
+
66
+ pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF"))
67
+ if not pdf_paths:
68
+ logger.error(f"❌ No PDF files found in: {reports_dir}")
69
+ sys.exit(1)
70
+
71
+ logger.info(f"📁 Found {len(pdf_paths)} PDF files")
72
+
73
+ # Load metadata mapping
74
+ metadata_mapping = {}
75
+ if args.metadata_file:
76
+ metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file))
77
+
78
+ # Dry run - just show summary
79
+ if args.dry_run:
80
+ logger.info("🏃 DRY RUN MODE")
81
+ logger.info(f" PDFs: {len(pdf_paths)}")
82
+ logger.info(f" Metadata entries: {len(metadata_mapping)}")
83
+ logger.info(f" Collection: {args.collection}")
84
+ logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}")
85
+
86
+ for pdf in pdf_paths[:10]:
87
+ has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗"
88
+ logger.info(f" {has_meta} {pdf.name}")
89
+ if len(pdf_paths) > 10:
90
+ logger.info(f" ... and {len(pdf_paths) - 10} more")
91
+ return
92
+
93
+ # Get settings
94
+ model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M")
95
+ collection_name = args.collection or config.get("qdrant", {}).get(
96
+ "collection_name", "visual_documents"
97
+ )
98
+
99
+ torch_dtype = None
100
+ if args.torch_dtype != "auto":
101
+ import torch
102
+
103
+ torch_dtype = {
104
+ "float32": torch.float32,
105
+ "float16": torch.float16,
106
+ "bfloat16": torch.bfloat16,
107
+ }[args.torch_dtype]
108
+
109
+ logger.info(f"🤖 Initializing embedder: {model_name}")
110
+ embedder = VisualEmbedder(
111
+ model_name=model_name,
112
+ batch_size=args.batch_size,
113
+ torch_dtype=torch_dtype,
114
+ processor_speed=str(getattr(args, "processor_speed", "fast")),
115
+ )
116
+
117
+ # Initialize Qdrant indexer
118
+ qdrant_url = (
119
+ os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
120
+ )
121
+ qdrant_api_key = (
122
+ os.getenv("SIGIR_QDRANT_KEY")
123
+ or os.getenv("SIGIR_QDRANT_API_KEY")
124
+ or os.getenv("DEST_QDRANT_API_KEY")
125
+ or os.getenv("QDRANT_API_KEY")
126
+ )
127
+
128
+ if not qdrant_url:
129
+ logger.error("❌ QDRANT_URL environment variable not set")
130
+ sys.exit(1)
131
+
132
+ logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}")
133
+ indexer = QdrantIndexer(
134
+ url=qdrant_url,
135
+ api_key=qdrant_api_key,
136
+ collection_name=collection_name,
137
+ prefer_grpc=args.prefer_grpc,
138
+ vector_datatype=args.qdrant_vector_dtype,
139
+ )
140
+
141
+ # Create collection if needed
142
+ indexer.create_collection(force_recreate=args.force_recreate)
143
+ inferred_fields = []
144
+ inferred_fields.append({"field": "filename", "type": "keyword"})
145
+ inferred_fields.append({"field": "page_number", "type": "integer"})
146
+ inferred_fields.append({"field": "has_text", "type": "bool"})
147
+
148
+ if metadata_mapping:
149
+ keys = set()
150
+ for _, meta in metadata_mapping.items():
151
+ if isinstance(meta, dict):
152
+ keys.update(meta.keys())
153
+ for k in sorted(keys):
154
+ if k in ("filename", "page_number", "has_text"):
155
+ continue
156
+ inferred_type = "keyword"
157
+ for _, meta in metadata_mapping.items():
158
+ if not isinstance(meta, dict):
159
+ continue
160
+ v = meta.get(k)
161
+ if isinstance(v, bool):
162
+ inferred_type = "bool"
163
+ break
164
+ if isinstance(v, int):
165
+ inferred_type = "integer"
166
+ break
167
+ if isinstance(v, float):
168
+ inferred_type = "float"
169
+ break
170
+ inferred_fields.append({"field": k, "type": inferred_type})
171
+
172
+ indexer.create_payload_indexes(fields=inferred_fields)
173
+
174
+ # Initialize Cloudinary uploader (optional)
175
+ cloudinary_uploader = None
176
+ if not args.no_cloudinary:
177
+ try:
178
+ project_name = config.get("project_name", "visual_docs")
179
+ cloudinary_uploader = CloudinaryUploader(folder=project_name)
180
+ except ValueError as e:
181
+ logger.warning(f"⚠️ Cloudinary not configured: {e}")
182
+ logger.warning(" Continuing without Cloudinary uploads")
183
+
184
+ # Create pipeline
185
+ pipeline = ProcessingPipeline(
186
+ embedder=embedder,
187
+ indexer=indexer,
188
+ cloudinary_uploader=cloudinary_uploader,
189
+ metadata_mapping=metadata_mapping,
190
+ config=config,
191
+ embedding_strategy=args.strategy,
192
+ crop_empty=bool(getattr(args, "crop_empty", False)),
193
+ crop_empty_percentage_to_remove=float(
194
+ getattr(args, "crop_empty_percentage_to_remove", 0.9)
195
+ ),
196
+ crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)),
197
+ )
198
+
199
+ # Process PDFs
200
+ total_uploaded = 0
201
+ total_skipped = 0
202
+ total_failed = 0
203
+
204
+ skip_existing = not args.no_skip_existing
205
+
206
+ for pdf_idx, pdf_path in enumerate(pdf_paths, 1):
207
+ logger.info(f"\n{'='*60}")
208
+ logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}")
209
+ logger.info(f"{'='*60}")
210
+
211
+ result = pipeline.process_pdf(
212
+ pdf_path,
213
+ skip_existing=skip_existing,
214
+ upload_to_cloudinary=(not args.no_cloudinary),
215
+ upload_to_qdrant=True,
216
+ )
217
+
218
+ total_uploaded += result["uploaded"]
219
+ total_skipped += result["skipped"]
220
+ total_failed += result["failed"]
221
+
222
+ # Summary
223
+ logger.info(f"\n{'='*60}")
224
+ logger.info("📊 SUMMARY")
225
+ logger.info(f"{'='*60}")
226
+ logger.info(f" Total PDFs: {len(pdf_paths)}")
227
+ logger.info(f" Uploaded: {total_uploaded}")
228
+ logger.info(f" Skipped: {total_skipped}")
229
+ logger.info(f" Failed: {total_failed}")
230
+
231
+ info = indexer.get_collection_info()
232
+ if info:
233
+ logger.info(f" Collection points: {info.get('points_count', 'N/A')}")
234
+
235
+
236
+ def cmd_search(args):
237
+ """Search documents."""
238
+ from qdrant_client import QdrantClient
239
+
240
+ from visual_rag import VisualEmbedder
241
+ from visual_rag.retrieval import SingleStageRetriever, TwoStageRetriever
242
+
243
+ load_dotenv()
244
+
245
+ qdrant_url = (
246
+ os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
247
+ )
248
+ qdrant_api_key = (
249
+ os.getenv("SIGIR_QDRANT_KEY")
250
+ or os.getenv("SIGIR_QDRANT_API_KEY")
251
+ or os.getenv("DEST_QDRANT_API_KEY")
252
+ or os.getenv("QDRANT_API_KEY")
253
+ )
254
+
255
+ if not qdrant_url:
256
+ logger.error("❌ QDRANT_URL not set")
257
+ sys.exit(1)
258
+
259
+ # Initialize
260
+ logger.info(f"🤖 Loading model: {args.model}")
261
+ embedder = VisualEmbedder(
262
+ model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast"))
263
+ )
264
+
265
+ logger.info("🔌 Connecting to Qdrant")
266
+ grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
267
+ client = QdrantClient(
268
+ url=qdrant_url,
269
+ api_key=qdrant_api_key,
270
+ prefer_grpc=args.prefer_grpc,
271
+ grpc_port=grpc_port,
272
+ check_compatibility=False,
273
+ )
274
+ two_stage = TwoStageRetriever(client, args.collection)
275
+ single_stage = SingleStageRetriever(client, args.collection)
276
+
277
+ # Embed query
278
+ logger.info(f"🔍 Query: {args.query}")
279
+ query_embedding = embedder.embed_query(args.query)
280
+
281
+ # Build filter
282
+ filter_obj = None
283
+ if args.year or args.source or args.district:
284
+ filter_obj = two_stage.build_filter(
285
+ year=args.year,
286
+ source=args.source,
287
+ district=args.district,
288
+ )
289
+
290
+ # Search
291
+ query_np = query_embedding.detach().cpu().numpy()
292
+ if args.strategy == "single_full":
293
+ results = single_stage.search(
294
+ query_embedding=query_np,
295
+ top_k=args.top_k,
296
+ strategy="multi_vector",
297
+ filter_obj=filter_obj,
298
+ )
299
+ elif args.strategy == "single_tiles":
300
+ results = single_stage.search(
301
+ query_embedding=query_np,
302
+ top_k=args.top_k,
303
+ strategy="tiles_maxsim",
304
+ filter_obj=filter_obj,
305
+ )
306
+ elif args.strategy == "single_global":
307
+ results = single_stage.search(
308
+ query_embedding=query_np,
309
+ top_k=args.top_k,
310
+ strategy="pooled_global",
311
+ filter_obj=filter_obj,
312
+ )
313
+ else:
314
+ results = two_stage.search(
315
+ query_embedding=query_np,
316
+ top_k=args.top_k,
317
+ prefetch_k=args.prefetch_k,
318
+ filter_obj=filter_obj,
319
+ stage1_mode=args.stage1_mode,
320
+ )
321
+
322
+ # Display results
323
+ logger.info(f"\n📊 Results ({len(results)}):")
324
+ for i, result in enumerate(results, 1):
325
+ payload = result.get("payload", {})
326
+ score = result.get("score_final", result.get("score_stage1", 0))
327
+
328
+ filename = payload.get("filename", "N/A")
329
+ page_num = payload.get("page_number", "N/A")
330
+ year = payload.get("year", "N/A")
331
+ source = payload.get("source", "N/A")
332
+
333
+ logger.info(f" {i}. {filename} p.{page_num}")
334
+ logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}")
335
+
336
+ # Text snippet
337
+ text = payload.get("text", "")
338
+ if text and args.show_text:
339
+ snippet = text[:200].replace("\n", " ")
340
+ logger.info(f" Text: {snippet}...")
341
+
342
+
343
+ def cmd_info(args):
344
+ """Show collection info."""
345
+ from qdrant_client import QdrantClient
346
+
347
+ load_dotenv()
348
+
349
+ qdrant_url = (
350
+ os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
351
+ )
352
+ qdrant_api_key = (
353
+ os.getenv("SIGIR_QDRANT_KEY")
354
+ or os.getenv("SIGIR_QDRANT_API_KEY")
355
+ or os.getenv("DEST_QDRANT_API_KEY")
356
+ or os.getenv("QDRANT_API_KEY")
357
+ )
358
+
359
+ if not qdrant_url:
360
+ logger.error("❌ QDRANT_URL not set")
361
+ sys.exit(1)
362
+
363
+ grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None
364
+ client = QdrantClient(
365
+ url=qdrant_url,
366
+ api_key=qdrant_api_key,
367
+ prefer_grpc=args.prefer_grpc,
368
+ grpc_port=grpc_port,
369
+ check_compatibility=False,
370
+ )
371
+
372
+ try:
373
+ info = client.get_collection(args.collection)
374
+
375
+ status = info.status
376
+ if hasattr(status, "value"):
377
+ status = status.value
378
+
379
+ indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
380
+ if isinstance(indexed_count, dict):
381
+ indexed_count = sum(indexed_count.values())
382
+
383
+ logger.info(f"📊 Collection: {args.collection}")
384
+ logger.info(f" Status: {status}")
385
+ logger.info(f" Points: {info.points_count}")
386
+ logger.info(f" Indexed vectors: {indexed_count}")
387
+
388
+ # Show vector config
389
+ if hasattr(info, "config") and hasattr(info.config, "params"):
390
+ vectors = getattr(info.config.params, "vectors", {})
391
+ if vectors:
392
+ logger.info(f" Vectors: {list(vectors.keys())}")
393
+
394
+ except Exception as e:
395
+ logger.error(f"❌ Could not get collection info: {e}")
396
+ sys.exit(1)
397
+
398
+
399
+ def main():
400
+ """Main CLI entry point."""
401
+ parser = argparse.ArgumentParser(
402
+ prog="visual-rag",
403
+ description="Visual RAG Toolkit - Visual document retrieval with ColPali",
404
+ formatter_class=argparse.RawDescriptionHelpFormatter,
405
+ epilog="""
406
+ Examples:
407
+ # Process PDFs (like process_pdfs_saliency_v2.py)
408
+ visual-rag process --reports-dir ./pdfs --metadata-file metadata.json
409
+
410
+ # Process without Cloudinary
411
+ visual-rag process --reports-dir ./pdfs --no-cloudinary
412
+
413
+ # Search
414
+ visual-rag search --query "budget allocation" --collection my_docs
415
+
416
+ # Search with filters
417
+ visual-rag search --query "budget" --year 2023 --source "Local Government"
418
+
419
+ # Show collection info
420
+ visual-rag info --collection my_docs
421
+ """,
422
+ )
423
+ parser.add_argument("--debug", action="store_true", help="Enable debug logging")
424
+
425
+ subparsers = parser.add_subparsers(dest="command", help="Command")
426
+
427
+ # =========================================================================
428
+ # PROCESS command
429
+ # =========================================================================
430
+ process_parser = subparsers.add_parser(
431
+ "process",
432
+ help="Process PDFs: embed, upload to Cloudinary, index in Qdrant",
433
+ formatter_class=argparse.RawDescriptionHelpFormatter,
434
+ )
435
+ process_parser.add_argument(
436
+ "--reports-dir", type=str, required=True, help="Directory containing PDF files"
437
+ )
438
+ process_parser.add_argument(
439
+ "--metadata-file",
440
+ type=str,
441
+ help="JSON file with filename → metadata mapping (like filename_metadata.json)",
442
+ )
443
+ process_parser.add_argument(
444
+ "--collection", type=str, default="visual_documents", help="Qdrant collection name"
445
+ )
446
+ process_parser.add_argument(
447
+ "--model",
448
+ type=str,
449
+ default="vidore/colSmol-500M",
450
+ help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)",
451
+ )
452
+ process_parser.add_argument("--batch-size", type=int, default=8, help="Embedding batch size")
453
+ process_parser.add_argument("--config", type=str, help="Path to config.yaml file")
454
+ process_parser.add_argument(
455
+ "--no-cloudinary", action="store_true", help="Skip Cloudinary uploads"
456
+ )
457
+ process_parser.add_argument(
458
+ "--crop-empty",
459
+ action="store_true",
460
+ help="Crop empty whitespace from page images before embedding (default: off).",
461
+ )
462
+ process_parser.add_argument(
463
+ "--crop-empty-percentage-to-remove",
464
+ type=float,
465
+ default=0.9,
466
+ help="Kept for traceability; currently does not affect cropping behavior (default: 0.9).",
467
+ )
468
+ process_parser.add_argument(
469
+ "--crop-empty-remove-page-number",
470
+ action="store_true",
471
+ help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).",
472
+ )
473
+ process_parser.add_argument(
474
+ "--no-skip-existing",
475
+ action="store_true",
476
+ help="Process all pages even if they exist in Qdrant",
477
+ )
478
+ process_parser.add_argument(
479
+ "--force-recreate", action="store_true", help="Delete and recreate collection"
480
+ )
481
+ process_parser.add_argument(
482
+ "--dry-run", action="store_true", help="Show what would be processed without doing it"
483
+ )
484
+ process_parser.add_argument(
485
+ "--strategy",
486
+ type=str,
487
+ default="pooling",
488
+ choices=["pooling", "standard", "all"],
489
+ help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), "
490
+ "'all' (embed once, store BOTH for comparison)",
491
+ )
492
+ process_parser.add_argument(
493
+ "--torch-dtype",
494
+ type=str,
495
+ default="auto",
496
+ choices=["auto", "float32", "float16", "bfloat16"],
497
+ help="Torch dtype for model weights (default: auto; CUDA->bfloat16, else float32).",
498
+ )
499
+ process_parser.add_argument(
500
+ "--qdrant-vector-dtype",
501
+ type=str,
502
+ default="float16",
503
+ choices=["float16", "float32"],
504
+ help="Datatype for vectors stored in Qdrant (default: float16).",
505
+ )
506
+ process_parser.add_argument(
507
+ "--processor-speed",
508
+ type=str,
509
+ default="fast",
510
+ choices=["fast", "slow", "auto"],
511
+ help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
512
+ )
513
+ process_grpc_group = process_parser.add_mutually_exclusive_group()
514
+ process_grpc_group.add_argument(
515
+ "--prefer-grpc",
516
+ dest="prefer_grpc",
517
+ action="store_true",
518
+ default=True,
519
+ help="Use gRPC for Qdrant client (recommended).",
520
+ )
521
+ process_grpc_group.add_argument(
522
+ "--no-prefer-grpc",
523
+ dest="prefer_grpc",
524
+ action="store_false",
525
+ help="Disable gRPC for Qdrant client.",
526
+ )
527
+ process_parser.set_defaults(func=cmd_process)
528
+
529
+ # =========================================================================
530
+ # SEARCH command
531
+ # =========================================================================
532
+ search_parser = subparsers.add_parser(
533
+ "search",
534
+ help="Search documents",
535
+ )
536
+ search_parser.add_argument("--query", type=str, required=True, help="Search query")
537
+ search_parser.add_argument(
538
+ "--collection", type=str, default="visual_documents", help="Qdrant collection name"
539
+ )
540
+ search_parser.add_argument(
541
+ "--model", type=str, default="vidore/colSmol-500M", help="Model name"
542
+ )
543
+ search_parser.add_argument(
544
+ "--processor-speed",
545
+ type=str,
546
+ default="fast",
547
+ choices=["fast", "slow", "auto"],
548
+ help="Processor implementation: fast (default, with fallback to slow), slow, or auto.",
549
+ )
550
+ search_parser.add_argument("--top-k", type=int, default=10, help="Number of results")
551
+ search_parser.add_argument(
552
+ "--strategy",
553
+ type=str,
554
+ default="single_full",
555
+ choices=["single_full", "single_tiles", "single_global", "two_stage"],
556
+ help="Search strategy",
557
+ )
558
+ search_parser.add_argument(
559
+ "--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage retrieval"
560
+ )
561
+ search_parser.add_argument(
562
+ "--stage1-mode",
563
+ type=str,
564
+ default="pooled_query_vs_tiles",
565
+ choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
566
+ help="Stage 1 mode for two-stage retrieval",
567
+ )
568
+ search_parser.add_argument("--year", type=int, help="Filter by year")
569
+ search_parser.add_argument("--source", type=str, help="Filter by source")
570
+ search_parser.add_argument("--district", type=str, help="Filter by district")
571
+ search_parser.add_argument(
572
+ "--show-text", action="store_true", help="Show text snippets in results"
573
+ )
574
+ search_grpc_group = search_parser.add_mutually_exclusive_group()
575
+ search_grpc_group.add_argument(
576
+ "--prefer-grpc",
577
+ dest="prefer_grpc",
578
+ action="store_true",
579
+ default=True,
580
+ help="Use gRPC for Qdrant client (recommended).",
581
+ )
582
+ search_grpc_group.add_argument(
583
+ "--no-prefer-grpc",
584
+ dest="prefer_grpc",
585
+ action="store_false",
586
+ help="Disable gRPC for Qdrant client.",
587
+ )
588
+ search_parser.set_defaults(func=cmd_search)
589
+
590
+ # =========================================================================
591
+ # INFO command
592
+ # =========================================================================
593
+ info_parser = subparsers.add_parser(
594
+ "info",
595
+ help="Show collection info",
596
+ )
597
+ info_parser.add_argument(
598
+ "--collection", type=str, default="visual_documents", help="Qdrant collection name"
599
+ )
600
+ info_grpc_group = info_parser.add_mutually_exclusive_group()
601
+ info_grpc_group.add_argument(
602
+ "--prefer-grpc",
603
+ dest="prefer_grpc",
604
+ action="store_true",
605
+ default=True,
606
+ help="Use gRPC for Qdrant client (recommended).",
607
+ )
608
+ info_grpc_group.add_argument(
609
+ "--no-prefer-grpc",
610
+ dest="prefer_grpc",
611
+ action="store_false",
612
+ help="Disable gRPC for Qdrant client.",
613
+ )
614
+ info_parser.set_defaults(func=cmd_info)
615
+
616
+ # Parse and execute
617
+ args = parser.parse_args()
618
+
619
+ setup_logging(args.debug)
620
+
621
+ if not args.command:
622
+ parser.print_help()
623
+ sys.exit(0)
624
+
625
+ args.func(args)
626
+
627
+
628
+ if __name__ == "__main__":
629
+ main()