visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,222 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
+ from urllib.parse import urlparse
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class QdrantConnection:
10
+ url: str
11
+ api_key: Optional[str]
12
+
13
+
14
+ def _maybe_load_dotenv() -> None:
15
+ try:
16
+ from dotenv import load_dotenv
17
+ except Exception:
18
+ return
19
+ try:
20
+ from pathlib import Path
21
+
22
+ if Path(".env").exists():
23
+ load_dotenv(".env")
24
+ except Exception:
25
+ return
26
+
27
+
28
+ def _resolve_qdrant_connection(
29
+ *,
30
+ url: Optional[str] = None,
31
+ api_key: Optional[str] = None,
32
+ ) -> QdrantConnection:
33
+ import os
34
+
35
+ _maybe_load_dotenv()
36
+ resolved_url = (
37
+ url
38
+ or os.getenv("SIGIR_QDRANT_URL")
39
+ or os.getenv("DEST_QDRANT_URL")
40
+ or os.getenv("QDRANT_URL")
41
+ )
42
+ if not resolved_url:
43
+ raise ValueError(
44
+ "Qdrant URL not set (pass url= or set SIGIR_QDRANT_URL/DEST_QDRANT_URL/QDRANT_URL)."
45
+ )
46
+ resolved_key = (
47
+ api_key
48
+ or os.getenv("SIGIR_QDRANT_KEY")
49
+ or os.getenv("SIGIR_QDRANT_API_KEY")
50
+ or os.getenv("DEST_QDRANT_API_KEY")
51
+ or os.getenv("QDRANT_API_KEY")
52
+ )
53
+ return QdrantConnection(url=str(resolved_url), api_key=resolved_key)
54
+
55
+
56
+ def _infer_grpc_port(url: str) -> Optional[int]:
57
+ try:
58
+ if urlparse(url).port == 6333:
59
+ return 6334
60
+ except Exception:
61
+ return None
62
+ return None
63
+
64
+
65
+ class QdrantAdmin:
66
+ def __init__(
67
+ self,
68
+ *,
69
+ url: Optional[str] = None,
70
+ api_key: Optional[str] = None,
71
+ prefer_grpc: bool = False,
72
+ timeout: int = 60,
73
+ ):
74
+ from qdrant_client import QdrantClient
75
+
76
+ conn = _resolve_qdrant_connection(url=url, api_key=api_key)
77
+ grpc_port = _infer_grpc_port(conn.url) if prefer_grpc else None
78
+ self.client = QdrantClient(
79
+ url=conn.url,
80
+ api_key=conn.api_key,
81
+ prefer_grpc=bool(prefer_grpc),
82
+ grpc_port=grpc_port,
83
+ timeout=int(timeout),
84
+ check_compatibility=False,
85
+ )
86
+
87
+ def get_collection_info(self, *, collection_name: str) -> Dict[str, Any]:
88
+ info = self.client.get_collection(collection_name)
89
+ try:
90
+ return info.model_dump()
91
+ except Exception:
92
+ try:
93
+ return info.dict()
94
+ except Exception:
95
+ return {"collection": str(collection_name), "raw": str(info)}
96
+
97
+ def modify_collection_config(
98
+ self,
99
+ *,
100
+ collection_name: str,
101
+ hnsw_config: Optional[Dict[str, Any]] = None,
102
+ collection_params: Optional[Dict[str, Any]] = None,
103
+ timeout: Optional[int] = None,
104
+ ) -> bool:
105
+ """
106
+ Patch collection-level config via Qdrant update_collection.
107
+
108
+ Supported keys:
109
+ - hnsw_config: dict for HnswConfigDiff (e.g. on_disk, m, ef_construct, full_scan_threshold)
110
+ - collection_params: dict for CollectionParamsDiff (e.g. on_disk_payload)
111
+ """
112
+ from qdrant_client.http import models as m
113
+
114
+ hnsw_diff = m.HnswConfigDiff(**hnsw_config) if isinstance(hnsw_config, dict) else None
115
+ params_diff = (
116
+ m.CollectionParamsDiff(**collection_params)
117
+ if isinstance(collection_params, dict)
118
+ else None
119
+ )
120
+ if hnsw_diff is None and params_diff is None:
121
+ raise ValueError("No changes provided (pass hnsw_config and/or collection_params).")
122
+ return bool(
123
+ self.client.update_collection(
124
+ collection_name=str(collection_name),
125
+ hnsw_config=hnsw_diff,
126
+ collection_params=params_diff,
127
+ timeout=int(timeout) if timeout is not None else None,
128
+ )
129
+ )
130
+
131
+ def modify_collection_vector_config(
132
+ self,
133
+ *,
134
+ collection_name: str,
135
+ vectors: Dict[str, Dict[str, Any]],
136
+ timeout: Optional[int] = None,
137
+ ) -> bool:
138
+ """
139
+ Patch vector params under params.vectors[vector_name] using Qdrant update_collection.
140
+
141
+ Supported keys per vector:
142
+ - on_disk: bool
143
+ - hnsw_config: dict with optional keys: m, ef_construct, full_scan_threshold, on_disk
144
+ """
145
+ from qdrant_client.http import models as m
146
+
147
+ collection_name = str(collection_name)
148
+ info = self.client.get_collection(collection_name)
149
+ existing = set()
150
+ try:
151
+ existing = set((info.config.params.vectors or {}).keys())
152
+ except Exception:
153
+ existing = set()
154
+
155
+ missing = [str(k) for k in (vectors or {}).keys() if existing and str(k) not in existing]
156
+ if missing:
157
+ raise ValueError(
158
+ f"Vectors do not exist in collection '{collection_name}': {missing}. Existing: {sorted(existing)}"
159
+ )
160
+
161
+ ok = True
162
+ for name, cfg in (vectors or {}).items():
163
+ if not isinstance(cfg, dict):
164
+ raise ValueError(f"vectors['{name}'] must be a dict, got {type(cfg)}")
165
+ hnsw_cfg = cfg.get("hnsw_config")
166
+ hnsw_diff = m.HnswConfigDiff(**hnsw_cfg) if isinstance(hnsw_cfg, dict) else None
167
+ vectors_diff = {
168
+ str(name): m.VectorParamsDiff(
169
+ on_disk=cfg.get("on_disk", None),
170
+ hnsw_config=hnsw_diff,
171
+ )
172
+ }
173
+
174
+ ok = (
175
+ bool(
176
+ self.client.update_collection(
177
+ collection_name=collection_name,
178
+ vectors_config=vectors_diff,
179
+ timeout=int(timeout) if timeout is not None else None,
180
+ )
181
+ )
182
+ and ok
183
+ )
184
+
185
+ return ok
186
+
187
+ def ensure_collection_all_on_disk(
188
+ self,
189
+ *,
190
+ collection_name: str,
191
+ timeout: Optional[int] = None,
192
+ ) -> Dict[str, Any]:
193
+ """
194
+ Ensure:
195
+ - All existing named vectors have on_disk=True and hnsw_config.on_disk=True
196
+ - Collection hnsw_config.on_disk=True
197
+ - Collection params.on_disk_payload=True
198
+ Returns the post-update collection info (dict).
199
+ """
200
+ collection_name = str(collection_name)
201
+ info = self.client.get_collection(collection_name)
202
+ vectors = {}
203
+ try:
204
+ existing = list((info.config.params.vectors or {}).keys())
205
+ except Exception:
206
+ existing = []
207
+ for vname in existing:
208
+ vectors[str(vname)] = {"on_disk": True, "hnsw_config": {"on_disk": True}}
209
+
210
+ if vectors:
211
+ self.modify_collection_vector_config(
212
+ collection_name=collection_name, vectors=vectors, timeout=timeout
213
+ )
214
+
215
+ self.modify_collection_config(
216
+ collection_name=collection_name,
217
+ hnsw_config={"on_disk": True},
218
+ collection_params={"on_disk_payload": True},
219
+ timeout=timeout,
220
+ )
221
+
222
+ return self.get_collection_info(collection_name=collection_name)
@@ -0,0 +1,19 @@
1
+ """
2
+ Retrieval module - Search and retrieval strategies.
3
+
4
+ Components:
5
+ - TwoStageRetriever: Pooled prefetch → MaxSim reranking (our novel contribution)
6
+ - SingleStageRetriever: Direct multi-vector or pooled search
7
+ """
8
+
9
+ from visual_rag.retrieval.multi_vector import MultiVectorRetriever
10
+ from visual_rag.retrieval.single_stage import SingleStageRetriever
11
+ from visual_rag.retrieval.three_stage import ThreeStageRetriever
12
+ from visual_rag.retrieval.two_stage import TwoStageRetriever
13
+
14
+ __all__ = [
15
+ "TwoStageRetriever",
16
+ "SingleStageRetriever",
17
+ "MultiVectorRetriever",
18
+ "ThreeStageRetriever",
19
+ ]
@@ -0,0 +1,222 @@
1
+ import os
2
+ from typing import Any, Dict, List, Optional
3
+ from urllib.parse import urlparse
4
+
5
+ from visual_rag.embedding.visual_embedder import VisualEmbedder
6
+ from visual_rag.retrieval.single_stage import SingleStageRetriever
7
+ from visual_rag.retrieval.three_stage import ThreeStageRetriever
8
+ from visual_rag.retrieval.two_stage import TwoStageRetriever
9
+
10
+
11
+ class MultiVectorRetriever:
12
+ @staticmethod
13
+ def _maybe_load_dotenv() -> None:
14
+ try:
15
+ from dotenv import load_dotenv
16
+ except ImportError:
17
+ return
18
+ if os.path.exists(".env"):
19
+ load_dotenv(".env")
20
+
21
+ def __init__(
22
+ self,
23
+ collection_name: str,
24
+ model_name: str = "vidore/colSmol-500M",
25
+ qdrant_url: Optional[str] = None,
26
+ qdrant_api_key: Optional[str] = None,
27
+ prefer_grpc: bool = False,
28
+ request_timeout: int = 120,
29
+ max_retries: int = 3,
30
+ retry_sleep: float = 0.5,
31
+ qdrant_client=None,
32
+ embedder: Optional[VisualEmbedder] = None,
33
+ ):
34
+ if qdrant_client is None:
35
+ self._maybe_load_dotenv()
36
+ try:
37
+ from qdrant_client import QdrantClient
38
+ except ImportError as e:
39
+ raise ImportError(
40
+ "Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
41
+ ) from e
42
+
43
+ qdrant_url = (
44
+ qdrant_url
45
+ or os.getenv("SIGIR_QDRANT_URL")
46
+ or os.getenv("DEST_QDRANT_URL")
47
+ or os.getenv("QDRANT_URL")
48
+ )
49
+ if not qdrant_url:
50
+ raise ValueError(
51
+ "QDRANT_URL is required (pass qdrant_url or set env var). "
52
+ "You can also set DEST_QDRANT_URL to override."
53
+ )
54
+
55
+ qdrant_api_key = (
56
+ qdrant_api_key
57
+ or os.getenv("SIGIR_QDRANT_KEY")
58
+ or os.getenv("SIGIR_QDRANT_API_KEY")
59
+ or os.getenv("DEST_QDRANT_API_KEY")
60
+ or os.getenv("QDRANT_API_KEY")
61
+ )
62
+
63
+ grpc_port = None
64
+ if prefer_grpc:
65
+ try:
66
+ if urlparse(qdrant_url).port == 6333:
67
+ grpc_port = 6334
68
+ except Exception:
69
+ grpc_port = None
70
+
71
+ def _make_client(use_grpc: bool):
72
+ return QdrantClient(
73
+ url=qdrant_url,
74
+ api_key=qdrant_api_key,
75
+ prefer_grpc=bool(use_grpc),
76
+ grpc_port=grpc_port,
77
+ timeout=int(request_timeout),
78
+ check_compatibility=False,
79
+ )
80
+
81
+ qdrant_client = _make_client(prefer_grpc)
82
+ if prefer_grpc:
83
+ try:
84
+ _ = qdrant_client.get_collections()
85
+ except Exception as e:
86
+ msg = str(e)
87
+ if (
88
+ "StatusCode.PERMISSION_DENIED" in msg
89
+ or "http2 header with status: 403" in msg
90
+ ):
91
+ qdrant_client = _make_client(False)
92
+ else:
93
+ raise
94
+
95
+ self.client = qdrant_client
96
+ self.collection_name = collection_name
97
+ self.embedder = embedder or VisualEmbedder(model_name=model_name)
98
+
99
+ self._two_stage = TwoStageRetriever(
100
+ self.client,
101
+ collection_name=self.collection_name,
102
+ request_timeout=int(request_timeout),
103
+ max_retries=int(max_retries),
104
+ retry_sleep=float(retry_sleep),
105
+ )
106
+ self._three_stage = ThreeStageRetriever(
107
+ self.client,
108
+ collection_name=self.collection_name,
109
+ request_timeout=int(request_timeout),
110
+ max_retries=int(max_retries),
111
+ retry_sleep=float(retry_sleep),
112
+ )
113
+ self._single_stage = SingleStageRetriever(
114
+ self.client,
115
+ collection_name=self.collection_name,
116
+ request_timeout=int(request_timeout),
117
+ )
118
+
119
+ def build_filter(
120
+ self,
121
+ year: Optional[Any] = None,
122
+ source: Optional[str] = None,
123
+ district: Optional[str] = None,
124
+ filename: Optional[str] = None,
125
+ has_text: Optional[bool] = None,
126
+ ):
127
+ return self._two_stage.build_filter(
128
+ year=year,
129
+ source=source,
130
+ district=district,
131
+ filename=filename,
132
+ has_text=has_text,
133
+ )
134
+
135
+ def search(
136
+ self,
137
+ query: str,
138
+ top_k: int = 10,
139
+ mode: str = "single_full",
140
+ prefetch_k: Optional[int] = None,
141
+ stage1_mode: str = "pooled_query_vs_tiles",
142
+ filter_obj=None,
143
+ return_embeddings: bool = False,
144
+ ) -> List[Dict[str, Any]]:
145
+ q = self.embedder.embed_query(query)
146
+ try:
147
+ import torch
148
+ except ImportError:
149
+ torch = None
150
+ if torch is not None and isinstance(q, torch.Tensor):
151
+ query_embedding = q.detach().cpu().numpy()
152
+ else:
153
+ query_embedding = q.numpy()
154
+
155
+ return self.search_embedded(
156
+ query_embedding=query_embedding,
157
+ top_k=top_k,
158
+ mode=mode,
159
+ prefetch_k=prefetch_k,
160
+ stage1_mode=stage1_mode,
161
+ filter_obj=filter_obj,
162
+ return_embeddings=return_embeddings,
163
+ )
164
+
165
+ def search_embedded(
166
+ self,
167
+ *,
168
+ query_embedding,
169
+ top_k: int = 10,
170
+ mode: str = "single_full",
171
+ prefetch_k: Optional[int] = None,
172
+ stage1_mode: str = "pooled_query_vs_tiles",
173
+ stage1_k: Optional[int] = None,
174
+ stage2_k: Optional[int] = None,
175
+ filter_obj=None,
176
+ return_embeddings: bool = False,
177
+ ) -> List[Dict[str, Any]]:
178
+ if mode == "single_full":
179
+ return self._single_stage.search(
180
+ query_embedding=query_embedding,
181
+ top_k=top_k,
182
+ strategy="multi_vector",
183
+ filter_obj=filter_obj,
184
+ )
185
+
186
+ if mode == "single_tiles":
187
+ return self._single_stage.search(
188
+ query_embedding=query_embedding,
189
+ top_k=top_k,
190
+ strategy="tiles_maxsim",
191
+ filter_obj=filter_obj,
192
+ )
193
+
194
+ if mode == "single_global":
195
+ return self._single_stage.search(
196
+ query_embedding=query_embedding,
197
+ top_k=top_k,
198
+ strategy="pooled_global",
199
+ filter_obj=filter_obj,
200
+ )
201
+
202
+ if mode == "two_stage":
203
+ return self._two_stage.search_server_side(
204
+ query_embedding=query_embedding,
205
+ top_k=top_k,
206
+ prefetch_k=prefetch_k,
207
+ filter_obj=filter_obj,
208
+ stage1_mode=stage1_mode,
209
+ )
210
+
211
+ if mode == "three_stage":
212
+ s1 = int(stage1_k) if stage1_k is not None else 1000
213
+ s2 = int(stage2_k) if stage2_k is not None else 300
214
+ return self._three_stage.search_server_side(
215
+ query_embedding=query_embedding,
216
+ top_k=top_k,
217
+ stage1_k=s1,
218
+ stage2_k=s2,
219
+ filter_obj=filter_obj,
220
+ )
221
+
222
+ raise ValueError(f"Unknown mode: {mode}")
@@ -0,0 +1,126 @@
1
+ """
2
+ Single-Stage Retrieval for Visual Document Search.
3
+
4
+ Provides direct search without the two-stage complexity.
5
+ Use when:
6
+ - Collection is small (<10K documents)
7
+ - Latency is not critical
8
+ - Maximum accuracy is required
9
+ """
10
+
11
+ import logging
12
+ from typing import Any, Dict, List, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SingleStageRetriever:
21
+ """
22
+ Single-stage visual document retrieval using native Qdrant search.
23
+
24
+ Supports strategies:
25
+ - multi_vector: Native MaxSim on full embeddings (using="initial")
26
+ - tiles_maxsim: Native MaxSim between query tokens and tile vectors (using="mean_pooling")
27
+ - pooled_tile: Pooled query vs tile vectors (using="mean_pooling")
28
+ - pooled_global: Pooled query vs global pooled doc vector (using="global_pooling")
29
+
30
+ Args:
31
+ qdrant_client: Connected Qdrant client
32
+ collection_name: Name of the Qdrant collection
33
+
34
+ Example:
35
+ >>> retriever = SingleStageRetriever(client, "my_collection")
36
+ >>> results = retriever.search(query, top_k=10)
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ qdrant_client,
42
+ collection_name: str,
43
+ request_timeout: int = 120,
44
+ ):
45
+ self.client = qdrant_client
46
+ self.collection_name = collection_name
47
+ self.request_timeout = int(request_timeout)
48
+
49
+ def search(
50
+ self,
51
+ query_embedding: Union[torch.Tensor, np.ndarray],
52
+ top_k: int = 10,
53
+ strategy: str = "multi_vector",
54
+ filter_obj=None,
55
+ ) -> List[Dict[str, Any]]:
56
+ """
57
+ Single-stage search with configurable strategy.
58
+
59
+ Args:
60
+ query_embedding: Query embeddings [num_tokens, dim]
61
+ top_k: Number of results
62
+ strategy: "multi_vector", "tiles_maxsim", "pooled_tile", or "pooled_global"
63
+ filter_obj: Qdrant filter
64
+
65
+ Returns:
66
+ List of results with scores and metadata
67
+ """
68
+ query_np = self._to_numpy(query_embedding)
69
+
70
+ if strategy == "multi_vector":
71
+ # Native multi-vector MaxSim
72
+ vector_name = "initial"
73
+ query_vector = query_np.tolist()
74
+ logger.debug(f"🎯 Multi-vector search on '{vector_name}'")
75
+
76
+ elif strategy == "tiles_maxsim":
77
+ # Native multi-vector MaxSim against tile vectors
78
+ vector_name = "mean_pooling"
79
+ query_vector = query_np.tolist()
80
+ logger.debug(f"🎯 Tile MaxSim search on '{vector_name}'")
81
+
82
+ elif strategy == "pooled_tile":
83
+ # Tile-level pooled
84
+ vector_name = "mean_pooling"
85
+ query_pooled = query_np.mean(axis=0)
86
+ query_vector = query_pooled.tolist()
87
+ logger.debug(f"🔍 Tile-pooled search on '{vector_name}'")
88
+
89
+ elif strategy == "pooled_global":
90
+ # Global pooled vector (single vector)
91
+ vector_name = "global_pooling"
92
+ query_pooled = query_np.mean(axis=0)
93
+ query_vector = query_pooled.tolist()
94
+ logger.debug(f"🔍 Global-pooled search on '{vector_name}'")
95
+
96
+ else:
97
+ raise ValueError(f"Unknown strategy: {strategy}")
98
+
99
+ results = self.client.query_points(
100
+ collection_name=self.collection_name,
101
+ query=query_vector,
102
+ using=vector_name,
103
+ query_filter=filter_obj,
104
+ limit=top_k,
105
+ with_payload=True,
106
+ with_vectors=False,
107
+ timeout=self.request_timeout,
108
+ ).points
109
+
110
+ return [
111
+ {
112
+ "id": r.id,
113
+ "score": r.score,
114
+ "score_final": r.score,
115
+ "payload": r.payload,
116
+ }
117
+ for r in results
118
+ ]
119
+
120
+ def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
121
+ """Convert embedding to numpy array."""
122
+ if isinstance(embedding, torch.Tensor):
123
+ if embedding.dtype == torch.bfloat16:
124
+ return embedding.cpu().float().numpy()
125
+ return embedding.cpu().numpy()
126
+ return np.array(embedding, dtype=np.float32)