mteb 2.2.2__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +4 -0
- mteb/descriptive_stats/Reranking/MultiLongDocReranking.json +466 -0
- mteb/evaluate.py +38 -7
- mteb/models/__init__.py +4 -1
- mteb/models/cache_wrappers/__init__.py +2 -1
- mteb/models/model_implementations/colpali_models.py +4 -4
- mteb/models/model_implementations/colqwen_models.py +206 -2
- mteb/models/model_implementations/eagerworks_models.py +163 -0
- mteb/models/model_implementations/euler_models.py +25 -0
- mteb/models/model_implementations/google_models.py +1 -1
- mteb/models/model_implementations/jina_models.py +203 -5
- mteb/models/model_implementations/nb_sbert.py +1 -1
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +10 -11
- mteb/models/model_implementations/nvidia_models.py +1 -1
- mteb/models/model_implementations/ops_moa_models.py +2 -2
- mteb/models/model_implementations/promptriever_models.py +4 -4
- mteb/models/model_implementations/qwen3_models.py +3 -3
- mteb/models/model_implementations/qzhou_models.py +1 -1
- mteb/models/model_implementations/random_baseline.py +8 -18
- mteb/models/model_implementations/vdr_models.py +1 -0
- mteb/models/model_implementations/yuan_models_en.py +57 -0
- mteb/models/search_encoder_index/__init__.py +7 -0
- mteb/models/search_encoder_index/search_backend_protocol.py +50 -0
- mteb/models/search_encoder_index/search_indexes/__init__.py +5 -0
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +157 -0
- mteb/models/search_wrappers.py +157 -41
- mteb/results/model_result.py +2 -1
- mteb/results/task_result.py +12 -0
- mteb/similarity_functions.py +49 -0
- mteb/tasks/reranking/multilingual/__init__.py +2 -0
- mteb/tasks/reranking/multilingual/multi_long_doc_reranking.py +70 -0
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +4 -0
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +56 -42
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +3 -3
- {mteb-2.2.2.dist-info → mteb-2.3.1.dist-info}/METADATA +6 -1
- {mteb-2.2.2.dist-info → mteb-2.3.1.dist-info}/RECORD +40 -31
- {mteb-2.2.2.dist-info → mteb-2.3.1.dist-info}/WHEEL +0 -0
- {mteb-2.2.2.dist-info → mteb-2.3.1.dist-info}/entry_points.txt +0 -0
- {mteb-2.2.2.dist-info → mteb-2.3.1.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.2.2.dist-info → mteb-2.3.1.dist-info}/top_level.txt +0 -0
|
@@ -117,19 +117,18 @@ class LlamaNemoretrieverColembed(AbsEncoder):
|
|
|
117
117
|
|
|
118
118
|
TRAINING_DATA = {
|
|
119
119
|
# from https://huggingface.co/datasets/vidore/colpali_train_set
|
|
120
|
-
"
|
|
121
|
-
"
|
|
122
|
-
"
|
|
123
|
-
"
|
|
124
|
-
"
|
|
125
|
-
"
|
|
120
|
+
"VidoreDocVQARetrieval",
|
|
121
|
+
"VidoreInfoVQARetrieval",
|
|
122
|
+
"VidoreTatdqaRetrieval",
|
|
123
|
+
"VidoreArxivQARetrieval",
|
|
124
|
+
"HotpotQA",
|
|
125
|
+
"MIRACLRetrieval",
|
|
126
126
|
"NQ",
|
|
127
|
-
"
|
|
127
|
+
"StackExchangeClustering",
|
|
128
128
|
"SQuAD",
|
|
129
129
|
"WebInstructSub",
|
|
130
130
|
"docmatix-ir",
|
|
131
|
-
"
|
|
132
|
-
"colpali_train_set", # as it contains PDFs
|
|
131
|
+
"VDRMultilingualRetrieval",
|
|
133
132
|
"VisRAG-Ret-Train-Synthetic-data",
|
|
134
133
|
"VisRAG-Ret-Train-In-domain-data",
|
|
135
134
|
"wiki-ss-nq",
|
|
@@ -146,7 +145,7 @@ llama_nemoretriever_colembed_1b_v1 = ModelMeta(
|
|
|
146
145
|
release_date="2025-06-27",
|
|
147
146
|
modalities=["image", "text"],
|
|
148
147
|
n_parameters=2_418_000_000,
|
|
149
|
-
memory_usage_mb=
|
|
148
|
+
memory_usage_mb=4610,
|
|
150
149
|
max_tokens=8192,
|
|
151
150
|
embed_dim=2048,
|
|
152
151
|
license="https://huggingface.co/nvidia/llama-nemoretriever-colembed-1b-v1/blob/main/LICENSE",
|
|
@@ -172,7 +171,7 @@ llama_nemoretriever_colembed_3b_v1 = ModelMeta(
|
|
|
172
171
|
release_date="2025-06-27",
|
|
173
172
|
modalities=["image", "text"],
|
|
174
173
|
n_parameters=4_407_000_000,
|
|
175
|
-
memory_usage_mb=
|
|
174
|
+
memory_usage_mb=8403,
|
|
176
175
|
max_tokens=8192,
|
|
177
176
|
embed_dim=3072,
|
|
178
177
|
license="https://huggingface.co/nvidia/llama-nemoretriever-colembed-1b-v1/blob/main/LICENSE",
|
|
@@ -146,7 +146,7 @@ NV_embed_v1 = ModelMeta(
|
|
|
146
146
|
revision="570834afd5fef5bf3a3c2311a2b6e0a66f6f4f2c",
|
|
147
147
|
release_date="2024-09-13", # initial commit of hf model.
|
|
148
148
|
n_parameters=7_850_000_000,
|
|
149
|
-
memory_usage_mb=
|
|
149
|
+
memory_usage_mb=14975,
|
|
150
150
|
embed_dim=4096,
|
|
151
151
|
license="cc-by-nc-4.0",
|
|
152
152
|
max_tokens=32768,
|
|
@@ -27,7 +27,7 @@ ops_moa_conan_embedding = ModelMeta(
|
|
|
27
27
|
languages=["zho-Hans"],
|
|
28
28
|
loader=OPSWrapper,
|
|
29
29
|
n_parameters=int(343 * 1e6),
|
|
30
|
-
memory_usage_mb=
|
|
30
|
+
memory_usage_mb=1308,
|
|
31
31
|
max_tokens=512,
|
|
32
32
|
embed_dim=1536,
|
|
33
33
|
license="cc-by-nc-4.0",
|
|
@@ -58,7 +58,7 @@ ops_moa_yuan_embedding = ModelMeta(
|
|
|
58
58
|
languages=["zho-Hans"],
|
|
59
59
|
loader=OPSWrapper,
|
|
60
60
|
n_parameters=int(343 * 1e6),
|
|
61
|
-
memory_usage_mb=
|
|
61
|
+
memory_usage_mb=1242,
|
|
62
62
|
max_tokens=512,
|
|
63
63
|
embed_dim=1536,
|
|
64
64
|
license="cc-by-nc-4.0",
|
|
@@ -80,7 +80,7 @@ promptriever_llama2 = ModelMeta(
|
|
|
80
80
|
revision="01c7f73d771dfac7d292323805ebc428287df4f9-30b14e3813c0fa45facfd01a594580c3fe5ecf23", # base-peft revision
|
|
81
81
|
release_date="2024-09-15",
|
|
82
82
|
n_parameters=7_000_000_000,
|
|
83
|
-
memory_usage_mb=
|
|
83
|
+
memory_usage_mb=26703,
|
|
84
84
|
max_tokens=4096,
|
|
85
85
|
embed_dim=4096,
|
|
86
86
|
license="apache-2.0",
|
|
@@ -115,7 +115,7 @@ promptriever_llama3 = ModelMeta(
|
|
|
115
115
|
},
|
|
116
116
|
release_date="2024-09-15",
|
|
117
117
|
n_parameters=8_000_000_000,
|
|
118
|
-
memory_usage_mb=
|
|
118
|
+
memory_usage_mb=30518,
|
|
119
119
|
max_tokens=8192,
|
|
120
120
|
embed_dim=4096,
|
|
121
121
|
license="apache-2.0",
|
|
@@ -143,7 +143,7 @@ promptriever_llama3_instruct = ModelMeta(
|
|
|
143
143
|
revision="5206a32e0bd3067aef1ce90f5528ade7d866253f-8b677258615625122c2eb7329292b8c402612c21", # base-peft revision
|
|
144
144
|
release_date="2024-09-15",
|
|
145
145
|
n_parameters=8_000_000_000,
|
|
146
|
-
memory_usage_mb=
|
|
146
|
+
memory_usage_mb=30518,
|
|
147
147
|
max_tokens=8192,
|
|
148
148
|
embed_dim=4096,
|
|
149
149
|
training_datasets={
|
|
@@ -175,7 +175,7 @@ promptriever_mistral_v1 = ModelMeta(
|
|
|
175
175
|
revision="7231864981174d9bee8c7687c24c8344414eae6b-876d63e49b6115ecb6839893a56298fadee7e8f5", # base-peft revision
|
|
176
176
|
release_date="2024-09-15",
|
|
177
177
|
n_parameters=7_000_000_000,
|
|
178
|
-
memory_usage_mb=
|
|
178
|
+
memory_usage_mb=26703,
|
|
179
179
|
training_datasets={
|
|
180
180
|
# "samaya-ai/msmarco-w-instructions",
|
|
181
181
|
"mMARCO-NL", # translation not trained on
|
|
@@ -139,7 +139,7 @@ Qwen3_Embedding_0B6 = ModelMeta(
|
|
|
139
139
|
revision="b22da495047858cce924d27d76261e96be6febc0", # Commit of @tomaarsen
|
|
140
140
|
release_date="2025-06-05",
|
|
141
141
|
n_parameters=595776512,
|
|
142
|
-
memory_usage_mb=
|
|
142
|
+
memory_usage_mb=1136,
|
|
143
143
|
embed_dim=1024,
|
|
144
144
|
max_tokens=32768,
|
|
145
145
|
license="apache-2.0",
|
|
@@ -161,7 +161,7 @@ Qwen3_Embedding_4B = ModelMeta(
|
|
|
161
161
|
revision="636cd9bf47d976946cdbb2b0c3ca0cb2f8eea5ff", # Commit of @tomaarsen
|
|
162
162
|
release_date="2025-06-05",
|
|
163
163
|
n_parameters=4021774336,
|
|
164
|
-
memory_usage_mb=
|
|
164
|
+
memory_usage_mb=7671,
|
|
165
165
|
embed_dim=2560,
|
|
166
166
|
max_tokens=32768,
|
|
167
167
|
license="apache-2.0",
|
|
@@ -183,7 +183,7 @@ Qwen3_Embedding_8B = ModelMeta(
|
|
|
183
183
|
revision="4e423935c619ae4df87b646a3ce949610c66241c", # Commit of @tomaarsen
|
|
184
184
|
release_date="2025-06-05",
|
|
185
185
|
n_parameters=7567295488,
|
|
186
|
-
memory_usage_mb=
|
|
186
|
+
memory_usage_mb=14433,
|
|
187
187
|
embed_dim=4096,
|
|
188
188
|
max_tokens=32768,
|
|
189
189
|
license="apache-2.0",
|
|
@@ -63,7 +63,7 @@ QZhou_Embedding = ModelMeta(
|
|
|
63
63
|
revision="f1e6c03ee3882e7b9fa5cec91217715272e433b8",
|
|
64
64
|
release_date="2025-08-24",
|
|
65
65
|
n_parameters=7_070_619_136,
|
|
66
|
-
memory_usage_mb=
|
|
66
|
+
memory_usage_mb=14436,
|
|
67
67
|
embed_dim=3584,
|
|
68
68
|
license="apache-2.0",
|
|
69
69
|
max_tokens=8192,
|
|
@@ -8,6 +8,10 @@ from torch.utils.data import DataLoader
|
|
|
8
8
|
|
|
9
9
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
10
10
|
from mteb.models.model_meta import ModelMeta
|
|
11
|
+
from mteb.similarity_functions import (
|
|
12
|
+
select_pairwise_similarity,
|
|
13
|
+
select_similarity,
|
|
14
|
+
)
|
|
11
15
|
from mteb.types._encoder_io import Array, BatchedInput, PromptType
|
|
12
16
|
|
|
13
17
|
|
|
@@ -155,15 +159,9 @@ class RandomEncoderBaseline:
|
|
|
155
159
|
Returns:
|
|
156
160
|
Cosine similarity matrix between the two sets of embeddings
|
|
157
161
|
"""
|
|
158
|
-
|
|
159
|
-
embeddings1
|
|
160
|
-
)
|
|
161
|
-
norm2 = np.linalg.norm(
|
|
162
|
-
embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True
|
|
162
|
+
return select_similarity(
|
|
163
|
+
embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name
|
|
163
164
|
)
|
|
164
|
-
normalized1 = embeddings1 / (norm1 + 1e-10)
|
|
165
|
-
normalized2 = embeddings2 / (norm2 + 1e-10)
|
|
166
|
-
return np.dot(normalized1, normalized2.T)
|
|
167
165
|
|
|
168
166
|
def similarity_pairwise(
|
|
169
167
|
self,
|
|
@@ -179,17 +177,9 @@ class RandomEncoderBaseline:
|
|
|
179
177
|
Returns:
|
|
180
178
|
Cosine similarity for each pair of embeddings
|
|
181
179
|
"""
|
|
182
|
-
|
|
183
|
-
embeddings1
|
|
184
|
-
)
|
|
185
|
-
norm2 = np.linalg.norm(
|
|
186
|
-
embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True
|
|
180
|
+
return select_pairwise_similarity(
|
|
181
|
+
embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name
|
|
187
182
|
)
|
|
188
|
-
normalized1 = embeddings1 / (norm1 + 1e-10)
|
|
189
|
-
normalized2 = embeddings2 / (norm2 + 1e-10)
|
|
190
|
-
normalized1 = np.asarray(normalized1)
|
|
191
|
-
normalized2 = np.asarray(normalized2)
|
|
192
|
-
return np.sum(normalized1 * normalized2, axis=1)
|
|
193
183
|
|
|
194
184
|
|
|
195
185
|
random_encoder_baseline = ModelMeta(
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from mteb.models.instruct_wrapper import InstructSentenceTransformerModel
|
|
2
|
+
from mteb.models.model_meta import ModelMeta
|
|
3
|
+
from mteb.models.models_protocols import PromptType
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def instruction_template(
|
|
7
|
+
instruction: str, prompt_type: PromptType | None = None
|
|
8
|
+
) -> str:
|
|
9
|
+
if not instruction or prompt_type == PromptType.document:
|
|
10
|
+
return ""
|
|
11
|
+
if isinstance(instruction, dict):
|
|
12
|
+
if prompt_type is None:
|
|
13
|
+
instruction = next(iter(instruction.values())) # TODO
|
|
14
|
+
else:
|
|
15
|
+
instruction = instruction[prompt_type]
|
|
16
|
+
return f"Instruct: {instruction}\nQuery:"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
training_data = {
|
|
20
|
+
"T2Retrieval",
|
|
21
|
+
"DuRetrieval",
|
|
22
|
+
"MMarcoReranking",
|
|
23
|
+
"CMedQAv2-reranking",
|
|
24
|
+
"NQ",
|
|
25
|
+
"MSMARCO",
|
|
26
|
+
"HotpotQA",
|
|
27
|
+
"MrTidyRetrieval",
|
|
28
|
+
"MIRACLRetrieval",
|
|
29
|
+
"CodeSearchNet",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
yuan_embedding_2_en = ModelMeta(
|
|
34
|
+
loader=InstructSentenceTransformerModel,
|
|
35
|
+
loader_kwargs=dict(
|
|
36
|
+
instruction_template=instruction_template,
|
|
37
|
+
apply_instruction_to_passages=False,
|
|
38
|
+
),
|
|
39
|
+
name="IEITYuan/Yuan-embedding-2.0-en",
|
|
40
|
+
languages=["eng-Latn"],
|
|
41
|
+
open_weights=True,
|
|
42
|
+
revision="b2fd15da3bcae3473c8529593825c15068f09fce",
|
|
43
|
+
release_date="2025-11-27",
|
|
44
|
+
n_parameters=595776512,
|
|
45
|
+
memory_usage_mb=2272,
|
|
46
|
+
embed_dim=1024,
|
|
47
|
+
max_tokens=2048,
|
|
48
|
+
license="apache-2.0",
|
|
49
|
+
reference="https://huggingface.co/IEITYuan/Yuan-embedding-2.0-en",
|
|
50
|
+
similarity_fn_name="cosine",
|
|
51
|
+
framework=["Sentence Transformers", "PyTorch"],
|
|
52
|
+
use_instructions=True,
|
|
53
|
+
public_training_code=None,
|
|
54
|
+
public_training_data=None,
|
|
55
|
+
training_datasets=training_data,
|
|
56
|
+
adapted_from="Qwen/Qwen3-Embedding-0.6B",
|
|
57
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Protocol
|
|
3
|
+
|
|
4
|
+
from mteb.types import Array, TopRankedDocumentsType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IndexEncoderSearchProtocol(Protocol):
|
|
8
|
+
"""Protocol for search backends used in encoder-based retrieval."""
|
|
9
|
+
|
|
10
|
+
def add_documents(
|
|
11
|
+
self,
|
|
12
|
+
embeddings: Array,
|
|
13
|
+
idxs: list[str],
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Add documents to the search backend.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
embeddings: Embeddings of the documents to add.
|
|
19
|
+
idxs: IDs of the documents to add.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def search(
|
|
23
|
+
self,
|
|
24
|
+
embeddings: Array,
|
|
25
|
+
top_k: int,
|
|
26
|
+
similarity_fn: Callable[[Array, Array], Array],
|
|
27
|
+
top_ranked: TopRankedDocumentsType | None = None,
|
|
28
|
+
query_idx_to_id: dict[int, str] | None = None,
|
|
29
|
+
) -> tuple[list[list[float]], list[list[int]]]:
|
|
30
|
+
"""Search through added corpus embeddings or rerank top-ranked documents.
|
|
31
|
+
|
|
32
|
+
Supports both full-corpus and reranking search modes:
|
|
33
|
+
- Full-corpus mode: `top_ranked=None`, uses added corpus embeddings.
|
|
34
|
+
- Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
embeddings: Query embeddings, shape (num_queries, dim).
|
|
38
|
+
top_k: Number of top results to return.
|
|
39
|
+
similarity_fn: Function to compute similarity between query and corpus.
|
|
40
|
+
top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking.
|
|
41
|
+
query_idx_to_id: Mapping of query index -> query_id. Used for reranking.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
A tuple (top_k_values, top_k_indices), for each query:
|
|
45
|
+
- top_k_values: List of top-k similarity scores.
|
|
46
|
+
- top_k_indices: List of indices of the top-k documents in the added corpus.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def clear(self) -> None:
|
|
50
|
+
"""Clear all stored documents and embeddings from the backend."""
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from mteb._requires_package import requires_package
|
|
8
|
+
from mteb.models.model_meta import ScoringFunction
|
|
9
|
+
from mteb.models.models_protocols import EncoderProtocol
|
|
10
|
+
from mteb.types import Array, TopRankedDocumentsType
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FaissSearchIndex:
|
|
16
|
+
"""FAISS-based backend for encoder-based search.
|
|
17
|
+
|
|
18
|
+
Supports both full-corpus retrieval and reranking (via `top_ranked`).
|
|
19
|
+
|
|
20
|
+
Notes:
|
|
21
|
+
- Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2).
|
|
22
|
+
- Expects embeddings to be normalized if cosine similarity is desired.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
_normalize: bool = False
|
|
26
|
+
|
|
27
|
+
def __init__(self, model: EncoderProtocol) -> None:
|
|
28
|
+
requires_package(
|
|
29
|
+
self,
|
|
30
|
+
"faiss",
|
|
31
|
+
"FAISS-based search",
|
|
32
|
+
install_instruction="pip install mteb[faiss-cpu]",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
import faiss
|
|
36
|
+
from faiss import IndexFlatIP, IndexFlatL2
|
|
37
|
+
|
|
38
|
+
# https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
|
|
39
|
+
if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT:
|
|
40
|
+
self.index_type = IndexFlatIP
|
|
41
|
+
elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE:
|
|
42
|
+
self.index_type = IndexFlatIP
|
|
43
|
+
self._normalize = True
|
|
44
|
+
elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN:
|
|
45
|
+
self.index_type = IndexFlatL2
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. "
|
|
49
|
+
f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
self.idxs: list[str] = []
|
|
53
|
+
self.index: faiss.Index | None = None
|
|
54
|
+
|
|
55
|
+
def add_documents(self, embeddings: Array, idxs: list[str]) -> None:
|
|
56
|
+
"""Add all document embeddings and their IDs to FAISS index."""
|
|
57
|
+
import faiss
|
|
58
|
+
|
|
59
|
+
if isinstance(embeddings, torch.Tensor):
|
|
60
|
+
embeddings = embeddings.detach().cpu().numpy()
|
|
61
|
+
|
|
62
|
+
embeddings = embeddings.astype(np.float32)
|
|
63
|
+
self.idxs.extend(idxs)
|
|
64
|
+
|
|
65
|
+
if self._normalize:
|
|
66
|
+
faiss.normalize_L2(embeddings)
|
|
67
|
+
|
|
68
|
+
dim = embeddings.shape[1]
|
|
69
|
+
if self.index is None:
|
|
70
|
+
self.index = self.index_type(dim)
|
|
71
|
+
|
|
72
|
+
self.index.add(embeddings)
|
|
73
|
+
logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.")
|
|
74
|
+
|
|
75
|
+
def search(
|
|
76
|
+
self,
|
|
77
|
+
embeddings: Array,
|
|
78
|
+
top_k: int,
|
|
79
|
+
similarity_fn: Callable[[Array, Array], Array],
|
|
80
|
+
top_ranked: TopRankedDocumentsType | None = None,
|
|
81
|
+
query_idx_to_id: dict[int, str] | None = None,
|
|
82
|
+
) -> tuple[list[list[float]], list[list[int]]]:
|
|
83
|
+
"""Search using FAISS."""
|
|
84
|
+
import faiss
|
|
85
|
+
|
|
86
|
+
if self.index is None:
|
|
87
|
+
raise ValueError("No index built. Call add_document() first.")
|
|
88
|
+
|
|
89
|
+
if isinstance(embeddings, torch.Tensor):
|
|
90
|
+
embeddings = embeddings.detach().cpu().numpy()
|
|
91
|
+
|
|
92
|
+
if self._normalize:
|
|
93
|
+
faiss.normalize_L2(embeddings)
|
|
94
|
+
|
|
95
|
+
if top_ranked is not None:
|
|
96
|
+
if query_idx_to_id is None:
|
|
97
|
+
raise ValueError("query_idx_to_id must be provided when reranking.")
|
|
98
|
+
|
|
99
|
+
similarities, ids = self._reranking(
|
|
100
|
+
embeddings,
|
|
101
|
+
top_k,
|
|
102
|
+
top_ranked=top_ranked,
|
|
103
|
+
query_idx_to_id=query_idx_to_id,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
similarities, ids = self.index.search(embeddings.astype(np.float32), top_k)
|
|
107
|
+
similarities = similarities.tolist()
|
|
108
|
+
ids = ids.tolist()
|
|
109
|
+
|
|
110
|
+
if issubclass(self.index_type, faiss.IndexFlatL2):
|
|
111
|
+
similarities = -np.sqrt(np.maximum(similarities, 0))
|
|
112
|
+
|
|
113
|
+
return similarities, ids
|
|
114
|
+
|
|
115
|
+
def _reranking(
|
|
116
|
+
self,
|
|
117
|
+
embeddings: Array,
|
|
118
|
+
top_k: int,
|
|
119
|
+
top_ranked: TopRankedDocumentsType | None = None,
|
|
120
|
+
query_idx_to_id: dict[int, str] | None = None,
|
|
121
|
+
) -> tuple[list[list[float]], list[list[int]]]:
|
|
122
|
+
doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}
|
|
123
|
+
scores_all: list[list[float]] = []
|
|
124
|
+
idxs_all: list[list[int]] = []
|
|
125
|
+
|
|
126
|
+
for query_idx, query_emb in enumerate(embeddings):
|
|
127
|
+
query_id = query_idx_to_id[query_idx]
|
|
128
|
+
ranked_ids = top_ranked.get(query_id)
|
|
129
|
+
if not ranked_ids:
|
|
130
|
+
logger.warning(f"No top-ranked documents for query {query_id}")
|
|
131
|
+
scores_all.append([])
|
|
132
|
+
idxs_all.append([])
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
|
|
136
|
+
d = self.index.d
|
|
137
|
+
candidate_embs = np.vstack(
|
|
138
|
+
[self.index.reconstruct(idx) for idx in candidate_indices]
|
|
139
|
+
)
|
|
140
|
+
sub_reranking_index = self.index_type(d)
|
|
141
|
+
sub_reranking_index.add(candidate_embs)
|
|
142
|
+
|
|
143
|
+
# Search returns scores and indices in one call
|
|
144
|
+
scores, local_indices = sub_reranking_index.search(
|
|
145
|
+
query_emb.reshape(1, -1).astype(np.float32),
|
|
146
|
+
min(top_k, len(candidate_indices)),
|
|
147
|
+
)
|
|
148
|
+
# faiss will output 2d arrays even for single query
|
|
149
|
+
scores_all.append(scores[0].tolist())
|
|
150
|
+
idxs_all.append(local_indices[0].tolist())
|
|
151
|
+
|
|
152
|
+
return scores_all, idxs_all
|
|
153
|
+
|
|
154
|
+
def clear(self) -> None:
|
|
155
|
+
"""Clear all stored documents and embeddings from the backend."""
|
|
156
|
+
self.index = None
|
|
157
|
+
self.idxs = []
|