mteb 2.5.2__py3-none-any.whl → 2.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +2 -0
- mteb/_create_dataloaders.py +17 -18
- mteb/_evaluators/any_sts_evaluator.py +3 -3
- mteb/_evaluators/clustering_evaluator.py +2 -2
- mteb/_evaluators/evaluator.py +4 -2
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +10 -8
- mteb/_evaluators/pair_classification_evaluator.py +5 -3
- mteb/_evaluators/retrieval_evaluator.py +2 -2
- mteb/_evaluators/retrieval_metrics.py +18 -17
- mteb/_evaluators/sklearn_evaluator.py +11 -10
- mteb/_evaluators/text/bitext_mining_evaluator.py +27 -18
- mteb/_evaluators/text/summarization_evaluator.py +23 -18
- mteb/_evaluators/zeroshot_classification_evaluator.py +5 -3
- mteb/abstasks/_data_filter/filters.py +1 -1
- mteb/abstasks/_data_filter/task_pipelines.py +3 -0
- mteb/abstasks/_statistics_calculation.py +18 -10
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +35 -28
- mteb/abstasks/aggregate_task_metadata.py +1 -9
- mteb/abstasks/aggregated_task.py +10 -29
- mteb/abstasks/classification.py +15 -10
- mteb/abstasks/clustering.py +19 -15
- mteb/abstasks/clustering_legacy.py +10 -10
- mteb/abstasks/image/image_text_pair_classification.py +7 -4
- mteb/abstasks/multilabel_classification.py +23 -19
- mteb/abstasks/pair_classification.py +20 -11
- mteb/abstasks/regression.py +4 -4
- mteb/abstasks/retrieval.py +28 -24
- mteb/abstasks/retrieval_dataset_loaders.py +2 -2
- mteb/abstasks/sts.py +8 -5
- mteb/abstasks/task_metadata.py +31 -33
- mteb/abstasks/text/bitext_mining.py +39 -28
- mteb/abstasks/text/reranking.py +8 -6
- mteb/abstasks/text/summarization.py +10 -5
- mteb/abstasks/zeroshot_classification.py +8 -4
- mteb/benchmarks/benchmark.py +4 -2
- mteb/benchmarks/benchmarks/__init__.py +4 -0
- mteb/benchmarks/benchmarks/benchmarks.py +112 -11
- mteb/benchmarks/get_benchmark.py +14 -55
- mteb/cache.py +182 -29
- mteb/cli/_display_tasks.py +2 -2
- mteb/cli/build_cli.py +110 -14
- mteb/cli/generate_model_card.py +43 -23
- mteb/deprecated_evaluator.py +63 -49
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
- mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
- mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
- mteb/evaluate.py +44 -33
- mteb/filter_tasks.py +25 -26
- mteb/get_tasks.py +29 -30
- mteb/languages/language_scripts.py +5 -3
- mteb/leaderboard/app.py +162 -34
- mteb/load_results.py +12 -12
- mteb/models/abs_encoder.py +10 -6
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +2 -2
- mteb/models/get_model_meta.py +21 -3
- mteb/models/instruct_wrapper.py +28 -8
- mteb/models/model_implementations/align_models.py +1 -1
- mteb/models/model_implementations/andersborges.py +4 -4
- mteb/models/model_implementations/ara_models.py +1 -1
- mteb/models/model_implementations/arctic_models.py +8 -8
- mteb/models/model_implementations/b1ade_models.py +1 -1
- mteb/models/model_implementations/bge_models.py +45 -21
- mteb/models/model_implementations/bica_model.py +3 -3
- mteb/models/model_implementations/blip2_models.py +2 -2
- mteb/models/model_implementations/blip_models.py +16 -16
- mteb/models/model_implementations/bm25.py +4 -4
- mteb/models/model_implementations/bmretriever_models.py +6 -4
- mteb/models/model_implementations/cadet_models.py +1 -1
- mteb/models/model_implementations/cde_models.py +11 -4
- mteb/models/model_implementations/clip_models.py +6 -6
- mteb/models/model_implementations/clips_models.py +3 -3
- mteb/models/model_implementations/codefuse_models.py +5 -5
- mteb/models/model_implementations/codesage_models.py +3 -3
- mteb/models/model_implementations/cohere_models.py +5 -5
- mteb/models/model_implementations/cohere_v.py +2 -2
- mteb/models/model_implementations/colpali_models.py +3 -3
- mteb/models/model_implementations/colqwen_models.py +8 -8
- mteb/models/model_implementations/colsmol_models.py +2 -2
- mteb/models/model_implementations/conan_models.py +1 -1
- mteb/models/model_implementations/dino_models.py +42 -42
- mteb/models/model_implementations/e5_instruct.py +23 -4
- mteb/models/model_implementations/e5_models.py +9 -9
- mteb/models/model_implementations/e5_v.py +6 -6
- mteb/models/model_implementations/eagerworks_models.py +1 -1
- mteb/models/model_implementations/emillykkejensen_models.py +6 -6
- mteb/models/model_implementations/en_code_retriever.py +1 -1
- mteb/models/model_implementations/euler_models.py +2 -2
- mteb/models/model_implementations/fa_models.py +9 -9
- mteb/models/model_implementations/facebookai.py +14 -2
- mteb/models/model_implementations/geogpt_models.py +1 -1
- mteb/models/model_implementations/gme_v_models.py +6 -5
- mteb/models/model_implementations/google_models.py +1 -1
- mteb/models/model_implementations/granite_vision_embedding_models.py +1 -1
- mteb/models/model_implementations/gritlm_models.py +2 -2
- mteb/models/model_implementations/gte_models.py +25 -13
- mteb/models/model_implementations/hinvec_models.py +1 -1
- mteb/models/model_implementations/ibm_granite_models.py +30 -6
- mteb/models/model_implementations/inf_models.py +2 -2
- mteb/models/model_implementations/jasper_models.py +2 -2
- mteb/models/model_implementations/jina_clip.py +48 -10
- mteb/models/model_implementations/jina_models.py +18 -11
- mteb/models/model_implementations/kblab.py +12 -6
- mteb/models/model_implementations/kennethenevoldsen_models.py +4 -4
- mteb/models/model_implementations/kfst.py +1 -1
- mteb/models/model_implementations/kowshik24_models.py +1 -1
- mteb/models/model_implementations/lgai_embedding_models.py +1 -1
- mteb/models/model_implementations/linq_models.py +1 -1
- mteb/models/model_implementations/listconranker.py +1 -1
- mteb/models/model_implementations/llm2clip_models.py +6 -6
- mteb/models/model_implementations/llm2vec_models.py +8 -8
- mteb/models/model_implementations/mcinext_models.py +4 -1
- mteb/models/model_implementations/mdbr_models.py +17 -3
- mteb/models/model_implementations/misc_models.py +68 -68
- mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
- mteb/models/model_implementations/mme5_models.py +1 -1
- mteb/models/model_implementations/moco_models.py +4 -4
- mteb/models/model_implementations/mod_models.py +1 -1
- mteb/models/model_implementations/model2vec_models.py +14 -14
- mteb/models/model_implementations/moka_models.py +1 -1
- mteb/models/model_implementations/nbailab.py +3 -3
- mteb/models/model_implementations/no_instruct_sentence_models.py +2 -2
- mteb/models/model_implementations/nomic_models.py +30 -15
- mteb/models/model_implementations/nomic_models_vision.py +1 -1
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +15 -9
- mteb/models/model_implementations/nvidia_models.py +151 -19
- mteb/models/model_implementations/octen_models.py +61 -2
- mteb/models/model_implementations/openclip_models.py +13 -13
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +5 -5
- mteb/models/model_implementations/ops_moa_models.py +1 -1
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
- mteb/models/model_implementations/pawan_models.py +1 -1
- mteb/models/model_implementations/piccolo_models.py +1 -1
- mteb/models/model_implementations/pixie_models.py +56 -0
- mteb/models/model_implementations/promptriever_models.py +4 -4
- mteb/models/model_implementations/pylate_models.py +10 -9
- mteb/models/model_implementations/qodo_models.py +2 -2
- mteb/models/model_implementations/qtack_models.py +1 -1
- mteb/models/model_implementations/qwen3_models.py +3 -3
- mteb/models/model_implementations/qzhou_models.py +2 -2
- mteb/models/model_implementations/random_baseline.py +3 -3
- mteb/models/model_implementations/rasgaard_models.py +2 -2
- mteb/models/model_implementations/reasonir_model.py +1 -1
- mteb/models/model_implementations/repllama_models.py +3 -3
- mteb/models/model_implementations/rerankers_custom.py +12 -6
- mteb/models/model_implementations/rerankers_monot5_based.py +17 -17
- mteb/models/model_implementations/richinfoai_models.py +1 -1
- mteb/models/model_implementations/ru_sentence_models.py +20 -20
- mteb/models/model_implementations/ruri_models.py +10 -10
- mteb/models/model_implementations/salesforce_models.py +3 -3
- mteb/models/model_implementations/samilpwc_models.py +1 -1
- mteb/models/model_implementations/sarashina_embedding_models.py +2 -2
- mteb/models/model_implementations/searchmap_models.py +1 -1
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
- mteb/models/model_implementations/sentence_transformers_models.py +124 -22
- mteb/models/model_implementations/shuu_model.py +1 -1
- mteb/models/model_implementations/siglip_models.py +20 -20
- mteb/models/model_implementations/slm_models.py +416 -0
- mteb/models/model_implementations/spartan8806_atles_champion.py +1 -1
- mteb/models/model_implementations/stella_models.py +17 -4
- mteb/models/model_implementations/tarka_models.py +2 -2
- mteb/models/model_implementations/text2vec_models.py +9 -3
- mteb/models/model_implementations/ua_sentence_models.py +1 -1
- mteb/models/model_implementations/uae_models.py +7 -1
- mteb/models/model_implementations/vdr_models.py +1 -1
- mteb/models/model_implementations/vi_vn_models.py +6 -6
- mteb/models/model_implementations/vlm2vec_models.py +3 -3
- mteb/models/model_implementations/voyage_models.py +84 -0
- mteb/models/model_implementations/voyage_v.py +9 -7
- mteb/models/model_implementations/youtu_models.py +1 -1
- mteb/models/model_implementations/yuan_models.py +1 -1
- mteb/models/model_implementations/yuan_models_en.py +1 -1
- mteb/models/model_meta.py +80 -31
- mteb/models/models_protocols.py +22 -6
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
- mteb/models/search_wrappers.py +33 -18
- mteb/models/sentence_transformer_wrapper.py +50 -25
- mteb/models/vllm_wrapper.py +327 -0
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +29 -21
- mteb/results/model_result.py +52 -22
- mteb/results/task_result.py +80 -58
- mteb/similarity_functions.py +11 -7
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +12 -12
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/eng/__init__.py +2 -0
- mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
- mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
- mteb/tasks/retrieval/kor/__init__.py +15 -1
- mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
- mteb/tasks/retrieval/multilingual/__init__.py +2 -0
- mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/tasks/retrieval/vie/__init__.py +14 -6
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
- mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
- mteb/types/__init__.py +2 -0
- mteb/types/_encoder_io.py +12 -0
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +9 -3
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/METADATA +15 -4
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/RECORD +240 -219
- mteb/models/model_implementations/mxbai_models.py +0 -111
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
mteb/models/search_wrappers.py
CHANGED
|
@@ -14,6 +14,7 @@ from mteb.types import (
|
|
|
14
14
|
Array,
|
|
15
15
|
BatchedInput,
|
|
16
16
|
CorpusDatasetType,
|
|
17
|
+
EncodeKwargs,
|
|
17
18
|
PromptType,
|
|
18
19
|
QueryDatasetType,
|
|
19
20
|
RetrievalOutputType,
|
|
@@ -50,7 +51,7 @@ class SearchEncoderWrapper:
|
|
|
50
51
|
task_metadata: TaskMetadata,
|
|
51
52
|
hf_split: str,
|
|
52
53
|
hf_subset: str,
|
|
53
|
-
encode_kwargs:
|
|
54
|
+
encode_kwargs: EncodeKwargs,
|
|
54
55
|
) -> None:
|
|
55
56
|
"""Index the corpus for retrieval.
|
|
56
57
|
|
|
@@ -88,7 +89,7 @@ class SearchEncoderWrapper:
|
|
|
88
89
|
hf_split: str,
|
|
89
90
|
hf_subset: str,
|
|
90
91
|
top_k: int,
|
|
91
|
-
encode_kwargs:
|
|
92
|
+
encode_kwargs: EncodeKwargs,
|
|
92
93
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
93
94
|
) -> RetrievalOutputType:
|
|
94
95
|
"""Search the corpus for the given queries.
|
|
@@ -200,7 +201,7 @@ class SearchEncoderWrapper:
|
|
|
200
201
|
# Reset the task corpus dataloader to None to free up memory
|
|
201
202
|
self.task_corpus = None
|
|
202
203
|
|
|
203
|
-
results = {qid: {} for qid in query_idx_to_id.values()}
|
|
204
|
+
results: RetrievalOutputType = {qid: {} for qid in query_idx_to_id.values()}
|
|
204
205
|
for qid in result_heaps:
|
|
205
206
|
for score, corpus_id in result_heaps[qid]:
|
|
206
207
|
results[qid][corpus_id] = score
|
|
@@ -215,16 +216,22 @@ class SearchEncoderWrapper:
|
|
|
215
216
|
hf_subset: str,
|
|
216
217
|
hf_split: str,
|
|
217
218
|
top_k: int,
|
|
218
|
-
encode_kwargs:
|
|
219
|
+
encode_kwargs: EncodeKwargs,
|
|
219
220
|
) -> dict[str, list[tuple[float, str]]]:
|
|
220
221
|
logger.info("Encoding Corpus in batches (this might take a while)...")
|
|
222
|
+
if self.task_corpus is None:
|
|
223
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
224
|
+
|
|
221
225
|
itr = range(0, len(self.task_corpus), self.corpus_chunk_size)
|
|
222
226
|
|
|
223
|
-
result_heaps
|
|
227
|
+
result_heaps: dict[str, list[tuple[float, str]]] = {
|
|
228
|
+
qid: [] for qid in query_idx_to_id.values()
|
|
229
|
+
}
|
|
224
230
|
for batch_num, corpus_start_idx in enumerate(itr):
|
|
225
231
|
logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
|
|
226
232
|
corpus_end_idx = min(
|
|
227
|
-
corpus_start_idx + self.corpus_chunk_size,
|
|
233
|
+
corpus_start_idx + self.corpus_chunk_size,
|
|
234
|
+
len(self.task_corpus),
|
|
228
235
|
)
|
|
229
236
|
sub_corpus = self.task_corpus.select(
|
|
230
237
|
range(corpus_start_idx, corpus_end_idx)
|
|
@@ -249,7 +256,7 @@ class SearchEncoderWrapper:
|
|
|
249
256
|
scores = self.model.similarity(query_embeddings, sub_corpus_embeddings)
|
|
250
257
|
|
|
251
258
|
# get top-k values
|
|
252
|
-
|
|
259
|
+
cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = torch.topk(
|
|
253
260
|
torch.as_tensor(scores),
|
|
254
261
|
min(
|
|
255
262
|
top_k + 1,
|
|
@@ -258,8 +265,8 @@ class SearchEncoderWrapper:
|
|
|
258
265
|
dim=1,
|
|
259
266
|
largest=True,
|
|
260
267
|
)
|
|
261
|
-
cos_scores_top_k_idx =
|
|
262
|
-
cos_scores_top_k_values =
|
|
268
|
+
cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
|
|
269
|
+
cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
|
|
263
270
|
|
|
264
271
|
sub_corpus_ids = list(sub_corpus_ids)
|
|
265
272
|
result_heaps = self._sort_full_corpus_results(
|
|
@@ -312,14 +319,18 @@ class SearchEncoderWrapper:
|
|
|
312
319
|
task_metadata: TaskMetadata,
|
|
313
320
|
hf_subset: str,
|
|
314
321
|
hf_split: str,
|
|
315
|
-
encode_kwargs:
|
|
322
|
+
encode_kwargs: EncodeKwargs,
|
|
316
323
|
) -> dict[str, list[tuple[float, str]]]:
|
|
317
324
|
"""Rerank documents based on pre-ranked documents.
|
|
318
325
|
|
|
319
326
|
Returns:
|
|
320
327
|
A dictionary mapping query IDs to a list of tuples, each containing a relevance score and a document ID.
|
|
321
328
|
"""
|
|
322
|
-
|
|
329
|
+
if self.task_corpus is None:
|
|
330
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
331
|
+
result_heaps: dict[str, list[tuple[float, str]]] = {
|
|
332
|
+
qid: [] for qid in query_idx_to_id.values()
|
|
333
|
+
}
|
|
323
334
|
doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
|
|
324
335
|
|
|
325
336
|
all_doc_embeddings = self.model.encode(
|
|
@@ -340,7 +351,8 @@ class SearchEncoderWrapper:
|
|
|
340
351
|
for query_idx, query_embedding in enumerate(query_embeddings):
|
|
341
352
|
query_id = query_idx_to_id[query_idx]
|
|
342
353
|
if query_id not in top_ranked:
|
|
343
|
-
|
|
354
|
+
msg = f"No pre-ranked documents found for query {query_id}"
|
|
355
|
+
logger.warning(msg)
|
|
344
356
|
continue
|
|
345
357
|
|
|
346
358
|
ranked_ids = top_ranked[query_id]
|
|
@@ -386,12 +398,12 @@ class SearchEncoderWrapper:
|
|
|
386
398
|
|
|
387
399
|
def _rerank_sort_results(
|
|
388
400
|
self,
|
|
389
|
-
result_heaps: list[tuple[float, str]],
|
|
401
|
+
result_heaps: dict[str, list[tuple[float, str]]],
|
|
390
402
|
query_id: str,
|
|
391
403
|
ranked_ids: list[str],
|
|
392
404
|
scores_top_k_idx: torch.Tensor,
|
|
393
405
|
scores_top_k_values: torch.Tensor,
|
|
394
|
-
) -> list[tuple[float, str]]:
|
|
406
|
+
) -> dict[str, list[tuple[float, str]]]:
|
|
395
407
|
"""Sort the heap into descending order list.
|
|
396
408
|
|
|
397
409
|
Returns:
|
|
@@ -459,7 +471,7 @@ class SearchCrossEncoderWrapper:
|
|
|
459
471
|
task_metadata: TaskMetadata,
|
|
460
472
|
hf_split: str,
|
|
461
473
|
hf_subset: str,
|
|
462
|
-
encode_kwargs:
|
|
474
|
+
encode_kwargs: EncodeKwargs,
|
|
463
475
|
) -> None:
|
|
464
476
|
"""Index the corpus for retrieval.
|
|
465
477
|
|
|
@@ -480,7 +492,7 @@ class SearchCrossEncoderWrapper:
|
|
|
480
492
|
hf_split: str,
|
|
481
493
|
hf_subset: str,
|
|
482
494
|
top_k: int,
|
|
483
|
-
encode_kwargs:
|
|
495
|
+
encode_kwargs: EncodeKwargs,
|
|
484
496
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
485
497
|
) -> RetrievalOutputType:
|
|
486
498
|
"""Search the corpus using the given queries.
|
|
@@ -502,6 +514,8 @@ class SearchCrossEncoderWrapper:
|
|
|
502
514
|
raise ValueError(
|
|
503
515
|
"CrossEncoder search requires top_ranked documents for reranking."
|
|
504
516
|
)
|
|
517
|
+
if self.task_corpus is None:
|
|
518
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
505
519
|
|
|
506
520
|
query_id_to_idx = {row["id"]: i for i, row in enumerate(queries)}
|
|
507
521
|
doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
|
|
@@ -511,7 +525,8 @@ class SearchCrossEncoderWrapper:
|
|
|
511
525
|
doc_pairs_ids: list[tuple[str, str]] = []
|
|
512
526
|
for query_id, corpus_ids in top_ranked.items():
|
|
513
527
|
if query_id not in top_ranked:
|
|
514
|
-
|
|
528
|
+
msg = f"No pre-ranked documents found for query {query_id}"
|
|
529
|
+
logger.warning(msg)
|
|
515
530
|
continue
|
|
516
531
|
|
|
517
532
|
query_idx = query_id_to_idx[query_id]
|
|
@@ -540,7 +555,7 @@ class SearchCrossEncoderWrapper:
|
|
|
540
555
|
hf_subset=hf_subset,
|
|
541
556
|
)
|
|
542
557
|
|
|
543
|
-
results = {qid: {} for qid in queries["id"]}
|
|
558
|
+
results: RetrievalOutputType = {qid: {} for qid in queries["id"]}
|
|
544
559
|
for (query_id, corpus_id), score in zip(doc_pairs_ids, predictions):
|
|
545
560
|
results[query_id][corpus_id] = float(score)
|
|
546
561
|
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import torch
|
|
8
9
|
from packaging.version import Version
|
|
9
10
|
from torch.utils.data import DataLoader
|
|
11
|
+
from typing_extensions import Unpack
|
|
10
12
|
|
|
11
13
|
from mteb._log_once import LogOnce
|
|
12
14
|
from mteb.models import ModelMeta
|
|
13
|
-
from mteb.types import Array, BatchedInput, PromptType
|
|
15
|
+
from mteb.types import Array, BatchedInput, EncodeKwargs, PromptType
|
|
14
16
|
|
|
15
17
|
from .abs_encoder import AbsEncoder
|
|
16
18
|
|
|
@@ -25,17 +27,18 @@ SENTENCE_TRANSFORMERS_QUERY_ENCODE_VERSION = "5.0.0"
|
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def sentence_transformers_loader(
|
|
28
|
-
model_name: str, revision: str | None = None, **kwargs
|
|
30
|
+
model_name: str, revision: str | None = None, device: str | None = None, **kwargs
|
|
29
31
|
) -> SentenceTransformerEncoderWrapper:
|
|
30
32
|
"""Loads a SentenceTransformer model and wraps it in a SentenceTransformerEncoderWrapper.
|
|
31
33
|
|
|
32
34
|
Args:
|
|
33
35
|
model_name: The name of the SentenceTransformer model to load.
|
|
34
36
|
revision: The revision of the model to load.
|
|
37
|
+
device: The device used to load the model.
|
|
35
38
|
kwargs: Additional arguments to pass to the SentenceTransformer model.
|
|
36
39
|
"""
|
|
37
40
|
return SentenceTransformerEncoderWrapper(
|
|
38
|
-
model=model_name, revision=revision, **kwargs
|
|
41
|
+
model=model_name, revision=revision, device=device, **kwargs
|
|
39
42
|
)
|
|
40
43
|
|
|
41
44
|
|
|
@@ -48,6 +51,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
48
51
|
self,
|
|
49
52
|
model: str | SentenceTransformer,
|
|
50
53
|
revision: str | None = None,
|
|
54
|
+
device: str | None = None,
|
|
51
55
|
model_prompts: dict[str, str] | None = None,
|
|
52
56
|
**kwargs,
|
|
53
57
|
) -> None:
|
|
@@ -56,6 +60,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
56
60
|
Args:
|
|
57
61
|
model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
|
|
58
62
|
revision: The revision of the model to use.
|
|
63
|
+
device: The device used to load the model.
|
|
59
64
|
model_prompts: A dictionary mapping task names to prompt names.
|
|
60
65
|
First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
|
|
61
66
|
then to the composed prompt of task type + prompt type, then to the specific task type prompt,
|
|
@@ -65,7 +70,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
65
70
|
from sentence_transformers import SentenceTransformer
|
|
66
71
|
|
|
67
72
|
if isinstance(model, str):
|
|
68
|
-
self.model = SentenceTransformer(
|
|
73
|
+
self.model = SentenceTransformer(
|
|
74
|
+
model, revision=revision, device=device, **kwargs
|
|
75
|
+
)
|
|
69
76
|
else:
|
|
70
77
|
self.model = model
|
|
71
78
|
|
|
@@ -75,9 +82,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
75
82
|
if built_in_prompts and not model_prompts:
|
|
76
83
|
model_prompts = built_in_prompts
|
|
77
84
|
elif model_prompts and built_in_prompts:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
85
|
+
msg = f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
|
|
86
|
+
logger.warning(msg)
|
|
87
|
+
warnings.warn(msg)
|
|
81
88
|
self.model.prompts = model_prompts
|
|
82
89
|
|
|
83
90
|
self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name(
|
|
@@ -86,9 +93,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
86
93
|
|
|
87
94
|
if invalid_prompts:
|
|
88
95
|
invalid_prompts = "\n".join(invalid_prompts)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
)
|
|
96
|
+
msg = f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
|
|
97
|
+
logger.warning(msg)
|
|
98
|
+
warnings.warn(msg)
|
|
92
99
|
|
|
93
100
|
if (
|
|
94
101
|
self.model_prompts
|
|
@@ -98,13 +105,15 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
98
105
|
or PromptType.document.value not in self.model_prompts
|
|
99
106
|
)
|
|
100
107
|
):
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
)
|
|
108
|
+
msg = f"SentenceTransformers that use prompts most often need to be configured with at least 'query' and 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
|
|
109
|
+
logger.warning(msg)
|
|
110
|
+
warnings.warn(msg)
|
|
105
111
|
|
|
112
|
+
def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:
|
|
113
|
+
"""Compute the similarity between two collections of embeddings."""
|
|
106
114
|
if hasattr(self.model, "similarity") and callable(self.model.similarity):
|
|
107
|
-
|
|
115
|
+
return self.model.similarity(embeddings1, embeddings2)
|
|
116
|
+
return super().similarity(embeddings1, embeddings2)
|
|
108
117
|
|
|
109
118
|
def encode(
|
|
110
119
|
self,
|
|
@@ -114,7 +123,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
114
123
|
hf_split: str,
|
|
115
124
|
hf_subset: str,
|
|
116
125
|
prompt_type: PromptType | None = None,
|
|
117
|
-
**kwargs:
|
|
126
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
118
127
|
) -> Array:
|
|
119
128
|
"""Encodes the given sentences using the encoder.
|
|
120
129
|
|
|
@@ -150,7 +159,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
150
159
|
prompt_name = None
|
|
151
160
|
if self.model_prompts is not None:
|
|
152
161
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
153
|
-
prompt = self.model_prompts.get(prompt_name, None)
|
|
162
|
+
prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
|
|
154
163
|
if prompt_name:
|
|
155
164
|
prompt_log = f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
|
|
156
165
|
else:
|
|
@@ -193,7 +202,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
193
202
|
hf_split: str,
|
|
194
203
|
hf_subset: str,
|
|
195
204
|
prompt_type: PromptType | None = None,
|
|
196
|
-
**kwargs:
|
|
205
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
197
206
|
) -> Array:
|
|
198
207
|
"""Encodes the given sentences using the encoder.
|
|
199
208
|
|
|
@@ -221,7 +230,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
221
230
|
prompt_name = None
|
|
222
231
|
if self.model_prompts is not None:
|
|
223
232
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
224
|
-
prompt = self.model_prompts.get(prompt_name, None)
|
|
233
|
+
prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
|
|
225
234
|
if prompt_name:
|
|
226
235
|
logger.info(
|
|
227
236
|
f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
|
|
@@ -234,7 +243,9 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
234
243
|
all_embeddings = []
|
|
235
244
|
for batch in inputs:
|
|
236
245
|
batch_column = next(iter(batch.keys()))
|
|
237
|
-
batched_input
|
|
246
|
+
batched_input: list[dict[str, Any]] = [
|
|
247
|
+
dict() for _ in range(len(batch[batch_column]))
|
|
248
|
+
]
|
|
238
249
|
|
|
239
250
|
# transform from {"text": [text1, text2], "image": [image1, image2]} to
|
|
240
251
|
# [{"text": text1, "image": image1}, {"text": text2, "image": image2}]
|
|
@@ -255,12 +266,24 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
255
266
|
|
|
256
267
|
|
|
257
268
|
class CrossEncoderWrapper:
|
|
258
|
-
"""Wrapper for CrossEncoder models.
|
|
269
|
+
"""Wrapper for CrossEncoder models.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
model: The CrossEncoder model to use. Can be a string (model name) or a CrossEncoder model.
|
|
273
|
+
revision: The revision of the model to use.
|
|
274
|
+
device: The device used to load the model.
|
|
275
|
+
query_prefix: A prefix to add to all queries.
|
|
276
|
+
passage_prefix: A prefix to add to all passages.
|
|
277
|
+
**kwargs: Additional arguments to pass to the CrossEncoder model.
|
|
278
|
+
"""
|
|
259
279
|
|
|
260
280
|
def __init__(
|
|
261
281
|
self,
|
|
262
282
|
model: CrossEncoder | str,
|
|
263
283
|
revision: str | None = None,
|
|
284
|
+
device: str | None = None,
|
|
285
|
+
query_prefix: str = "",
|
|
286
|
+
passage_prefix: str = "",
|
|
264
287
|
**kwargs,
|
|
265
288
|
) -> None:
|
|
266
289
|
from sentence_transformers import CrossEncoder
|
|
@@ -268,9 +291,11 @@ class CrossEncoderWrapper:
|
|
|
268
291
|
if isinstance(model, CrossEncoder):
|
|
269
292
|
self.model = model
|
|
270
293
|
elif isinstance(model, str):
|
|
271
|
-
self.model = CrossEncoder(model, revision=revision, **kwargs)
|
|
294
|
+
self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)
|
|
272
295
|
|
|
273
296
|
self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
|
|
297
|
+
self.query_prefix = query_prefix
|
|
298
|
+
self.passage_prefix = passage_prefix
|
|
274
299
|
|
|
275
300
|
def predict(
|
|
276
301
|
self,
|
|
@@ -281,7 +306,7 @@ class CrossEncoderWrapper:
|
|
|
281
306
|
hf_split: str,
|
|
282
307
|
hf_subset: str,
|
|
283
308
|
prompt_type: PromptType | None = None,
|
|
284
|
-
**kwargs:
|
|
309
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
285
310
|
) -> Array:
|
|
286
311
|
"""Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
|
|
287
312
|
|
|
@@ -299,10 +324,10 @@ class CrossEncoderWrapper:
|
|
|
299
324
|
The predicted relevance scores for each inputs pair.
|
|
300
325
|
"""
|
|
301
326
|
all_queries_with_instructions = [
|
|
302
|
-
text for batch in inputs1 for text in batch["text"]
|
|
327
|
+
self.query_prefix + text for batch in inputs1 for text in batch["text"]
|
|
303
328
|
]
|
|
304
329
|
all_corpus_with_instructions = [
|
|
305
|
-
text for batch in inputs2 for text in batch["text"]
|
|
330
|
+
self.passage_prefix + text for batch in inputs2 for text in batch["text"]
|
|
306
331
|
]
|
|
307
332
|
|
|
308
333
|
return self.model.predict(
|