visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class QdrantConnection:
|
|
10
|
+
url: str
|
|
11
|
+
api_key: Optional[str]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _maybe_load_dotenv() -> None:
|
|
15
|
+
try:
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
except Exception:
|
|
18
|
+
return
|
|
19
|
+
try:
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
if Path(".env").exists():
|
|
23
|
+
load_dotenv(".env")
|
|
24
|
+
except Exception:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _resolve_qdrant_connection(
|
|
29
|
+
*,
|
|
30
|
+
url: Optional[str] = None,
|
|
31
|
+
api_key: Optional[str] = None,
|
|
32
|
+
) -> QdrantConnection:
|
|
33
|
+
import os
|
|
34
|
+
|
|
35
|
+
_maybe_load_dotenv()
|
|
36
|
+
resolved_url = (
|
|
37
|
+
url
|
|
38
|
+
or os.getenv("SIGIR_QDRANT_URL")
|
|
39
|
+
or os.getenv("DEST_QDRANT_URL")
|
|
40
|
+
or os.getenv("QDRANT_URL")
|
|
41
|
+
)
|
|
42
|
+
if not resolved_url:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"Qdrant URL not set (pass url= or set SIGIR_QDRANT_URL/DEST_QDRANT_URL/QDRANT_URL)."
|
|
45
|
+
)
|
|
46
|
+
resolved_key = (
|
|
47
|
+
api_key
|
|
48
|
+
or os.getenv("SIGIR_QDRANT_KEY")
|
|
49
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
50
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
51
|
+
or os.getenv("QDRANT_API_KEY")
|
|
52
|
+
)
|
|
53
|
+
return QdrantConnection(url=str(resolved_url), api_key=resolved_key)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _infer_grpc_port(url: str) -> Optional[int]:
|
|
57
|
+
try:
|
|
58
|
+
if urlparse(url).port == 6333:
|
|
59
|
+
return 6334
|
|
60
|
+
except Exception:
|
|
61
|
+
return None
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class QdrantAdmin:
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
*,
|
|
69
|
+
url: Optional[str] = None,
|
|
70
|
+
api_key: Optional[str] = None,
|
|
71
|
+
prefer_grpc: bool = False,
|
|
72
|
+
timeout: int = 60,
|
|
73
|
+
):
|
|
74
|
+
from qdrant_client import QdrantClient
|
|
75
|
+
|
|
76
|
+
conn = _resolve_qdrant_connection(url=url, api_key=api_key)
|
|
77
|
+
grpc_port = _infer_grpc_port(conn.url) if prefer_grpc else None
|
|
78
|
+
self.client = QdrantClient(
|
|
79
|
+
url=conn.url,
|
|
80
|
+
api_key=conn.api_key,
|
|
81
|
+
prefer_grpc=bool(prefer_grpc),
|
|
82
|
+
grpc_port=grpc_port,
|
|
83
|
+
timeout=int(timeout),
|
|
84
|
+
check_compatibility=False,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def get_collection_info(self, *, collection_name: str) -> Dict[str, Any]:
|
|
88
|
+
info = self.client.get_collection(collection_name)
|
|
89
|
+
try:
|
|
90
|
+
return info.model_dump()
|
|
91
|
+
except Exception:
|
|
92
|
+
try:
|
|
93
|
+
return info.dict()
|
|
94
|
+
except Exception:
|
|
95
|
+
return {"collection": str(collection_name), "raw": str(info)}
|
|
96
|
+
|
|
97
|
+
def modify_collection_config(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
collection_name: str,
|
|
101
|
+
hnsw_config: Optional[Dict[str, Any]] = None,
|
|
102
|
+
collection_params: Optional[Dict[str, Any]] = None,
|
|
103
|
+
timeout: Optional[int] = None,
|
|
104
|
+
) -> bool:
|
|
105
|
+
"""
|
|
106
|
+
Patch collection-level config via Qdrant update_collection.
|
|
107
|
+
|
|
108
|
+
Supported keys:
|
|
109
|
+
- hnsw_config: dict for HnswConfigDiff (e.g. on_disk, m, ef_construct, full_scan_threshold)
|
|
110
|
+
- collection_params: dict for CollectionParamsDiff (e.g. on_disk_payload)
|
|
111
|
+
"""
|
|
112
|
+
from qdrant_client.http import models as m
|
|
113
|
+
|
|
114
|
+
hnsw_diff = m.HnswConfigDiff(**hnsw_config) if isinstance(hnsw_config, dict) else None
|
|
115
|
+
params_diff = (
|
|
116
|
+
m.CollectionParamsDiff(**collection_params)
|
|
117
|
+
if isinstance(collection_params, dict)
|
|
118
|
+
else None
|
|
119
|
+
)
|
|
120
|
+
if hnsw_diff is None and params_diff is None:
|
|
121
|
+
raise ValueError("No changes provided (pass hnsw_config and/or collection_params).")
|
|
122
|
+
return bool(
|
|
123
|
+
self.client.update_collection(
|
|
124
|
+
collection_name=str(collection_name),
|
|
125
|
+
hnsw_config=hnsw_diff,
|
|
126
|
+
collection_params=params_diff,
|
|
127
|
+
timeout=int(timeout) if timeout is not None else None,
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def modify_collection_vector_config(
|
|
132
|
+
self,
|
|
133
|
+
*,
|
|
134
|
+
collection_name: str,
|
|
135
|
+
vectors: Dict[str, Dict[str, Any]],
|
|
136
|
+
timeout: Optional[int] = None,
|
|
137
|
+
) -> bool:
|
|
138
|
+
"""
|
|
139
|
+
Patch vector params under params.vectors[vector_name] using Qdrant update_collection.
|
|
140
|
+
|
|
141
|
+
Supported keys per vector:
|
|
142
|
+
- on_disk: bool
|
|
143
|
+
- hnsw_config: dict with optional keys: m, ef_construct, full_scan_threshold, on_disk
|
|
144
|
+
"""
|
|
145
|
+
from qdrant_client.http import models as m
|
|
146
|
+
|
|
147
|
+
collection_name = str(collection_name)
|
|
148
|
+
info = self.client.get_collection(collection_name)
|
|
149
|
+
existing = set()
|
|
150
|
+
try:
|
|
151
|
+
existing = set((info.config.params.vectors or {}).keys())
|
|
152
|
+
except Exception:
|
|
153
|
+
existing = set()
|
|
154
|
+
|
|
155
|
+
missing = [str(k) for k in (vectors or {}).keys() if existing and str(k) not in existing]
|
|
156
|
+
if missing:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Vectors do not exist in collection '{collection_name}': {missing}. Existing: {sorted(existing)}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
ok = True
|
|
162
|
+
for name, cfg in (vectors or {}).items():
|
|
163
|
+
if not isinstance(cfg, dict):
|
|
164
|
+
raise ValueError(f"vectors['{name}'] must be a dict, got {type(cfg)}")
|
|
165
|
+
hnsw_cfg = cfg.get("hnsw_config")
|
|
166
|
+
hnsw_diff = m.HnswConfigDiff(**hnsw_cfg) if isinstance(hnsw_cfg, dict) else None
|
|
167
|
+
vectors_diff = {
|
|
168
|
+
str(name): m.VectorParamsDiff(
|
|
169
|
+
on_disk=cfg.get("on_disk", None),
|
|
170
|
+
hnsw_config=hnsw_diff,
|
|
171
|
+
)
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
ok = (
|
|
175
|
+
bool(
|
|
176
|
+
self.client.update_collection(
|
|
177
|
+
collection_name=collection_name,
|
|
178
|
+
vectors_config=vectors_diff,
|
|
179
|
+
timeout=int(timeout) if timeout is not None else None,
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
and ok
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return ok
|
|
186
|
+
|
|
187
|
+
def ensure_collection_all_on_disk(
|
|
188
|
+
self,
|
|
189
|
+
*,
|
|
190
|
+
collection_name: str,
|
|
191
|
+
timeout: Optional[int] = None,
|
|
192
|
+
) -> Dict[str, Any]:
|
|
193
|
+
"""
|
|
194
|
+
Ensure:
|
|
195
|
+
- All existing named vectors have on_disk=True and hnsw_config.on_disk=True
|
|
196
|
+
- Collection hnsw_config.on_disk=True
|
|
197
|
+
- Collection params.on_disk_payload=True
|
|
198
|
+
Returns the post-update collection info (dict).
|
|
199
|
+
"""
|
|
200
|
+
collection_name = str(collection_name)
|
|
201
|
+
info = self.client.get_collection(collection_name)
|
|
202
|
+
vectors = {}
|
|
203
|
+
try:
|
|
204
|
+
existing = list((info.config.params.vectors or {}).keys())
|
|
205
|
+
except Exception:
|
|
206
|
+
existing = []
|
|
207
|
+
for vname in existing:
|
|
208
|
+
vectors[str(vname)] = {"on_disk": True, "hnsw_config": {"on_disk": True}}
|
|
209
|
+
|
|
210
|
+
if vectors:
|
|
211
|
+
self.modify_collection_vector_config(
|
|
212
|
+
collection_name=collection_name, vectors=vectors, timeout=timeout
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
self.modify_collection_config(
|
|
216
|
+
collection_name=collection_name,
|
|
217
|
+
hnsw_config={"on_disk": True},
|
|
218
|
+
collection_params={"on_disk_payload": True},
|
|
219
|
+
timeout=timeout,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return self.get_collection_info(collection_name=collection_name)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Retrieval module - Search and retrieval strategies.
|
|
3
|
+
|
|
4
|
+
Components:
|
|
5
|
+
- TwoStageRetriever: Pooled prefetch → MaxSim reranking (our novel contribution)
|
|
6
|
+
- SingleStageRetriever: Direct multi-vector or pooled search
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from visual_rag.retrieval.multi_vector import MultiVectorRetriever
|
|
10
|
+
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
|
11
|
+
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
|
12
|
+
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"TwoStageRetriever",
|
|
16
|
+
"SingleStageRetriever",
|
|
17
|
+
"MultiVectorRetriever",
|
|
18
|
+
"ThreeStageRetriever",
|
|
19
|
+
]
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
from urllib.parse import urlparse
|
|
4
|
+
|
|
5
|
+
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
|
6
|
+
from visual_rag.retrieval.single_stage import SingleStageRetriever
|
|
7
|
+
from visual_rag.retrieval.three_stage import ThreeStageRetriever
|
|
8
|
+
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MultiVectorRetriever:
|
|
12
|
+
@staticmethod
|
|
13
|
+
def _maybe_load_dotenv() -> None:
|
|
14
|
+
try:
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
except ImportError:
|
|
17
|
+
return
|
|
18
|
+
if os.path.exists(".env"):
|
|
19
|
+
load_dotenv(".env")
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
collection_name: str,
|
|
24
|
+
model_name: str = "vidore/colSmol-500M",
|
|
25
|
+
qdrant_url: Optional[str] = None,
|
|
26
|
+
qdrant_api_key: Optional[str] = None,
|
|
27
|
+
prefer_grpc: bool = False,
|
|
28
|
+
request_timeout: int = 120,
|
|
29
|
+
max_retries: int = 3,
|
|
30
|
+
retry_sleep: float = 0.5,
|
|
31
|
+
qdrant_client=None,
|
|
32
|
+
embedder: Optional[VisualEmbedder] = None,
|
|
33
|
+
):
|
|
34
|
+
if qdrant_client is None:
|
|
35
|
+
self._maybe_load_dotenv()
|
|
36
|
+
try:
|
|
37
|
+
from qdrant_client import QdrantClient
|
|
38
|
+
except ImportError as e:
|
|
39
|
+
raise ImportError(
|
|
40
|
+
"Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
|
|
41
|
+
) from e
|
|
42
|
+
|
|
43
|
+
qdrant_url = (
|
|
44
|
+
qdrant_url
|
|
45
|
+
or os.getenv("SIGIR_QDRANT_URL")
|
|
46
|
+
or os.getenv("DEST_QDRANT_URL")
|
|
47
|
+
or os.getenv("QDRANT_URL")
|
|
48
|
+
)
|
|
49
|
+
if not qdrant_url:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"QDRANT_URL is required (pass qdrant_url or set env var). "
|
|
52
|
+
"You can also set DEST_QDRANT_URL to override."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
qdrant_api_key = (
|
|
56
|
+
qdrant_api_key
|
|
57
|
+
or os.getenv("SIGIR_QDRANT_KEY")
|
|
58
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
59
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
60
|
+
or os.getenv("QDRANT_API_KEY")
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
grpc_port = None
|
|
64
|
+
if prefer_grpc:
|
|
65
|
+
try:
|
|
66
|
+
if urlparse(qdrant_url).port == 6333:
|
|
67
|
+
grpc_port = 6334
|
|
68
|
+
except Exception:
|
|
69
|
+
grpc_port = None
|
|
70
|
+
|
|
71
|
+
def _make_client(use_grpc: bool):
|
|
72
|
+
return QdrantClient(
|
|
73
|
+
url=qdrant_url,
|
|
74
|
+
api_key=qdrant_api_key,
|
|
75
|
+
prefer_grpc=bool(use_grpc),
|
|
76
|
+
grpc_port=grpc_port,
|
|
77
|
+
timeout=int(request_timeout),
|
|
78
|
+
check_compatibility=False,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
qdrant_client = _make_client(prefer_grpc)
|
|
82
|
+
if prefer_grpc:
|
|
83
|
+
try:
|
|
84
|
+
_ = qdrant_client.get_collections()
|
|
85
|
+
except Exception as e:
|
|
86
|
+
msg = str(e)
|
|
87
|
+
if (
|
|
88
|
+
"StatusCode.PERMISSION_DENIED" in msg
|
|
89
|
+
or "http2 header with status: 403" in msg
|
|
90
|
+
):
|
|
91
|
+
qdrant_client = _make_client(False)
|
|
92
|
+
else:
|
|
93
|
+
raise
|
|
94
|
+
|
|
95
|
+
self.client = qdrant_client
|
|
96
|
+
self.collection_name = collection_name
|
|
97
|
+
self.embedder = embedder or VisualEmbedder(model_name=model_name)
|
|
98
|
+
|
|
99
|
+
self._two_stage = TwoStageRetriever(
|
|
100
|
+
self.client,
|
|
101
|
+
collection_name=self.collection_name,
|
|
102
|
+
request_timeout=int(request_timeout),
|
|
103
|
+
max_retries=int(max_retries),
|
|
104
|
+
retry_sleep=float(retry_sleep),
|
|
105
|
+
)
|
|
106
|
+
self._three_stage = ThreeStageRetriever(
|
|
107
|
+
self.client,
|
|
108
|
+
collection_name=self.collection_name,
|
|
109
|
+
request_timeout=int(request_timeout),
|
|
110
|
+
max_retries=int(max_retries),
|
|
111
|
+
retry_sleep=float(retry_sleep),
|
|
112
|
+
)
|
|
113
|
+
self._single_stage = SingleStageRetriever(
|
|
114
|
+
self.client,
|
|
115
|
+
collection_name=self.collection_name,
|
|
116
|
+
request_timeout=int(request_timeout),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def build_filter(
|
|
120
|
+
self,
|
|
121
|
+
year: Optional[Any] = None,
|
|
122
|
+
source: Optional[str] = None,
|
|
123
|
+
district: Optional[str] = None,
|
|
124
|
+
filename: Optional[str] = None,
|
|
125
|
+
has_text: Optional[bool] = None,
|
|
126
|
+
):
|
|
127
|
+
return self._two_stage.build_filter(
|
|
128
|
+
year=year,
|
|
129
|
+
source=source,
|
|
130
|
+
district=district,
|
|
131
|
+
filename=filename,
|
|
132
|
+
has_text=has_text,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def search(
|
|
136
|
+
self,
|
|
137
|
+
query: str,
|
|
138
|
+
top_k: int = 10,
|
|
139
|
+
mode: str = "single_full",
|
|
140
|
+
prefetch_k: Optional[int] = None,
|
|
141
|
+
stage1_mode: str = "pooled_query_vs_tiles",
|
|
142
|
+
filter_obj=None,
|
|
143
|
+
return_embeddings: bool = False,
|
|
144
|
+
) -> List[Dict[str, Any]]:
|
|
145
|
+
q = self.embedder.embed_query(query)
|
|
146
|
+
try:
|
|
147
|
+
import torch
|
|
148
|
+
except ImportError:
|
|
149
|
+
torch = None
|
|
150
|
+
if torch is not None and isinstance(q, torch.Tensor):
|
|
151
|
+
query_embedding = q.detach().cpu().numpy()
|
|
152
|
+
else:
|
|
153
|
+
query_embedding = q.numpy()
|
|
154
|
+
|
|
155
|
+
return self.search_embedded(
|
|
156
|
+
query_embedding=query_embedding,
|
|
157
|
+
top_k=top_k,
|
|
158
|
+
mode=mode,
|
|
159
|
+
prefetch_k=prefetch_k,
|
|
160
|
+
stage1_mode=stage1_mode,
|
|
161
|
+
filter_obj=filter_obj,
|
|
162
|
+
return_embeddings=return_embeddings,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def search_embedded(
|
|
166
|
+
self,
|
|
167
|
+
*,
|
|
168
|
+
query_embedding,
|
|
169
|
+
top_k: int = 10,
|
|
170
|
+
mode: str = "single_full",
|
|
171
|
+
prefetch_k: Optional[int] = None,
|
|
172
|
+
stage1_mode: str = "pooled_query_vs_tiles",
|
|
173
|
+
stage1_k: Optional[int] = None,
|
|
174
|
+
stage2_k: Optional[int] = None,
|
|
175
|
+
filter_obj=None,
|
|
176
|
+
return_embeddings: bool = False,
|
|
177
|
+
) -> List[Dict[str, Any]]:
|
|
178
|
+
if mode == "single_full":
|
|
179
|
+
return self._single_stage.search(
|
|
180
|
+
query_embedding=query_embedding,
|
|
181
|
+
top_k=top_k,
|
|
182
|
+
strategy="multi_vector",
|
|
183
|
+
filter_obj=filter_obj,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if mode == "single_tiles":
|
|
187
|
+
return self._single_stage.search(
|
|
188
|
+
query_embedding=query_embedding,
|
|
189
|
+
top_k=top_k,
|
|
190
|
+
strategy="tiles_maxsim",
|
|
191
|
+
filter_obj=filter_obj,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if mode == "single_global":
|
|
195
|
+
return self._single_stage.search(
|
|
196
|
+
query_embedding=query_embedding,
|
|
197
|
+
top_k=top_k,
|
|
198
|
+
strategy="pooled_global",
|
|
199
|
+
filter_obj=filter_obj,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if mode == "two_stage":
|
|
203
|
+
return self._two_stage.search_server_side(
|
|
204
|
+
query_embedding=query_embedding,
|
|
205
|
+
top_k=top_k,
|
|
206
|
+
prefetch_k=prefetch_k,
|
|
207
|
+
filter_obj=filter_obj,
|
|
208
|
+
stage1_mode=stage1_mode,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if mode == "three_stage":
|
|
212
|
+
s1 = int(stage1_k) if stage1_k is not None else 1000
|
|
213
|
+
s2 = int(stage2_k) if stage2_k is not None else 300
|
|
214
|
+
return self._three_stage.search_server_side(
|
|
215
|
+
query_embedding=query_embedding,
|
|
216
|
+
top_k=top_k,
|
|
217
|
+
stage1_k=s1,
|
|
218
|
+
stage2_k=s2,
|
|
219
|
+
filter_obj=filter_obj,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
raise ValueError(f"Unknown mode: {mode}")
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Single-Stage Retrieval for Visual Document Search.
|
|
3
|
+
|
|
4
|
+
Provides direct search without the two-stage complexity.
|
|
5
|
+
Use when:
|
|
6
|
+
- Collection is small (<10K documents)
|
|
7
|
+
- Latency is not critical
|
|
8
|
+
- Maximum accuracy is required
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any, Dict, List, Union
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SingleStageRetriever:
|
|
21
|
+
"""
|
|
22
|
+
Single-stage visual document retrieval using native Qdrant search.
|
|
23
|
+
|
|
24
|
+
Supports strategies:
|
|
25
|
+
- multi_vector: Native MaxSim on full embeddings (using="initial")
|
|
26
|
+
- tiles_maxsim: Native MaxSim between query tokens and tile vectors (using="mean_pooling")
|
|
27
|
+
- pooled_tile: Pooled query vs tile vectors (using="mean_pooling")
|
|
28
|
+
- pooled_global: Pooled query vs global pooled doc vector (using="global_pooling")
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
qdrant_client: Connected Qdrant client
|
|
32
|
+
collection_name: Name of the Qdrant collection
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> retriever = SingleStageRetriever(client, "my_collection")
|
|
36
|
+
>>> results = retriever.search(query, top_k=10)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
qdrant_client,
|
|
42
|
+
collection_name: str,
|
|
43
|
+
request_timeout: int = 120,
|
|
44
|
+
):
|
|
45
|
+
self.client = qdrant_client
|
|
46
|
+
self.collection_name = collection_name
|
|
47
|
+
self.request_timeout = int(request_timeout)
|
|
48
|
+
|
|
49
|
+
def search(
|
|
50
|
+
self,
|
|
51
|
+
query_embedding: Union[torch.Tensor, np.ndarray],
|
|
52
|
+
top_k: int = 10,
|
|
53
|
+
strategy: str = "multi_vector",
|
|
54
|
+
filter_obj=None,
|
|
55
|
+
) -> List[Dict[str, Any]]:
|
|
56
|
+
"""
|
|
57
|
+
Single-stage search with configurable strategy.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
query_embedding: Query embeddings [num_tokens, dim]
|
|
61
|
+
top_k: Number of results
|
|
62
|
+
strategy: "multi_vector", "tiles_maxsim", "pooled_tile", or "pooled_global"
|
|
63
|
+
filter_obj: Qdrant filter
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List of results with scores and metadata
|
|
67
|
+
"""
|
|
68
|
+
query_np = self._to_numpy(query_embedding)
|
|
69
|
+
|
|
70
|
+
if strategy == "multi_vector":
|
|
71
|
+
# Native multi-vector MaxSim
|
|
72
|
+
vector_name = "initial"
|
|
73
|
+
query_vector = query_np.tolist()
|
|
74
|
+
logger.debug(f"🎯 Multi-vector search on '{vector_name}'")
|
|
75
|
+
|
|
76
|
+
elif strategy == "tiles_maxsim":
|
|
77
|
+
# Native multi-vector MaxSim against tile vectors
|
|
78
|
+
vector_name = "mean_pooling"
|
|
79
|
+
query_vector = query_np.tolist()
|
|
80
|
+
logger.debug(f"🎯 Tile MaxSim search on '{vector_name}'")
|
|
81
|
+
|
|
82
|
+
elif strategy == "pooled_tile":
|
|
83
|
+
# Tile-level pooled
|
|
84
|
+
vector_name = "mean_pooling"
|
|
85
|
+
query_pooled = query_np.mean(axis=0)
|
|
86
|
+
query_vector = query_pooled.tolist()
|
|
87
|
+
logger.debug(f"🔍 Tile-pooled search on '{vector_name}'")
|
|
88
|
+
|
|
89
|
+
elif strategy == "pooled_global":
|
|
90
|
+
# Global pooled vector (single vector)
|
|
91
|
+
vector_name = "global_pooling"
|
|
92
|
+
query_pooled = query_np.mean(axis=0)
|
|
93
|
+
query_vector = query_pooled.tolist()
|
|
94
|
+
logger.debug(f"🔍 Global-pooled search on '{vector_name}'")
|
|
95
|
+
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Unknown strategy: {strategy}")
|
|
98
|
+
|
|
99
|
+
results = self.client.query_points(
|
|
100
|
+
collection_name=self.collection_name,
|
|
101
|
+
query=query_vector,
|
|
102
|
+
using=vector_name,
|
|
103
|
+
query_filter=filter_obj,
|
|
104
|
+
limit=top_k,
|
|
105
|
+
with_payload=True,
|
|
106
|
+
with_vectors=False,
|
|
107
|
+
timeout=self.request_timeout,
|
|
108
|
+
).points
|
|
109
|
+
|
|
110
|
+
return [
|
|
111
|
+
{
|
|
112
|
+
"id": r.id,
|
|
113
|
+
"score": r.score,
|
|
114
|
+
"score_final": r.score,
|
|
115
|
+
"payload": r.payload,
|
|
116
|
+
}
|
|
117
|
+
for r in results
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
|
121
|
+
"""Convert embedding to numpy array."""
|
|
122
|
+
if isinstance(embedding, torch.Tensor):
|
|
123
|
+
if embedding.dtype == torch.bfloat16:
|
|
124
|
+
return embedding.cpu().float().numpy()
|
|
125
|
+
return embedding.cpu().numpy()
|
|
126
|
+
return np.array(embedding, dtype=np.float32)
|