visual-rag-toolkit 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- demo/app.py +20 -8
- demo/evaluation.py +5 -45
- demo/indexing.py +180 -192
- demo/qdrant_utils.py +12 -5
- demo/ui/playground.py +1 -1
- demo/ui/sidebar.py +4 -3
- demo/ui/upload.py +5 -4
- visual_rag/__init__.py +43 -1
- visual_rag/config.py +4 -7
- visual_rag/indexing/__init__.py +21 -4
- visual_rag/indexing/qdrant_indexer.py +92 -42
- visual_rag/retrieval/multi_vector.py +63 -65
- visual_rag/retrieval/single_stage.py +7 -0
- visual_rag/retrieval/two_stage.py +8 -10
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/METADATA +98 -17
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/RECORD +19 -20
- benchmarks/vidore_tatdqa_test/COMMANDS.md +0 -83
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/WHEEL +0 -0
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/entry_points.txt +0 -0
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/licenses/LICENSE +0 -0
demo/ui/upload.py
CHANGED
|
@@ -9,6 +9,7 @@ import inspect
|
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
|
|
12
|
+
import numpy as np
|
|
12
13
|
import streamlit as st
|
|
13
14
|
|
|
14
15
|
from demo.config import AVAILABLE_MODELS
|
|
@@ -17,6 +18,10 @@ from demo.qdrant_utils import (
|
|
|
17
18
|
get_collection_stats,
|
|
18
19
|
sample_points_cached,
|
|
19
20
|
)
|
|
21
|
+
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
|
22
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
23
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
24
|
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
|
20
25
|
|
|
21
26
|
|
|
22
27
|
VECTOR_TYPES = ["initial", "mean_pooling", "experimental_pooling", "global_pooling"]
|
|
@@ -251,10 +256,6 @@ def process_pdfs(uploaded_files, config):
|
|
|
251
256
|
model_short = model_name.split("/")[-1]
|
|
252
257
|
model_status.info(f"Loading `{model_short}`...")
|
|
253
258
|
|
|
254
|
-
import numpy as np
|
|
255
|
-
from visual_rag import VisualEmbedder
|
|
256
|
-
from visual_rag.indexing import QdrantIndexer, CloudinaryUploader, ProcessingPipeline
|
|
257
|
-
|
|
258
259
|
output_dtype = np.float16 if vector_dtype == "float16" else np.float32
|
|
259
260
|
embedder_key = f"{model_name}::{vector_dtype}"
|
|
260
261
|
embedder = None
|
visual_rag/__init__.py
CHANGED
|
@@ -31,7 +31,47 @@ Quick Start:
|
|
|
31
31
|
Each component works independently - use only what you need.
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
import logging
|
|
35
|
+
|
|
36
|
+
__version__ = "0.1.3"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def setup_logging(level: str = "INFO", format: str = None) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Configure logging for visual_rag package.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
level: Log level ("DEBUG", "INFO", "WARNING", "ERROR")
|
|
45
|
+
format: Custom format string. Default shows time, level, and message.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> import visual_rag
|
|
49
|
+
>>> visual_rag.setup_logging("INFO")
|
|
50
|
+
>>> # Now you'll see processing logs
|
|
51
|
+
"""
|
|
52
|
+
if format is None:
|
|
53
|
+
format = "[%(asctime)s] %(levelname)s - %(message)s"
|
|
54
|
+
|
|
55
|
+
logging.basicConfig(
|
|
56
|
+
level=getattr(logging, level.upper(), logging.INFO),
|
|
57
|
+
format=format,
|
|
58
|
+
datefmt="%H:%M:%S",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Also set the visual_rag logger specifically
|
|
62
|
+
logger = logging.getLogger("visual_rag")
|
|
63
|
+
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# Enable INFO logging by default for visual_rag package and all submodules
|
|
67
|
+
# This ensures logs like "Processing PDF...", "Embedding pages..." are visible
|
|
68
|
+
_logger = logging.getLogger("visual_rag")
|
|
69
|
+
if not _logger.handlers:
|
|
70
|
+
_handler = logging.StreamHandler()
|
|
71
|
+
_handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s", datefmt="%H:%M:%S"))
|
|
72
|
+
_logger.addHandler(_handler)
|
|
73
|
+
_logger.setLevel(logging.INFO)
|
|
74
|
+
_logger.propagate = False # Don't duplicate to root logger
|
|
35
75
|
|
|
36
76
|
# Import main classes at package level for convenience
|
|
37
77
|
# These are optional - if dependencies aren't installed, we catch the error
|
|
@@ -95,4 +135,6 @@ __all__ = [
|
|
|
95
135
|
"load_config",
|
|
96
136
|
"get",
|
|
97
137
|
"get_section",
|
|
138
|
+
# Logging
|
|
139
|
+
"setup_logging",
|
|
98
140
|
]
|
visual_rag/config.py
CHANGED
|
@@ -21,16 +21,13 @@ _raw_config_cache_path: Optional[str] = None
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def _env_qdrant_url() -> Optional[str]:
|
|
24
|
-
|
|
24
|
+
"""Get Qdrant URL from environment. Prefers QDRANT_URL."""
|
|
25
|
+
return os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy fallback
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def _env_qdrant_api_key() -> Optional[str]:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
31
|
-
or os.getenv("DEST_QDRANT_API_KEY")
|
|
32
|
-
or os.getenv("QDRANT_API_KEY")
|
|
33
|
-
)
|
|
29
|
+
"""Get Qdrant API key from environment. Prefers QDRANT_API_KEY."""
|
|
30
|
+
return os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy fallback
|
|
34
31
|
|
|
35
32
|
|
|
36
33
|
def load_config(
|
visual_rag/indexing/__init__.py
CHANGED
|
@@ -8,10 +8,27 @@ Components:
|
|
|
8
8
|
- ProcessingPipeline: End-to-end PDF → Qdrant pipeline
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from visual_rag.indexing.
|
|
11
|
+
# Lazy imports to avoid failures when optional dependencies aren't installed
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from visual_rag.indexing.pdf_processor import PDFProcessor
|
|
15
|
+
except ImportError:
|
|
16
|
+
PDFProcessor = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
20
|
+
except ImportError:
|
|
21
|
+
QdrantIndexer = None
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
25
|
+
except ImportError:
|
|
26
|
+
CloudinaryUploader = None
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
|
30
|
+
except ImportError:
|
|
31
|
+
ProcessingPipeline = None
|
|
15
32
|
|
|
16
33
|
__all__ = [
|
|
17
34
|
"PDFProcessor",
|
|
@@ -19,6 +19,23 @@ from urllib.parse import urlparse
|
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
|
+
try:
|
|
23
|
+
from qdrant_client import QdrantClient
|
|
24
|
+
from qdrant_client.http import models as qdrant_models
|
|
25
|
+
from qdrant_client.http.models import Distance, VectorParams
|
|
26
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
27
|
+
|
|
28
|
+
QDRANT_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
QDRANT_AVAILABLE = False
|
|
31
|
+
QdrantClient = None
|
|
32
|
+
qdrant_models = None
|
|
33
|
+
Distance = None
|
|
34
|
+
VectorParams = None
|
|
35
|
+
FieldCondition = None
|
|
36
|
+
Filter = None
|
|
37
|
+
MatchValue = None
|
|
38
|
+
|
|
22
39
|
logger = logging.getLogger(__name__)
|
|
23
40
|
|
|
24
41
|
|
|
@@ -58,9 +75,7 @@ class QdrantIndexer:
|
|
|
58
75
|
prefer_grpc: bool = False,
|
|
59
76
|
vector_datatype: str = "float32",
|
|
60
77
|
):
|
|
61
|
-
|
|
62
|
-
from qdrant_client import QdrantClient
|
|
63
|
-
except ImportError:
|
|
78
|
+
if not QDRANT_AVAILABLE:
|
|
64
79
|
raise ImportError(
|
|
65
80
|
"Qdrant client not installed. "
|
|
66
81
|
"Install with: pip install visual-rag-toolkit[qdrant]"
|
|
@@ -139,9 +154,6 @@ class QdrantIndexer:
|
|
|
139
154
|
Returns:
|
|
140
155
|
True if created, False if already existed
|
|
141
156
|
"""
|
|
142
|
-
from qdrant_client.http import models
|
|
143
|
-
from qdrant_client.http.models import Distance, VectorParams
|
|
144
|
-
|
|
145
157
|
if self.collection_exists():
|
|
146
158
|
if force_recreate:
|
|
147
159
|
logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
|
|
@@ -153,15 +165,15 @@ class QdrantIndexer:
|
|
|
153
165
|
logger.info(f"📦 Creating collection: {self.collection_name}")
|
|
154
166
|
|
|
155
167
|
# Multi-vector config for ColBERT-style MaxSim
|
|
156
|
-
multivector_config =
|
|
157
|
-
comparator=
|
|
168
|
+
multivector_config = qdrant_models.MultiVectorConfig(
|
|
169
|
+
comparator=qdrant_models.MultiVectorComparator.MAX_SIM
|
|
158
170
|
)
|
|
159
171
|
|
|
160
172
|
# Vector configs - simplified for compatibility
|
|
161
173
|
datatype = (
|
|
162
|
-
|
|
174
|
+
qdrant_models.Datatype.FLOAT16
|
|
163
175
|
if self.vector_datatype == "float16"
|
|
164
|
-
else
|
|
176
|
+
else qdrant_models.Datatype.FLOAT32
|
|
165
177
|
)
|
|
166
178
|
vectors_config = {
|
|
167
179
|
"initial": VectorParams(
|
|
@@ -198,6 +210,18 @@ class QdrantIndexer:
|
|
|
198
210
|
vectors_config=vectors_config,
|
|
199
211
|
)
|
|
200
212
|
|
|
213
|
+
# Create required payload index for skip_existing functionality
|
|
214
|
+
# This index is needed for filtering by filename when checking existing docs
|
|
215
|
+
try:
|
|
216
|
+
self.client.create_payload_index(
|
|
217
|
+
collection_name=self.collection_name,
|
|
218
|
+
field_name="filename",
|
|
219
|
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
|
220
|
+
)
|
|
221
|
+
logger.info(" 📇 Created payload index: filename")
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.warning(f" ⚠️ Could not create filename index: {e}")
|
|
224
|
+
|
|
201
225
|
logger.info(f"✅ Collection created: {self.collection_name}")
|
|
202
226
|
return True
|
|
203
227
|
|
|
@@ -212,14 +236,12 @@ class QdrantIndexer:
|
|
|
212
236
|
fields: List of {field, type} dicts
|
|
213
237
|
type can be: integer, keyword, bool, float, text
|
|
214
238
|
"""
|
|
215
|
-
from qdrant_client.http import models
|
|
216
|
-
|
|
217
239
|
type_mapping = {
|
|
218
|
-
"integer":
|
|
219
|
-
"keyword":
|
|
220
|
-
"bool":
|
|
221
|
-
"float":
|
|
222
|
-
"text":
|
|
240
|
+
"integer": qdrant_models.PayloadSchemaType.INTEGER,
|
|
241
|
+
"keyword": qdrant_models.PayloadSchemaType.KEYWORD,
|
|
242
|
+
"bool": qdrant_models.PayloadSchemaType.BOOL,
|
|
243
|
+
"float": qdrant_models.PayloadSchemaType.FLOAT,
|
|
244
|
+
"text": qdrant_models.PayloadSchemaType.TEXT,
|
|
223
245
|
}
|
|
224
246
|
|
|
225
247
|
if not fields:
|
|
@@ -230,7 +252,7 @@ class QdrantIndexer:
|
|
|
230
252
|
for field_config in fields:
|
|
231
253
|
field_name = field_config["field"]
|
|
232
254
|
field_type_str = field_config.get("type", "keyword")
|
|
233
|
-
field_type = type_mapping.get(field_type_str,
|
|
255
|
+
field_type = type_mapping.get(field_type_str, qdrant_models.PayloadSchemaType.KEYWORD)
|
|
234
256
|
|
|
235
257
|
try:
|
|
236
258
|
self.client.create_payload_index(
|
|
@@ -271,8 +293,6 @@ class QdrantIndexer:
|
|
|
271
293
|
Returns:
|
|
272
294
|
Number of successfully uploaded points
|
|
273
295
|
"""
|
|
274
|
-
from qdrant_client.http import models
|
|
275
|
-
|
|
276
296
|
if not points:
|
|
277
297
|
return 0
|
|
278
298
|
|
|
@@ -315,8 +335,8 @@ class QdrantIndexer:
|
|
|
315
335
|
return val.tolist()
|
|
316
336
|
return val
|
|
317
337
|
|
|
318
|
-
def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[
|
|
319
|
-
qdrant_points: List[
|
|
338
|
+
def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[qdrant_models.PointStruct]:
|
|
339
|
+
qdrant_points: List[qdrant_models.PointStruct] = []
|
|
320
340
|
for p in batch_points:
|
|
321
341
|
global_pooled = p.get("global_pooled_embedding")
|
|
322
342
|
if global_pooled is None:
|
|
@@ -336,7 +356,7 @@ class QdrantIndexer:
|
|
|
336
356
|
global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
|
|
337
357
|
|
|
338
358
|
qdrant_points.append(
|
|
339
|
-
|
|
359
|
+
qdrant_models.PointStruct(
|
|
340
360
|
id=p["id"],
|
|
341
361
|
vector={
|
|
342
362
|
"initial": _to_list(initial),
|
|
@@ -361,6 +381,8 @@ class QdrantIndexer:
|
|
|
361
381
|
wait=wait,
|
|
362
382
|
)
|
|
363
383
|
|
|
384
|
+
logger.info(f" ✅ Uploaded {len(points)} points to Qdrant")
|
|
385
|
+
|
|
364
386
|
if delay_between_batches > 0:
|
|
365
387
|
if _is_cancelled():
|
|
366
388
|
return 0
|
|
@@ -413,32 +435,60 @@ class QdrantIndexer:
|
|
|
413
435
|
return False
|
|
414
436
|
|
|
415
437
|
def get_existing_ids(self, filename: str) -> Set[str]:
|
|
416
|
-
"""Get all point IDs for a specific file.
|
|
417
|
-
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
438
|
+
"""Get all point IDs for a specific file.
|
|
418
439
|
|
|
440
|
+
Requires a payload index on 'filename' field. If the index doesn't exist,
|
|
441
|
+
this method will attempt to create it automatically.
|
|
442
|
+
"""
|
|
419
443
|
existing_ids = set()
|
|
420
444
|
offset = None
|
|
421
445
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
446
|
+
try:
|
|
447
|
+
while True:
|
|
448
|
+
results = self.client.scroll(
|
|
449
|
+
collection_name=self.collection_name,
|
|
450
|
+
scroll_filter=Filter(
|
|
451
|
+
must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
|
|
452
|
+
),
|
|
453
|
+
limit=100,
|
|
454
|
+
offset=offset,
|
|
455
|
+
with_payload=["page_number"],
|
|
456
|
+
with_vectors=False,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
points, next_offset = results
|
|
433
460
|
|
|
434
|
-
|
|
461
|
+
for point in points:
|
|
462
|
+
existing_ids.add(str(point.id))
|
|
435
463
|
|
|
436
|
-
|
|
437
|
-
|
|
464
|
+
if next_offset is None or len(points) == 0:
|
|
465
|
+
break
|
|
466
|
+
offset = next_offset
|
|
438
467
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
468
|
+
except Exception as e:
|
|
469
|
+
error_msg = str(e).lower()
|
|
470
|
+
if "index required" in error_msg or "index" in error_msg and "filename" in error_msg:
|
|
471
|
+
# Missing payload index - try to create it
|
|
472
|
+
logger.warning(
|
|
473
|
+
"⚠️ Missing 'filename' payload index. Creating it now... "
|
|
474
|
+
"(skip_existing requires this index for filtering)"
|
|
475
|
+
)
|
|
476
|
+
try:
|
|
477
|
+
self.client.create_payload_index(
|
|
478
|
+
collection_name=self.collection_name,
|
|
479
|
+
field_name="filename",
|
|
480
|
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
|
481
|
+
)
|
|
482
|
+
logger.info(" ✅ Created 'filename' index. Retrying query...")
|
|
483
|
+
# Retry the query
|
|
484
|
+
return self.get_existing_ids(filename)
|
|
485
|
+
except Exception as idx_err:
|
|
486
|
+
logger.warning(f" ❌ Could not create index: {idx_err}")
|
|
487
|
+
logger.warning(" Returning empty set - all pages will be processed")
|
|
488
|
+
return set()
|
|
489
|
+
else:
|
|
490
|
+
logger.warning(f"⚠️ Error checking existing IDs: {e}")
|
|
491
|
+
return set()
|
|
442
492
|
|
|
443
493
|
return existing_ids
|
|
444
494
|
|
|
@@ -2,6 +2,25 @@ import os
|
|
|
2
2
|
from typing import Any, Dict, List, Optional
|
|
3
3
|
from urllib.parse import urlparse
|
|
4
4
|
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
|
|
11
|
+
DOTENV_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
DOTENV_AVAILABLE = False
|
|
14
|
+
load_dotenv = None
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from qdrant_client import QdrantClient
|
|
18
|
+
|
|
19
|
+
QDRANT_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
QDRANT_AVAILABLE = False
|
|
22
|
+
QdrantClient = None
|
|
23
|
+
|
|
5
24
|
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
|
6
25
|
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
|
7
26
|
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
|
@@ -11,9 +30,7 @@ from visual_rag.retrieval.two_stage import TwoStageRetriever
|
|
|
11
30
|
class MultiVectorRetriever:
|
|
12
31
|
@staticmethod
|
|
13
32
|
def _maybe_load_dotenv() -> None:
|
|
14
|
-
|
|
15
|
-
from dotenv import load_dotenv
|
|
16
|
-
except ImportError:
|
|
33
|
+
if not DOTENV_AVAILABLE:
|
|
17
34
|
return
|
|
18
35
|
if os.path.exists(".env"):
|
|
19
36
|
load_dotenv(".env")
|
|
@@ -33,87 +50,84 @@ class MultiVectorRetriever:
|
|
|
33
50
|
):
|
|
34
51
|
if qdrant_client is None:
|
|
35
52
|
self._maybe_load_dotenv()
|
|
36
|
-
|
|
37
|
-
from qdrant_client import QdrantClient
|
|
38
|
-
except ImportError as e:
|
|
53
|
+
if not QDRANT_AVAILABLE:
|
|
39
54
|
raise ImportError(
|
|
40
55
|
"Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
|
|
41
|
-
)
|
|
56
|
+
)
|
|
42
57
|
|
|
43
58
|
qdrant_url = (
|
|
44
59
|
qdrant_url
|
|
45
|
-
or os.getenv("SIGIR_QDRANT_URL")
|
|
46
|
-
or os.getenv("DEST_QDRANT_URL")
|
|
47
60
|
or os.getenv("QDRANT_URL")
|
|
61
|
+
or os.getenv("SIGIR_QDRANT_URL") # legacy
|
|
48
62
|
)
|
|
49
63
|
if not qdrant_url:
|
|
50
64
|
raise ValueError(
|
|
51
|
-
"QDRANT_URL is required (pass qdrant_url or set env var).
|
|
52
|
-
"You can also set DEST_QDRANT_URL to override."
|
|
65
|
+
"QDRANT_URL is required (pass qdrant_url or set env var)."
|
|
53
66
|
)
|
|
54
67
|
|
|
55
68
|
qdrant_api_key = (
|
|
56
69
|
qdrant_api_key
|
|
57
|
-
or os.getenv("SIGIR_QDRANT_KEY")
|
|
58
|
-
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
59
|
-
or os.getenv("DEST_QDRANT_API_KEY")
|
|
60
70
|
or os.getenv("QDRANT_API_KEY")
|
|
71
|
+
or os.getenv("SIGIR_QDRANT_KEY") # legacy
|
|
61
72
|
)
|
|
62
73
|
|
|
63
74
|
grpc_port = None
|
|
64
75
|
if prefer_grpc:
|
|
65
76
|
try:
|
|
66
|
-
|
|
77
|
+
parsed = urlparse(qdrant_url)
|
|
78
|
+
port = parsed.port
|
|
79
|
+
if port == 6333:
|
|
67
80
|
grpc_port = 6334
|
|
68
81
|
except Exception:
|
|
69
|
-
|
|
82
|
+
pass
|
|
70
83
|
|
|
71
84
|
def _make_client(use_grpc: bool):
|
|
72
85
|
return QdrantClient(
|
|
73
86
|
url=qdrant_url,
|
|
74
87
|
api_key=qdrant_api_key,
|
|
88
|
+
timeout=request_timeout,
|
|
75
89
|
prefer_grpc=bool(use_grpc),
|
|
76
90
|
grpc_port=grpc_port,
|
|
77
|
-
timeout=int(request_timeout),
|
|
78
91
|
check_compatibility=False,
|
|
79
92
|
)
|
|
80
93
|
|
|
81
|
-
|
|
94
|
+
client = _make_client(prefer_grpc)
|
|
82
95
|
if prefer_grpc:
|
|
83
96
|
try:
|
|
84
|
-
_ =
|
|
97
|
+
_ = client.get_collections()
|
|
85
98
|
except Exception as e:
|
|
86
99
|
msg = str(e)
|
|
87
|
-
if
|
|
88
|
-
|
|
89
|
-
or "http2 header with status: 403" in msg
|
|
90
|
-
):
|
|
91
|
-
qdrant_client = _make_client(False)
|
|
100
|
+
if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
|
|
101
|
+
client = _make_client(False)
|
|
92
102
|
else:
|
|
93
103
|
raise
|
|
104
|
+
qdrant_client = client
|
|
94
105
|
|
|
95
106
|
self.client = qdrant_client
|
|
96
107
|
self.collection_name = collection_name
|
|
108
|
+
|
|
97
109
|
self.embedder = embedder or VisualEmbedder(model_name=model_name)
|
|
98
110
|
|
|
99
111
|
self._two_stage = TwoStageRetriever(
|
|
100
|
-
|
|
101
|
-
collection_name=
|
|
102
|
-
request_timeout=
|
|
103
|
-
max_retries=
|
|
104
|
-
retry_sleep=
|
|
112
|
+
qdrant_client=qdrant_client,
|
|
113
|
+
collection_name=collection_name,
|
|
114
|
+
request_timeout=request_timeout,
|
|
115
|
+
max_retries=max_retries,
|
|
116
|
+
retry_sleep=retry_sleep,
|
|
105
117
|
)
|
|
106
118
|
self._three_stage = ThreeStageRetriever(
|
|
107
|
-
|
|
108
|
-
collection_name=
|
|
109
|
-
request_timeout=
|
|
110
|
-
max_retries=
|
|
111
|
-
retry_sleep=
|
|
119
|
+
qdrant_client=qdrant_client,
|
|
120
|
+
collection_name=collection_name,
|
|
121
|
+
request_timeout=request_timeout,
|
|
122
|
+
max_retries=max_retries,
|
|
123
|
+
retry_sleep=retry_sleep,
|
|
112
124
|
)
|
|
113
125
|
self._single_stage = SingleStageRetriever(
|
|
114
|
-
|
|
115
|
-
collection_name=
|
|
116
|
-
request_timeout=
|
|
126
|
+
qdrant_client=qdrant_client,
|
|
127
|
+
collection_name=collection_name,
|
|
128
|
+
request_timeout=request_timeout,
|
|
129
|
+
max_retries=max_retries,
|
|
130
|
+
retry_sleep=retry_sleep,
|
|
117
131
|
)
|
|
118
132
|
|
|
119
133
|
def build_filter(
|
|
@@ -143,14 +157,10 @@ class MultiVectorRetriever:
|
|
|
143
157
|
return_embeddings: bool = False,
|
|
144
158
|
) -> List[Dict[str, Any]]:
|
|
145
159
|
q = self.embedder.embed_query(query)
|
|
146
|
-
|
|
147
|
-
import torch
|
|
148
|
-
except ImportError:
|
|
149
|
-
torch = None
|
|
150
|
-
if torch is not None and isinstance(q, torch.Tensor):
|
|
160
|
+
if isinstance(q, torch.Tensor):
|
|
151
161
|
query_embedding = q.detach().cpu().numpy()
|
|
152
162
|
else:
|
|
153
|
-
query_embedding =
|
|
163
|
+
query_embedding = np.asarray(q)
|
|
154
164
|
|
|
155
165
|
return self.search_embedded(
|
|
156
166
|
query_embedding=query_embedding,
|
|
@@ -179,27 +189,17 @@ class MultiVectorRetriever:
|
|
|
179
189
|
return self._single_stage.search(
|
|
180
190
|
query_embedding=query_embedding,
|
|
181
191
|
top_k=top_k,
|
|
182
|
-
strategy="multi_vector",
|
|
183
192
|
filter_obj=filter_obj,
|
|
193
|
+
using="initial",
|
|
184
194
|
)
|
|
185
|
-
|
|
186
|
-
if mode == "single_tiles":
|
|
195
|
+
elif mode == "single_pooled":
|
|
187
196
|
return self._single_stage.search(
|
|
188
197
|
query_embedding=query_embedding,
|
|
189
198
|
top_k=top_k,
|
|
190
|
-
strategy="tiles_maxsim",
|
|
191
199
|
filter_obj=filter_obj,
|
|
200
|
+
using="mean_pooling",
|
|
192
201
|
)
|
|
193
|
-
|
|
194
|
-
if mode == "single_global":
|
|
195
|
-
return self._single_stage.search(
|
|
196
|
-
query_embedding=query_embedding,
|
|
197
|
-
top_k=top_k,
|
|
198
|
-
strategy="pooled_global",
|
|
199
|
-
filter_obj=filter_obj,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
if mode == "two_stage":
|
|
202
|
+
elif mode == "two_stage":
|
|
203
203
|
return self._two_stage.search_server_side(
|
|
204
204
|
query_embedding=query_embedding,
|
|
205
205
|
top_k=top_k,
|
|
@@ -207,16 +207,14 @@ class MultiVectorRetriever:
|
|
|
207
207
|
filter_obj=filter_obj,
|
|
208
208
|
stage1_mode=stage1_mode,
|
|
209
209
|
)
|
|
210
|
-
|
|
211
|
-
if mode == "three_stage":
|
|
212
|
-
s1 = int(stage1_k) if stage1_k is not None else 1000
|
|
213
|
-
s2 = int(stage2_k) if stage2_k is not None else 300
|
|
210
|
+
elif mode == "three_stage":
|
|
214
211
|
return self._three_stage.search_server_side(
|
|
215
212
|
query_embedding=query_embedding,
|
|
216
213
|
top_k=top_k,
|
|
217
|
-
stage1_k=
|
|
218
|
-
stage2_k=
|
|
214
|
+
stage1_k=stage1_k,
|
|
215
|
+
stage2_k=stage2_k,
|
|
219
216
|
filter_obj=filter_obj,
|
|
217
|
+
stage1_mode=stage1_mode,
|
|
220
218
|
)
|
|
221
|
-
|
|
222
|
-
|
|
219
|
+
else:
|
|
220
|
+
raise ValueError(f"Unknown mode: {mode}")
|
|
@@ -30,6 +30,9 @@ class SingleStageRetriever:
|
|
|
30
30
|
Args:
|
|
31
31
|
qdrant_client: Connected Qdrant client
|
|
32
32
|
collection_name: Name of the Qdrant collection
|
|
33
|
+
request_timeout: Timeout for Qdrant requests (seconds)
|
|
34
|
+
max_retries: Number of retry attempts on failure
|
|
35
|
+
retry_sleep: Sleep time between retries (seconds)
|
|
33
36
|
|
|
34
37
|
Example:
|
|
35
38
|
>>> retriever = SingleStageRetriever(client, "my_collection")
|
|
@@ -41,10 +44,14 @@ class SingleStageRetriever:
|
|
|
41
44
|
qdrant_client,
|
|
42
45
|
collection_name: str,
|
|
43
46
|
request_timeout: int = 120,
|
|
47
|
+
max_retries: int = 3,
|
|
48
|
+
retry_sleep: float = 1.0,
|
|
44
49
|
):
|
|
45
50
|
self.client = qdrant_client
|
|
46
51
|
self.collection_name = collection_name
|
|
47
52
|
self.request_timeout = int(request_timeout)
|
|
53
|
+
self.max_retries = max_retries
|
|
54
|
+
self.retry_sleep = retry_sleep
|
|
48
55
|
|
|
49
56
|
def search(
|
|
50
57
|
self,
|
|
@@ -17,11 +17,17 @@ Research Context:
|
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
19
|
import logging
|
|
20
|
+
import time
|
|
20
21
|
from typing import Any, Dict, List, Optional, Union
|
|
21
22
|
|
|
22
23
|
import numpy as np
|
|
23
24
|
import torch
|
|
24
25
|
|
|
26
|
+
from qdrant_client.http import models as qdrant_models
|
|
27
|
+
from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchValue
|
|
28
|
+
|
|
29
|
+
from visual_rag.embedding.pooling import compute_maxsim_score
|
|
30
|
+
|
|
25
31
|
logger = logging.getLogger(__name__)
|
|
26
32
|
|
|
27
33
|
|
|
@@ -82,8 +88,6 @@ class TwoStageRetriever:
|
|
|
82
88
|
self.retry_sleep = float(retry_sleep)
|
|
83
89
|
|
|
84
90
|
def _retry_call(self, fn):
|
|
85
|
-
import time
|
|
86
|
-
|
|
87
91
|
last_err = None
|
|
88
92
|
for attempt in range(self.max_retries):
|
|
89
93
|
try:
|
|
@@ -120,8 +124,6 @@ class TwoStageRetriever:
|
|
|
120
124
|
Returns:
|
|
121
125
|
List of results with scores
|
|
122
126
|
"""
|
|
123
|
-
from qdrant_client.http import models
|
|
124
|
-
|
|
125
127
|
query_np = self._to_numpy(query_embedding)
|
|
126
128
|
|
|
127
129
|
if prefetch_k is None:
|
|
@@ -155,9 +157,9 @@ class TwoStageRetriever:
|
|
|
155
157
|
limit=top_k,
|
|
156
158
|
query_filter=filter_obj,
|
|
157
159
|
with_payload=True,
|
|
158
|
-
search_params=
|
|
160
|
+
search_params=qdrant_models.SearchParams(exact=True),
|
|
159
161
|
prefetch=[
|
|
160
|
-
|
|
162
|
+
qdrant_models.Prefetch(
|
|
161
163
|
query=prefetch_query,
|
|
162
164
|
using=prefetch_using,
|
|
163
165
|
limit=prefetch_k,
|
|
@@ -363,8 +365,6 @@ class TwoStageRetriever:
|
|
|
363
365
|
return_embeddings: bool = False,
|
|
364
366
|
) -> List[Dict[str, Any]]:
|
|
365
367
|
"""Stage 2: Rerank with full multi-vector MaxSim scoring."""
|
|
366
|
-
from visual_rag.embedding.pooling import compute_maxsim_score
|
|
367
|
-
|
|
368
368
|
# Fetch full embeddings for candidates
|
|
369
369
|
candidate_ids = [c["id"] for c in candidates]
|
|
370
370
|
|
|
@@ -435,8 +435,6 @@ class TwoStageRetriever:
|
|
|
435
435
|
|
|
436
436
|
Supports single values or lists (using MatchAny).
|
|
437
437
|
"""
|
|
438
|
-
from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchValue
|
|
439
|
-
|
|
440
438
|
conditions = []
|
|
441
439
|
|
|
442
440
|
if year is not None:
|