visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ThreeStageRetriever:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
qdrant_client,
|
|
14
|
+
collection_name: str,
|
|
15
|
+
*,
|
|
16
|
+
full_vector_name: str = "initial",
|
|
17
|
+
experimental_vector_name: str = "experimental_pooling",
|
|
18
|
+
global_vector_name: str = "global_pooling",
|
|
19
|
+
request_timeout: int = 120,
|
|
20
|
+
max_retries: int = 3,
|
|
21
|
+
retry_sleep: float = 0.5,
|
|
22
|
+
):
|
|
23
|
+
self.client = qdrant_client
|
|
24
|
+
self.collection_name = collection_name
|
|
25
|
+
self.full_vector_name = full_vector_name
|
|
26
|
+
self.experimental_vector_name = experimental_vector_name
|
|
27
|
+
self.global_vector_name = global_vector_name
|
|
28
|
+
self.request_timeout = int(request_timeout)
|
|
29
|
+
self.max_retries = int(max_retries)
|
|
30
|
+
self.retry_sleep = float(retry_sleep)
|
|
31
|
+
|
|
32
|
+
self._global_is_multivector: Optional[bool] = None
|
|
33
|
+
self._experimental_is_multivector: Optional[bool] = None
|
|
34
|
+
|
|
35
|
+
def _retry_call(self, fn):
|
|
36
|
+
import time
|
|
37
|
+
|
|
38
|
+
last_err = None
|
|
39
|
+
for attempt in range(self.max_retries):
|
|
40
|
+
try:
|
|
41
|
+
return fn()
|
|
42
|
+
except Exception as e:
|
|
43
|
+
last_err = e
|
|
44
|
+
if attempt >= self.max_retries - 1:
|
|
45
|
+
break
|
|
46
|
+
time.sleep(self.retry_sleep * (2**attempt))
|
|
47
|
+
if last_err is not None:
|
|
48
|
+
raise last_err
|
|
49
|
+
|
|
50
|
+
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
|
51
|
+
if isinstance(embedding, torch.Tensor):
|
|
52
|
+
if embedding.dtype == torch.bfloat16:
|
|
53
|
+
return embedding.cpu().float().numpy()
|
|
54
|
+
return embedding.cpu().numpy()
|
|
55
|
+
return np.array(embedding, dtype=np.float32)
|
|
56
|
+
|
|
57
|
+
def _infer_vector_is_multivector(self, vector_name: str) -> bool:
|
|
58
|
+
info = self.client.get_collection(self.collection_name)
|
|
59
|
+
cfg = getattr(info, "config", None)
|
|
60
|
+
params = getattr(cfg, "params", None) if cfg is not None else None
|
|
61
|
+
vectors = getattr(params, "vectors", None) if params is not None else None
|
|
62
|
+
v = None
|
|
63
|
+
try:
|
|
64
|
+
if isinstance(vectors, dict):
|
|
65
|
+
v = vectors.get(vector_name)
|
|
66
|
+
else:
|
|
67
|
+
v = vectors[vector_name]
|
|
68
|
+
except Exception:
|
|
69
|
+
v = None
|
|
70
|
+
mv = getattr(v, "multivector_config", None) if v is not None else None
|
|
71
|
+
if mv is None and isinstance(v, dict):
|
|
72
|
+
mv = v.get("multivector_config")
|
|
73
|
+
return mv is not None
|
|
74
|
+
|
|
75
|
+
def _and_filter(self, base_filter, ids: List[Any]):
|
|
76
|
+
from qdrant_client.http import models as m
|
|
77
|
+
|
|
78
|
+
has_id = m.HasIdCondition(has_id=list(ids))
|
|
79
|
+
if base_filter is None:
|
|
80
|
+
return m.Filter(must=[has_id])
|
|
81
|
+
return m.Filter(must=[base_filter, has_id])
|
|
82
|
+
|
|
83
|
+
def search_server_side(
|
|
84
|
+
self,
|
|
85
|
+
*,
|
|
86
|
+
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
87
|
+
top_k: int = 100,
|
|
88
|
+
stage1_k: int = 1000,
|
|
89
|
+
stage2_k: int = 300,
|
|
90
|
+
filter_obj=None,
|
|
91
|
+
) -> List[Dict[str, Any]]:
|
|
92
|
+
from qdrant_client.http import models as m
|
|
93
|
+
|
|
94
|
+
query_np = self._to_numpy(query_embedding)
|
|
95
|
+
|
|
96
|
+
stage1_query = query_np.mean(axis=0).tolist()
|
|
97
|
+
stage2_query = query_np.tolist()
|
|
98
|
+
stage3_query = query_np.tolist()
|
|
99
|
+
|
|
100
|
+
logger.info(f"Stage 1: global prefetch {int(stage1_k)}")
|
|
101
|
+
|
|
102
|
+
def _do_stage1():
|
|
103
|
+
return self.client.query_points(
|
|
104
|
+
collection_name=self.collection_name,
|
|
105
|
+
query=stage1_query,
|
|
106
|
+
using=self.global_vector_name,
|
|
107
|
+
limit=int(stage1_k),
|
|
108
|
+
query_filter=filter_obj,
|
|
109
|
+
with_payload=False,
|
|
110
|
+
with_vectors=False,
|
|
111
|
+
timeout=self.request_timeout,
|
|
112
|
+
).points
|
|
113
|
+
|
|
114
|
+
s1 = self._retry_call(_do_stage1)
|
|
115
|
+
if not s1:
|
|
116
|
+
return []
|
|
117
|
+
s1_ids = [p.id for p in s1]
|
|
118
|
+
s1_score = {str(p.id): float(p.score) for p in s1}
|
|
119
|
+
|
|
120
|
+
logger.info(f"Stage 2: experimental prefetch {int(stage2_k)} (restricted to stage1)")
|
|
121
|
+
|
|
122
|
+
stage2_filter = self._and_filter(filter_obj, s1_ids)
|
|
123
|
+
|
|
124
|
+
def _do_stage2():
|
|
125
|
+
return self.client.query_points(
|
|
126
|
+
collection_name=self.collection_name,
|
|
127
|
+
query=stage2_query,
|
|
128
|
+
using=self.experimental_vector_name,
|
|
129
|
+
limit=int(min(int(stage2_k), len(s1_ids))),
|
|
130
|
+
query_filter=stage2_filter,
|
|
131
|
+
with_payload=False,
|
|
132
|
+
with_vectors=False,
|
|
133
|
+
timeout=self.request_timeout,
|
|
134
|
+
).points
|
|
135
|
+
|
|
136
|
+
s2 = self._retry_call(_do_stage2)
|
|
137
|
+
if not s2:
|
|
138
|
+
return []
|
|
139
|
+
s2_ids = [p.id for p in s2]
|
|
140
|
+
s2_score = {str(p.id): float(p.score) for p in s2}
|
|
141
|
+
|
|
142
|
+
logger.info(f"Stage 3: exact rerank on initial to top {int(top_k)} (restricted to stage2)")
|
|
143
|
+
|
|
144
|
+
stage3_filter = self._and_filter(filter_obj, s2_ids)
|
|
145
|
+
|
|
146
|
+
def _do_stage3():
|
|
147
|
+
return self.client.query_points(
|
|
148
|
+
collection_name=self.collection_name,
|
|
149
|
+
query=stage3_query,
|
|
150
|
+
using=self.full_vector_name,
|
|
151
|
+
limit=int(top_k),
|
|
152
|
+
query_filter=stage3_filter,
|
|
153
|
+
with_payload=True,
|
|
154
|
+
with_vectors=False,
|
|
155
|
+
search_params=m.SearchParams(exact=True),
|
|
156
|
+
timeout=self.request_timeout,
|
|
157
|
+
).points
|
|
158
|
+
|
|
159
|
+
s3 = self._retry_call(_do_stage3)
|
|
160
|
+
out = []
|
|
161
|
+
for p in s3:
|
|
162
|
+
pid = str(p.id)
|
|
163
|
+
out.append(
|
|
164
|
+
{
|
|
165
|
+
"id": p.id,
|
|
166
|
+
"score_stage1": s1_score.get(pid),
|
|
167
|
+
"score_stage2": s2_score.get(pid),
|
|
168
|
+
"score_stage3": float(p.score),
|
|
169
|
+
"score_final": float(p.score),
|
|
170
|
+
"payload": p.payload,
|
|
171
|
+
}
|
|
172
|
+
)
|
|
173
|
+
return out
|
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Two-Stage Retrieval for Scalable Visual Document Search.
|
|
3
|
+
|
|
4
|
+
This is our NOVEL contribution:
|
|
5
|
+
- Stage 1: Fast prefetch using tile-level pooled vectors (mean_pooling)
|
|
6
|
+
- Stage 2: Exact reranking using full multi-vector embeddings (MaxSim)
|
|
7
|
+
|
|
8
|
+
Benefits:
|
|
9
|
+
- 5-10x faster than full MaxSim at scale
|
|
10
|
+
- Maintains 95%+ accuracy compared to full search
|
|
11
|
+
- Memory efficient (don't load all embeddings upfront)
|
|
12
|
+
|
|
13
|
+
Research Context:
|
|
14
|
+
- Different from HPC-ColPali (compression vs pooling)
|
|
15
|
+
- Inspired by text ColBERT two-stage retrieval
|
|
16
|
+
- Novel: tile-level pooling preserves spatial structure
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
from typing import Any, Dict, List, Optional, Union
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TwoStageRetriever:
|
|
29
|
+
"""
|
|
30
|
+
Two-stage visual document retrieval with pooling and reranking.
|
|
31
|
+
|
|
32
|
+
Stage 1 (Prefetch):
|
|
33
|
+
Uses tile-level mean-pooled vectors for fast HNSW search.
|
|
34
|
+
Retrieves prefetch_k candidates (e.g., 100-500).
|
|
35
|
+
|
|
36
|
+
Stage 2 (Rerank):
|
|
37
|
+
Fetches full multi-vector embeddings for candidates.
|
|
38
|
+
Computes exact MaxSim scores for precise ranking.
|
|
39
|
+
Returns top_k results (e.g., 10).
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
qdrant_client: Connected Qdrant client
|
|
43
|
+
collection_name: Name of the Qdrant collection
|
|
44
|
+
full_vector_name: Name of full multi-vector field (default: "initial")
|
|
45
|
+
pooled_vector_name: Name of pooled vector field (default: "mean_pooling")
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> retriever = TwoStageRetriever(client, "my_collection")
|
|
49
|
+
>>>
|
|
50
|
+
>>> # Two-stage search: prefetch 200, return top 10
|
|
51
|
+
>>> results = retriever.search(
|
|
52
|
+
... query_embedding=query,
|
|
53
|
+
... top_k=10,
|
|
54
|
+
... prefetch_k=200,
|
|
55
|
+
... )
|
|
56
|
+
>>>
|
|
57
|
+
>>> # Compare latency:
|
|
58
|
+
>>> # Full MaxSim (1000 docs): ~500ms
|
|
59
|
+
>>> # Two-stage (200→10): ~50ms
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
qdrant_client,
|
|
65
|
+
collection_name: str,
|
|
66
|
+
full_vector_name: str = "initial",
|
|
67
|
+
pooled_vector_name: str = "mean_pooling",
|
|
68
|
+
experimental_vector_name: str = "experimental_pooling",
|
|
69
|
+
global_vector_name: str = "global_pooling",
|
|
70
|
+
request_timeout: int = 120,
|
|
71
|
+
max_retries: int = 3,
|
|
72
|
+
retry_sleep: float = 0.5,
|
|
73
|
+
):
|
|
74
|
+
self.client = qdrant_client
|
|
75
|
+
self.collection_name = collection_name
|
|
76
|
+
self.full_vector_name = full_vector_name
|
|
77
|
+
self.pooled_vector_name = pooled_vector_name
|
|
78
|
+
self.experimental_vector_name = experimental_vector_name
|
|
79
|
+
self.global_vector_name = global_vector_name
|
|
80
|
+
self.request_timeout = int(request_timeout)
|
|
81
|
+
self.max_retries = int(max_retries)
|
|
82
|
+
self.retry_sleep = float(retry_sleep)
|
|
83
|
+
|
|
84
|
+
def _retry_call(self, fn):
|
|
85
|
+
import time
|
|
86
|
+
|
|
87
|
+
last_err = None
|
|
88
|
+
for attempt in range(self.max_retries):
|
|
89
|
+
try:
|
|
90
|
+
return fn()
|
|
91
|
+
except Exception as e:
|
|
92
|
+
last_err = e
|
|
93
|
+
if attempt >= self.max_retries - 1:
|
|
94
|
+
break
|
|
95
|
+
time.sleep(self.retry_sleep * (2**attempt))
|
|
96
|
+
if last_err is not None:
|
|
97
|
+
raise last_err
|
|
98
|
+
|
|
99
|
+
def search_server_side(
|
|
100
|
+
self,
|
|
101
|
+
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
102
|
+
top_k: int = 10,
|
|
103
|
+
prefetch_k: Optional[int] = None,
|
|
104
|
+
filter_obj=None,
|
|
105
|
+
stage1_mode: str = "pooled_query_vs_tiles",
|
|
106
|
+
) -> List[Dict[str, Any]]:
|
|
107
|
+
"""
|
|
108
|
+
Two-stage retrieval using Qdrant's native prefetch (all server-side).
|
|
109
|
+
|
|
110
|
+
This is MUCH faster than search() because it avoids network transfer
|
|
111
|
+
of large multi-vector embeddings. All computation happens in Qdrant.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
query_embedding: Query embeddings [num_tokens, dim]
|
|
115
|
+
top_k: Final number of results
|
|
116
|
+
prefetch_k: Candidates for stage 1 (default: 10x top_k)
|
|
117
|
+
filter_obj: Qdrant filter
|
|
118
|
+
stage1_mode: How to do stage 1 prefetch
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
List of results with scores
|
|
122
|
+
"""
|
|
123
|
+
from qdrant_client.http import models
|
|
124
|
+
|
|
125
|
+
query_np = self._to_numpy(query_embedding)
|
|
126
|
+
|
|
127
|
+
if prefetch_k is None:
|
|
128
|
+
prefetch_k = max(100, top_k * 10)
|
|
129
|
+
|
|
130
|
+
if stage1_mode == "pooled_query_vs_tiles":
|
|
131
|
+
prefetch_query = query_np.mean(axis=0).tolist()
|
|
132
|
+
prefetch_using = self.pooled_vector_name
|
|
133
|
+
elif stage1_mode == "tokens_vs_tiles":
|
|
134
|
+
prefetch_query = query_np.tolist()
|
|
135
|
+
prefetch_using = self.pooled_vector_name
|
|
136
|
+
elif stage1_mode == "pooled_query_vs_experimental":
|
|
137
|
+
prefetch_query = query_np.mean(axis=0).tolist()
|
|
138
|
+
prefetch_using = self.experimental_vector_name
|
|
139
|
+
elif stage1_mode == "tokens_vs_experimental":
|
|
140
|
+
prefetch_query = query_np.tolist()
|
|
141
|
+
prefetch_using = self.experimental_vector_name
|
|
142
|
+
elif stage1_mode == "pooled_query_vs_global":
|
|
143
|
+
prefetch_query = query_np.mean(axis=0).tolist()
|
|
144
|
+
prefetch_using = self.global_vector_name
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
|
|
147
|
+
|
|
148
|
+
rerank_query = query_np.tolist()
|
|
149
|
+
|
|
150
|
+
def _do_query():
|
|
151
|
+
return self.client.query_points(
|
|
152
|
+
collection_name=self.collection_name,
|
|
153
|
+
query=rerank_query,
|
|
154
|
+
using=self.full_vector_name,
|
|
155
|
+
limit=top_k,
|
|
156
|
+
query_filter=filter_obj,
|
|
157
|
+
with_payload=True,
|
|
158
|
+
search_params=models.SearchParams(exact=True),
|
|
159
|
+
prefetch=[
|
|
160
|
+
models.Prefetch(
|
|
161
|
+
query=prefetch_query,
|
|
162
|
+
using=prefetch_using,
|
|
163
|
+
limit=prefetch_k,
|
|
164
|
+
)
|
|
165
|
+
],
|
|
166
|
+
timeout=self.request_timeout,
|
|
167
|
+
).points
|
|
168
|
+
|
|
169
|
+
results = self._retry_call(_do_query)
|
|
170
|
+
|
|
171
|
+
return [
|
|
172
|
+
{
|
|
173
|
+
"id": r.id,
|
|
174
|
+
"score_stage1": None,
|
|
175
|
+
"score_stage2": r.score,
|
|
176
|
+
"score_final": r.score,
|
|
177
|
+
"payload": r.payload,
|
|
178
|
+
}
|
|
179
|
+
for r in results
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
def search(
|
|
183
|
+
self,
|
|
184
|
+
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
185
|
+
top_k: int = 10,
|
|
186
|
+
prefetch_k: Optional[int] = None,
|
|
187
|
+
filter_obj=None,
|
|
188
|
+
use_reranking: bool = True,
|
|
189
|
+
return_embeddings: bool = False,
|
|
190
|
+
stage1_mode: str = "pooled_query_vs_tiles",
|
|
191
|
+
) -> List[Dict[str, Any]]:
|
|
192
|
+
"""
|
|
193
|
+
Two-stage retrieval: prefetch with pooling, rerank with MaxSim.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
query_embedding: Query embeddings [num_tokens, dim]
|
|
197
|
+
top_k: Final number of results to return
|
|
198
|
+
prefetch_k: Candidates for stage 1 (default: 10x top_k)
|
|
199
|
+
filter_obj: Qdrant filter for metadata filtering
|
|
200
|
+
use_reranking: Enable stage 2 reranking (default: True)
|
|
201
|
+
return_embeddings: Include embeddings in results
|
|
202
|
+
stage1_mode:
|
|
203
|
+
- "pooled_query_vs_tiles": pool query to 1×dim and search tile vectors (using="mean_pooling")
|
|
204
|
+
- "tokens_vs_tiles": search tile vectors with full query tokens (using="mean_pooling")
|
|
205
|
+
- "pooled_query_vs_global": pool query to 1×dim and search global pooled doc vectors (using="global_pooling")
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
List of results with scores and metadata:
|
|
209
|
+
[
|
|
210
|
+
{
|
|
211
|
+
"id": point_id,
|
|
212
|
+
"score_stage1": float, # Pooled similarity
|
|
213
|
+
"score_stage2": float, # MaxSim (if reranking)
|
|
214
|
+
"score_final": float, # Final score used for ranking
|
|
215
|
+
"payload": {...}, # Document metadata
|
|
216
|
+
},
|
|
217
|
+
...
|
|
218
|
+
]
|
|
219
|
+
"""
|
|
220
|
+
# Convert to numpy
|
|
221
|
+
query_np = self._to_numpy(query_embedding)
|
|
222
|
+
|
|
223
|
+
# Auto-set prefetch_k
|
|
224
|
+
if prefetch_k is None:
|
|
225
|
+
prefetch_k = max(100, top_k * 10)
|
|
226
|
+
|
|
227
|
+
# Stage 1: Prefetch with pooled vectors
|
|
228
|
+
logger.info(f"🔍 Stage 1: Prefetching {prefetch_k} candidates ({stage1_mode})")
|
|
229
|
+
candidates = self._stage1_prefetch(
|
|
230
|
+
query_np=query_np,
|
|
231
|
+
top_k=prefetch_k,
|
|
232
|
+
filter_obj=filter_obj,
|
|
233
|
+
stage1_mode=stage1_mode,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if not candidates:
|
|
237
|
+
logger.warning("No candidates found in stage 1")
|
|
238
|
+
return []
|
|
239
|
+
|
|
240
|
+
logger.info(f"✅ Stage 1: Retrieved {len(candidates)} candidates")
|
|
241
|
+
|
|
242
|
+
# Stage 2: Rerank with full embeddings
|
|
243
|
+
if use_reranking and len(candidates) > top_k:
|
|
244
|
+
logger.info("🎯 Stage 2: Reranking with MaxSim...")
|
|
245
|
+
results = self._stage2_rerank(
|
|
246
|
+
query_np=query_np,
|
|
247
|
+
candidates=candidates,
|
|
248
|
+
top_k=top_k,
|
|
249
|
+
return_embeddings=return_embeddings,
|
|
250
|
+
)
|
|
251
|
+
logger.info(f"✅ Stage 2: Reranked to top {len(results)} results")
|
|
252
|
+
else:
|
|
253
|
+
# Skip reranking
|
|
254
|
+
results = candidates[:top_k]
|
|
255
|
+
for r in results:
|
|
256
|
+
r["score_final"] = r["score_stage1"]
|
|
257
|
+
logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
|
|
258
|
+
|
|
259
|
+
return results
|
|
260
|
+
|
|
261
|
+
def search_single_stage(
|
|
262
|
+
self,
|
|
263
|
+
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
264
|
+
top_k: int = 10,
|
|
265
|
+
filter_obj=None,
|
|
266
|
+
use_pooling: bool = False,
|
|
267
|
+
) -> List[Dict[str, Any]]:
|
|
268
|
+
"""
|
|
269
|
+
Single-stage search (either pooled or full multi-vector).
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
query_embedding: Query embeddings
|
|
273
|
+
top_k: Number of results
|
|
274
|
+
filter_obj: Qdrant filter
|
|
275
|
+
use_pooling: Use pooled vectors (faster) or full (more accurate)
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
List of results
|
|
279
|
+
"""
|
|
280
|
+
query_np = self._to_numpy(query_embedding)
|
|
281
|
+
|
|
282
|
+
if use_pooling:
|
|
283
|
+
# Pool query and search pooled vectors
|
|
284
|
+
query_pooled = query_np.mean(axis=0)
|
|
285
|
+
vector_name = self.pooled_vector_name
|
|
286
|
+
query_vector = query_pooled.tolist()
|
|
287
|
+
logger.info(f"🔍 Pooled search: {vector_name}")
|
|
288
|
+
else:
|
|
289
|
+
# Native multi-vector search
|
|
290
|
+
vector_name = self.full_vector_name
|
|
291
|
+
query_vector = query_np.tolist()
|
|
292
|
+
logger.info(f"🎯 Multi-vector search: {vector_name}")
|
|
293
|
+
|
|
294
|
+
results = self.client.query_points(
|
|
295
|
+
collection_name=self.collection_name,
|
|
296
|
+
query=query_vector,
|
|
297
|
+
using=vector_name,
|
|
298
|
+
query_filter=filter_obj,
|
|
299
|
+
limit=top_k,
|
|
300
|
+
with_payload=True,
|
|
301
|
+
with_vectors=False,
|
|
302
|
+
timeout=120,
|
|
303
|
+
).points
|
|
304
|
+
|
|
305
|
+
return [
|
|
306
|
+
{
|
|
307
|
+
"id": r.id,
|
|
308
|
+
"score_stage1": r.score,
|
|
309
|
+
"score_final": r.score,
|
|
310
|
+
"payload": r.payload,
|
|
311
|
+
}
|
|
312
|
+
for r in results
|
|
313
|
+
]
|
|
314
|
+
|
|
315
|
+
def _stage1_prefetch(
|
|
316
|
+
self,
|
|
317
|
+
query_np: np.ndarray,
|
|
318
|
+
top_k: int,
|
|
319
|
+
filter_obj=None,
|
|
320
|
+
stage1_mode: str = "pooled_query_vs_tiles",
|
|
321
|
+
) -> List[Dict[str, Any]]:
|
|
322
|
+
"""Stage 1: Prefetch candidates."""
|
|
323
|
+
if stage1_mode == "pooled_query_vs_tiles":
|
|
324
|
+
query_vector = query_np.mean(axis=0).tolist()
|
|
325
|
+
vector_name = self.pooled_vector_name
|
|
326
|
+
elif stage1_mode == "tokens_vs_tiles":
|
|
327
|
+
query_vector = query_np.tolist()
|
|
328
|
+
vector_name = self.pooled_vector_name
|
|
329
|
+
elif stage1_mode == "pooled_query_vs_global":
|
|
330
|
+
query_vector = query_np.mean(axis=0).tolist()
|
|
331
|
+
vector_name = self.global_vector_name
|
|
332
|
+
else:
|
|
333
|
+
raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
|
|
334
|
+
|
|
335
|
+
def _do_query():
|
|
336
|
+
return self.client.query_points(
|
|
337
|
+
collection_name=self.collection_name,
|
|
338
|
+
query=query_vector,
|
|
339
|
+
using=vector_name,
|
|
340
|
+
query_filter=filter_obj,
|
|
341
|
+
limit=top_k,
|
|
342
|
+
with_payload=True,
|
|
343
|
+
with_vectors=False,
|
|
344
|
+
timeout=self.request_timeout,
|
|
345
|
+
).points
|
|
346
|
+
|
|
347
|
+
results = self._retry_call(_do_query)
|
|
348
|
+
|
|
349
|
+
return [
|
|
350
|
+
{
|
|
351
|
+
"id": r.id,
|
|
352
|
+
"score_stage1": r.score,
|
|
353
|
+
"payload": r.payload,
|
|
354
|
+
}
|
|
355
|
+
for r in results
|
|
356
|
+
]
|
|
357
|
+
|
|
358
|
+
def _stage2_rerank(
|
|
359
|
+
self,
|
|
360
|
+
query_np: np.ndarray,
|
|
361
|
+
candidates: List[Dict[str, Any]],
|
|
362
|
+
top_k: int,
|
|
363
|
+
return_embeddings: bool = False,
|
|
364
|
+
) -> List[Dict[str, Any]]:
|
|
365
|
+
"""Stage 2: Rerank with full multi-vector MaxSim scoring."""
|
|
366
|
+
from visual_rag.embedding.pooling import compute_maxsim_score
|
|
367
|
+
|
|
368
|
+
# Fetch full embeddings for candidates
|
|
369
|
+
candidate_ids = [c["id"] for c in candidates]
|
|
370
|
+
|
|
371
|
+
# Retrieve points with vectors
|
|
372
|
+
def _do_retrieve():
|
|
373
|
+
return self.client.retrieve(
|
|
374
|
+
collection_name=self.collection_name,
|
|
375
|
+
ids=candidate_ids,
|
|
376
|
+
with_payload=False,
|
|
377
|
+
with_vectors=[self.full_vector_name],
|
|
378
|
+
timeout=self.request_timeout,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
points = self._retry_call(_do_retrieve)
|
|
382
|
+
|
|
383
|
+
# Build ID to embedding map
|
|
384
|
+
id_to_embedding = {}
|
|
385
|
+
for point in points:
|
|
386
|
+
if point.vector and self.full_vector_name in point.vector:
|
|
387
|
+
id_to_embedding[point.id] = np.array(
|
|
388
|
+
point.vector[self.full_vector_name], dtype=np.float32
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Compute MaxSim scores
|
|
392
|
+
reranked = []
|
|
393
|
+
for candidate in candidates:
|
|
394
|
+
point_id = candidate["id"]
|
|
395
|
+
doc_embedding = id_to_embedding.get(point_id)
|
|
396
|
+
|
|
397
|
+
if doc_embedding is None:
|
|
398
|
+
# Fallback to stage 1 score
|
|
399
|
+
candidate["score_stage2"] = candidate["score_stage1"]
|
|
400
|
+
candidate["score_final"] = candidate["score_stage1"]
|
|
401
|
+
else:
|
|
402
|
+
# Compute exact MaxSim
|
|
403
|
+
maxsim_score = compute_maxsim_score(query_np, doc_embedding)
|
|
404
|
+
candidate["score_stage2"] = maxsim_score
|
|
405
|
+
candidate["score_final"] = maxsim_score
|
|
406
|
+
|
|
407
|
+
if return_embeddings:
|
|
408
|
+
candidate["embedding"] = doc_embedding
|
|
409
|
+
|
|
410
|
+
reranked.append(candidate)
|
|
411
|
+
|
|
412
|
+
# Sort by final score (descending)
|
|
413
|
+
reranked.sort(key=lambda x: x["score_final"], reverse=True)
|
|
414
|
+
|
|
415
|
+
return reranked[:top_k]
|
|
416
|
+
|
|
417
|
+
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
|
418
|
+
"""Convert embedding to numpy array."""
|
|
419
|
+
if isinstance(embedding, torch.Tensor):
|
|
420
|
+
if embedding.dtype == torch.bfloat16:
|
|
421
|
+
return embedding.cpu().float().numpy()
|
|
422
|
+
return embedding.cpu().numpy()
|
|
423
|
+
return np.array(embedding, dtype=np.float32)
|
|
424
|
+
|
|
425
|
+
def build_filter(
|
|
426
|
+
self,
|
|
427
|
+
year: Optional[Any] = None,
|
|
428
|
+
source: Optional[str] = None,
|
|
429
|
+
district: Optional[str] = None,
|
|
430
|
+
filename: Optional[str] = None,
|
|
431
|
+
has_text: Optional[bool] = None,
|
|
432
|
+
):
|
|
433
|
+
"""
|
|
434
|
+
Build Qdrant filter from parameters.
|
|
435
|
+
|
|
436
|
+
Supports single values or lists (using MatchAny).
|
|
437
|
+
"""
|
|
438
|
+
from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchValue
|
|
439
|
+
|
|
440
|
+
conditions = []
|
|
441
|
+
|
|
442
|
+
if year is not None:
|
|
443
|
+
if isinstance(year, list):
|
|
444
|
+
year_values = [int(y) if isinstance(y, str) else y for y in year]
|
|
445
|
+
conditions.append(FieldCondition(key="year", match=MatchAny(any=year_values)))
|
|
446
|
+
else:
|
|
447
|
+
year_value = int(year) if isinstance(year, str) else year
|
|
448
|
+
conditions.append(FieldCondition(key="year", match=MatchValue(value=year_value)))
|
|
449
|
+
|
|
450
|
+
if source is not None:
|
|
451
|
+
if isinstance(source, list):
|
|
452
|
+
conditions.append(FieldCondition(key="source", match=MatchAny(any=source)))
|
|
453
|
+
else:
|
|
454
|
+
conditions.append(FieldCondition(key="source", match=MatchValue(value=source)))
|
|
455
|
+
|
|
456
|
+
if district is not None:
|
|
457
|
+
if isinstance(district, list):
|
|
458
|
+
conditions.append(FieldCondition(key="district", match=MatchAny(any=district)))
|
|
459
|
+
else:
|
|
460
|
+
conditions.append(FieldCondition(key="district", match=MatchValue(value=district)))
|
|
461
|
+
|
|
462
|
+
if filename is not None:
|
|
463
|
+
if isinstance(filename, list):
|
|
464
|
+
conditions.append(FieldCondition(key="filename", match=MatchAny(any=filename)))
|
|
465
|
+
else:
|
|
466
|
+
conditions.append(FieldCondition(key="filename", match=MatchValue(value=filename)))
|
|
467
|
+
|
|
468
|
+
if has_text is not None:
|
|
469
|
+
conditions.append(FieldCondition(key="has_text", match=MatchValue(value=has_text)))
|
|
470
|
+
|
|
471
|
+
return Filter(must=conditions) if conditions else None
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization module - Saliency maps and attention visualization.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- Saliency map generation showing query-document relevance
|
|
6
|
+
- Attention heatmaps for visual token analysis
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from visual_rag.visualization.saliency import (
|
|
10
|
+
create_saliency_overlay,
|
|
11
|
+
generate_saliency_map,
|
|
12
|
+
visualize_search_results,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"generate_saliency_map",
|
|
17
|
+
"create_saliency_overlay",
|
|
18
|
+
"visualize_search_results",
|
|
19
|
+
]
|