mteb 2.5.2__py3-none-any.whl → 2.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +2 -0
- mteb/_create_dataloaders.py +17 -18
- mteb/_evaluators/any_sts_evaluator.py +3 -3
- mteb/_evaluators/clustering_evaluator.py +2 -2
- mteb/_evaluators/evaluator.py +4 -2
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +10 -8
- mteb/_evaluators/pair_classification_evaluator.py +5 -3
- mteb/_evaluators/retrieval_evaluator.py +2 -2
- mteb/_evaluators/retrieval_metrics.py +18 -17
- mteb/_evaluators/sklearn_evaluator.py +11 -10
- mteb/_evaluators/text/bitext_mining_evaluator.py +27 -18
- mteb/_evaluators/text/summarization_evaluator.py +23 -18
- mteb/_evaluators/zeroshot_classification_evaluator.py +5 -3
- mteb/abstasks/_data_filter/filters.py +1 -1
- mteb/abstasks/_data_filter/task_pipelines.py +3 -0
- mteb/abstasks/_statistics_calculation.py +18 -10
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +35 -28
- mteb/abstasks/aggregate_task_metadata.py +1 -9
- mteb/abstasks/aggregated_task.py +10 -29
- mteb/abstasks/classification.py +15 -10
- mteb/abstasks/clustering.py +19 -15
- mteb/abstasks/clustering_legacy.py +10 -10
- mteb/abstasks/image/image_text_pair_classification.py +7 -4
- mteb/abstasks/multilabel_classification.py +23 -19
- mteb/abstasks/pair_classification.py +20 -11
- mteb/abstasks/regression.py +4 -4
- mteb/abstasks/retrieval.py +28 -24
- mteb/abstasks/retrieval_dataset_loaders.py +2 -2
- mteb/abstasks/sts.py +8 -5
- mteb/abstasks/task_metadata.py +31 -33
- mteb/abstasks/text/bitext_mining.py +39 -28
- mteb/abstasks/text/reranking.py +8 -6
- mteb/abstasks/text/summarization.py +10 -5
- mteb/abstasks/zeroshot_classification.py +8 -4
- mteb/benchmarks/benchmark.py +4 -2
- mteb/benchmarks/benchmarks/__init__.py +4 -0
- mteb/benchmarks/benchmarks/benchmarks.py +112 -11
- mteb/benchmarks/get_benchmark.py +14 -55
- mteb/cache.py +182 -29
- mteb/cli/_display_tasks.py +2 -2
- mteb/cli/build_cli.py +110 -14
- mteb/cli/generate_model_card.py +43 -23
- mteb/deprecated_evaluator.py +63 -49
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
- mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
- mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
- mteb/evaluate.py +44 -33
- mteb/filter_tasks.py +25 -26
- mteb/get_tasks.py +29 -30
- mteb/languages/language_scripts.py +5 -3
- mteb/leaderboard/app.py +162 -34
- mteb/load_results.py +12 -12
- mteb/models/abs_encoder.py +10 -6
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +2 -2
- mteb/models/get_model_meta.py +21 -3
- mteb/models/instruct_wrapper.py +28 -8
- mteb/models/model_implementations/align_models.py +1 -1
- mteb/models/model_implementations/andersborges.py +4 -4
- mteb/models/model_implementations/ara_models.py +1 -1
- mteb/models/model_implementations/arctic_models.py +8 -8
- mteb/models/model_implementations/b1ade_models.py +1 -1
- mteb/models/model_implementations/bge_models.py +45 -21
- mteb/models/model_implementations/bica_model.py +3 -3
- mteb/models/model_implementations/blip2_models.py +2 -2
- mteb/models/model_implementations/blip_models.py +16 -16
- mteb/models/model_implementations/bm25.py +4 -4
- mteb/models/model_implementations/bmretriever_models.py +6 -4
- mteb/models/model_implementations/cadet_models.py +1 -1
- mteb/models/model_implementations/cde_models.py +11 -4
- mteb/models/model_implementations/clip_models.py +6 -6
- mteb/models/model_implementations/clips_models.py +3 -3
- mteb/models/model_implementations/codefuse_models.py +5 -5
- mteb/models/model_implementations/codesage_models.py +3 -3
- mteb/models/model_implementations/cohere_models.py +5 -5
- mteb/models/model_implementations/cohere_v.py +2 -2
- mteb/models/model_implementations/colpali_models.py +3 -3
- mteb/models/model_implementations/colqwen_models.py +8 -8
- mteb/models/model_implementations/colsmol_models.py +2 -2
- mteb/models/model_implementations/conan_models.py +1 -1
- mteb/models/model_implementations/dino_models.py +42 -42
- mteb/models/model_implementations/e5_instruct.py +23 -4
- mteb/models/model_implementations/e5_models.py +9 -9
- mteb/models/model_implementations/e5_v.py +6 -6
- mteb/models/model_implementations/eagerworks_models.py +1 -1
- mteb/models/model_implementations/emillykkejensen_models.py +6 -6
- mteb/models/model_implementations/en_code_retriever.py +1 -1
- mteb/models/model_implementations/euler_models.py +2 -2
- mteb/models/model_implementations/fa_models.py +9 -9
- mteb/models/model_implementations/facebookai.py +14 -2
- mteb/models/model_implementations/geogpt_models.py +1 -1
- mteb/models/model_implementations/gme_v_models.py +6 -5
- mteb/models/model_implementations/google_models.py +1 -1
- mteb/models/model_implementations/granite_vision_embedding_models.py +1 -1
- mteb/models/model_implementations/gritlm_models.py +2 -2
- mteb/models/model_implementations/gte_models.py +25 -13
- mteb/models/model_implementations/hinvec_models.py +1 -1
- mteb/models/model_implementations/ibm_granite_models.py +30 -6
- mteb/models/model_implementations/inf_models.py +2 -2
- mteb/models/model_implementations/jasper_models.py +2 -2
- mteb/models/model_implementations/jina_clip.py +48 -10
- mteb/models/model_implementations/jina_models.py +18 -11
- mteb/models/model_implementations/kblab.py +12 -6
- mteb/models/model_implementations/kennethenevoldsen_models.py +4 -4
- mteb/models/model_implementations/kfst.py +1 -1
- mteb/models/model_implementations/kowshik24_models.py +1 -1
- mteb/models/model_implementations/lgai_embedding_models.py +1 -1
- mteb/models/model_implementations/linq_models.py +1 -1
- mteb/models/model_implementations/listconranker.py +1 -1
- mteb/models/model_implementations/llm2clip_models.py +6 -6
- mteb/models/model_implementations/llm2vec_models.py +8 -8
- mteb/models/model_implementations/mcinext_models.py +4 -1
- mteb/models/model_implementations/mdbr_models.py +17 -3
- mteb/models/model_implementations/misc_models.py +68 -68
- mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
- mteb/models/model_implementations/mme5_models.py +1 -1
- mteb/models/model_implementations/moco_models.py +4 -4
- mteb/models/model_implementations/mod_models.py +1 -1
- mteb/models/model_implementations/model2vec_models.py +14 -14
- mteb/models/model_implementations/moka_models.py +1 -1
- mteb/models/model_implementations/nbailab.py +3 -3
- mteb/models/model_implementations/no_instruct_sentence_models.py +2 -2
- mteb/models/model_implementations/nomic_models.py +30 -15
- mteb/models/model_implementations/nomic_models_vision.py +1 -1
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +15 -9
- mteb/models/model_implementations/nvidia_models.py +151 -19
- mteb/models/model_implementations/octen_models.py +61 -2
- mteb/models/model_implementations/openclip_models.py +13 -13
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +5 -5
- mteb/models/model_implementations/ops_moa_models.py +1 -1
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
- mteb/models/model_implementations/pawan_models.py +1 -1
- mteb/models/model_implementations/piccolo_models.py +1 -1
- mteb/models/model_implementations/pixie_models.py +56 -0
- mteb/models/model_implementations/promptriever_models.py +4 -4
- mteb/models/model_implementations/pylate_models.py +10 -9
- mteb/models/model_implementations/qodo_models.py +2 -2
- mteb/models/model_implementations/qtack_models.py +1 -1
- mteb/models/model_implementations/qwen3_models.py +3 -3
- mteb/models/model_implementations/qzhou_models.py +2 -2
- mteb/models/model_implementations/random_baseline.py +3 -3
- mteb/models/model_implementations/rasgaard_models.py +2 -2
- mteb/models/model_implementations/reasonir_model.py +1 -1
- mteb/models/model_implementations/repllama_models.py +3 -3
- mteb/models/model_implementations/rerankers_custom.py +12 -6
- mteb/models/model_implementations/rerankers_monot5_based.py +17 -17
- mteb/models/model_implementations/richinfoai_models.py +1 -1
- mteb/models/model_implementations/ru_sentence_models.py +20 -20
- mteb/models/model_implementations/ruri_models.py +10 -10
- mteb/models/model_implementations/salesforce_models.py +3 -3
- mteb/models/model_implementations/samilpwc_models.py +1 -1
- mteb/models/model_implementations/sarashina_embedding_models.py +2 -2
- mteb/models/model_implementations/searchmap_models.py +1 -1
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
- mteb/models/model_implementations/sentence_transformers_models.py +124 -22
- mteb/models/model_implementations/shuu_model.py +1 -1
- mteb/models/model_implementations/siglip_models.py +20 -20
- mteb/models/model_implementations/slm_models.py +416 -0
- mteb/models/model_implementations/spartan8806_atles_champion.py +1 -1
- mteb/models/model_implementations/stella_models.py +17 -4
- mteb/models/model_implementations/tarka_models.py +2 -2
- mteb/models/model_implementations/text2vec_models.py +9 -3
- mteb/models/model_implementations/ua_sentence_models.py +1 -1
- mteb/models/model_implementations/uae_models.py +7 -1
- mteb/models/model_implementations/vdr_models.py +1 -1
- mteb/models/model_implementations/vi_vn_models.py +6 -6
- mteb/models/model_implementations/vlm2vec_models.py +3 -3
- mteb/models/model_implementations/voyage_models.py +84 -0
- mteb/models/model_implementations/voyage_v.py +9 -7
- mteb/models/model_implementations/youtu_models.py +1 -1
- mteb/models/model_implementations/yuan_models.py +1 -1
- mteb/models/model_implementations/yuan_models_en.py +1 -1
- mteb/models/model_meta.py +80 -31
- mteb/models/models_protocols.py +22 -6
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
- mteb/models/search_wrappers.py +33 -18
- mteb/models/sentence_transformer_wrapper.py +50 -25
- mteb/models/vllm_wrapper.py +327 -0
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +29 -21
- mteb/results/model_result.py +52 -22
- mteb/results/task_result.py +80 -58
- mteb/similarity_functions.py +11 -7
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +12 -12
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/eng/__init__.py +2 -0
- mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
- mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
- mteb/tasks/retrieval/kor/__init__.py +15 -1
- mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
- mteb/tasks/retrieval/multilingual/__init__.py +2 -0
- mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/tasks/retrieval/vie/__init__.py +14 -6
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
- mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
- mteb/types/__init__.py +2 -0
- mteb/types/_encoder_io.py +12 -0
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +9 -3
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/METADATA +15 -4
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/RECORD +240 -219
- mteb/models/model_implementations/mxbai_models.py +0 -111
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
mteb/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ from importlib.metadata import version
|
|
|
3
3
|
from mteb import types
|
|
4
4
|
from mteb.abstasks import AbsTask
|
|
5
5
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
6
|
+
from mteb.cache import ResultCache
|
|
6
7
|
from mteb.deprecated_evaluator import MTEB
|
|
7
8
|
from mteb.evaluate import evaluate
|
|
8
9
|
from mteb.filter_tasks import filter_tasks
|
|
@@ -33,6 +34,7 @@ __all__ = [
|
|
|
33
34
|
"CrossEncoderProtocol",
|
|
34
35
|
"EncoderProtocol",
|
|
35
36
|
"IndexEncoderSearchProtocol",
|
|
37
|
+
"ResultCache",
|
|
36
38
|
"SearchProtocol",
|
|
37
39
|
"SentenceTransformerEncoderWrapper",
|
|
38
40
|
"TaskMetadata",
|
mteb/_create_dataloaders.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import warnings
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any, cast
|
|
4
5
|
|
|
@@ -22,7 +23,7 @@ logger = logging.getLogger(__name__)
|
|
|
22
23
|
def _create_dataloader_from_texts(
|
|
23
24
|
text: list[str],
|
|
24
25
|
batch_size: int = 32,
|
|
25
|
-
**kwargs:
|
|
26
|
+
**kwargs: Any,
|
|
26
27
|
) -> DataLoader[TextInput]:
|
|
27
28
|
"""Create a dataloader from a list of text.
|
|
28
29
|
|
|
@@ -113,11 +114,8 @@ def _create_text_dataloader_for_queries(
|
|
|
113
114
|
)
|
|
114
115
|
|
|
115
116
|
|
|
116
|
-
_warned_about_user_role = False
|
|
117
|
-
|
|
118
|
-
|
|
119
117
|
def _convert_conv_history_to_query(
|
|
120
|
-
row: dict[str, list[str] | Conversation],
|
|
118
|
+
row: dict[str, str | list[str] | Conversation],
|
|
121
119
|
) -> dict[str, str | Conversation]:
|
|
122
120
|
"""Convert a conversation history to a single query string.
|
|
123
121
|
|
|
@@ -127,21 +125,18 @@ def _convert_conv_history_to_query(
|
|
|
127
125
|
Returns:
|
|
128
126
|
The updated row with the "query" and "text" fields set to the conversation string, and the "conversation" field set to the list of ConversationTurn.
|
|
129
127
|
"""
|
|
130
|
-
global _warned_about_user_role
|
|
131
|
-
|
|
132
128
|
conversation = row["text"]
|
|
133
129
|
# if it's a list of strings, just join them
|
|
134
130
|
if isinstance(conversation, list) and isinstance(conversation[0], str):
|
|
135
|
-
|
|
136
|
-
conv_str = "; ".join(
|
|
131
|
+
conversation_ = cast(list[str], conversation)
|
|
132
|
+
conv_str = "; ".join(conversation_)
|
|
137
133
|
current_conversation = [
|
|
138
|
-
ConversationTurn(role="user", content=message) for message in
|
|
134
|
+
ConversationTurn(role="user", content=message) for message in conversation_
|
|
139
135
|
]
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
_warned_about_user_role = True
|
|
136
|
+
warnings.warn(
|
|
137
|
+
"Conversations are a list of strings. Used 'user' role for all turns.",
|
|
138
|
+
category=UserWarning,
|
|
139
|
+
)
|
|
145
140
|
# otherwise, it's a list of dictionaries, which we need to convert to strings
|
|
146
141
|
elif isinstance(conversation, list) and isinstance(conversation[0], dict):
|
|
147
142
|
conv = []
|
|
@@ -178,7 +173,7 @@ def _convert_conv_history_to_query(
|
|
|
178
173
|
|
|
179
174
|
row["text"] = conv_str
|
|
180
175
|
row["conversation"] = current_conversation
|
|
181
|
-
return row
|
|
176
|
+
return cast(dict[str, str | list[ConversationTurn]], row)
|
|
182
177
|
|
|
183
178
|
|
|
184
179
|
def _create_dataloader_for_queries_conversation(
|
|
@@ -196,7 +191,8 @@ def _create_dataloader_for_queries_conversation(
|
|
|
196
191
|
"""
|
|
197
192
|
return DataLoader(
|
|
198
193
|
queries.map(
|
|
199
|
-
_convert_conv_history_to_query,
|
|
194
|
+
_convert_conv_history_to_query,
|
|
195
|
+
desc="Converting conversations to queries",
|
|
200
196
|
),
|
|
201
197
|
collate_fn=_custom_collate_fn,
|
|
202
198
|
batch_size=batch_size,
|
|
@@ -366,6 +362,9 @@ def _create_document_dataloader(
|
|
|
366
362
|
task_metadata: Metadata of the task to determine the document type.
|
|
367
363
|
input_column: The column to use as input. If None, it will use the first column that matches the modality.
|
|
368
364
|
batch_size: Batch size for the dataloader.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
A dataloader for the documents.
|
|
369
368
|
"""
|
|
370
369
|
document_type = task_metadata.get_modalities(PromptType.document)
|
|
371
370
|
if document_type == ["text"]: # text only
|
|
@@ -388,7 +387,7 @@ def create_dataloader(
|
|
|
388
387
|
prompt_type: PromptType | None = None,
|
|
389
388
|
input_column: str | None = None,
|
|
390
389
|
batch_size: int = 32,
|
|
391
|
-
**kwargs:
|
|
390
|
+
**kwargs: Any,
|
|
392
391
|
) -> DataLoader[BatchedInput]:
|
|
393
392
|
"""Create a dataloader from a dataset.
|
|
394
393
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import TypedDict
|
|
3
3
|
|
|
4
4
|
from datasets import Dataset
|
|
5
5
|
from sklearn.metrics.pairwise import (
|
|
@@ -12,7 +12,7 @@ from mteb._create_dataloaders import create_dataloader
|
|
|
12
12
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
13
13
|
from mteb.models import EncoderProtocol
|
|
14
14
|
from mteb.similarity_functions import compute_pairwise_similarity
|
|
15
|
-
from mteb.types import PromptType
|
|
15
|
+
from mteb.types import EncodeKwargs, PromptType
|
|
16
16
|
|
|
17
17
|
from .evaluator import Evaluator
|
|
18
18
|
|
|
@@ -60,7 +60,7 @@ class AnySTSEvaluator(Evaluator):
|
|
|
60
60
|
self,
|
|
61
61
|
model: EncoderProtocol,
|
|
62
62
|
*,
|
|
63
|
-
encode_kwargs:
|
|
63
|
+
encode_kwargs: EncodeKwargs,
|
|
64
64
|
) -> STSEvaluatorScores:
|
|
65
65
|
logger.info("Running semantic similarity - Encoding samples (1/2)")
|
|
66
66
|
embeddings1 = model.encode(
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any
|
|
3
2
|
|
|
4
3
|
from datasets import Dataset
|
|
5
4
|
from sklearn import cluster
|
|
@@ -7,6 +6,7 @@ from sklearn import cluster
|
|
|
7
6
|
from mteb._create_dataloaders import create_dataloader
|
|
8
7
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
9
8
|
from mteb.models import EncoderProtocol
|
|
9
|
+
from mteb.types import EncodeKwargs
|
|
10
10
|
|
|
11
11
|
from .evaluator import Evaluator
|
|
12
12
|
|
|
@@ -38,7 +38,7 @@ class ClusteringEvaluator(Evaluator):
|
|
|
38
38
|
self,
|
|
39
39
|
model: EncoderProtocol,
|
|
40
40
|
*,
|
|
41
|
-
encode_kwargs:
|
|
41
|
+
encode_kwargs: EncodeKwargs,
|
|
42
42
|
) -> list[int]:
|
|
43
43
|
data_loader = create_dataloader(
|
|
44
44
|
self.dataset,
|
mteb/_evaluators/evaluator.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Iterable, Mapping
|
|
2
3
|
from typing import Any
|
|
3
4
|
|
|
4
5
|
from mteb.abstasks.abstask import _set_seed
|
|
5
6
|
from mteb.models import EncoderProtocol
|
|
7
|
+
from mteb.types import EncodeKwargs
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class Evaluator(ABC):
|
|
@@ -17,8 +19,8 @@ class Evaluator(ABC):
|
|
|
17
19
|
|
|
18
20
|
@abstractmethod
|
|
19
21
|
def __call__(
|
|
20
|
-
self, model: EncoderProtocol, *, encode_kwargs:
|
|
21
|
-
) ->
|
|
22
|
+
self, model: EncoderProtocol, *, encode_kwargs: EncodeKwargs
|
|
23
|
+
) -> Mapping[str, float] | Iterable[Any]:
|
|
22
24
|
"""This is called during training to evaluate the model.
|
|
23
25
|
|
|
24
26
|
It returns scores.
|
|
@@ -1,20 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
from collections.abc import Sequence
|
|
4
5
|
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
import torch.nn.functional as F
|
|
8
|
-
from datasets import Dataset
|
|
9
9
|
from torch.utils.data import DataLoader
|
|
10
10
|
|
|
11
11
|
from mteb._create_dataloaders import (
|
|
12
|
+
_create_dataloader_from_texts,
|
|
12
13
|
_transform_image_to_rgb,
|
|
13
14
|
)
|
|
14
15
|
from mteb._evaluators.evaluator import Evaluator
|
|
15
16
|
from mteb._requires_package import requires_image_dependencies
|
|
16
17
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
17
18
|
from mteb.models.models_protocols import EncoderProtocol
|
|
19
|
+
from mteb.types import EncodeKwargs
|
|
18
20
|
|
|
19
21
|
if TYPE_CHECKING:
|
|
20
22
|
from PIL.Image import Image
|
|
@@ -61,8 +63,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
61
63
|
def __init__(
|
|
62
64
|
self,
|
|
63
65
|
dataset,
|
|
64
|
-
images_column_names: str |
|
|
65
|
-
texts_column_names: str |
|
|
66
|
+
images_column_names: str | Sequence[str],
|
|
67
|
+
texts_column_names: str | Sequence[str],
|
|
66
68
|
num_images_per_sample: int,
|
|
67
69
|
num_texts_per_sample: int,
|
|
68
70
|
task_metadata: TaskMetadata,
|
|
@@ -82,10 +84,11 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
82
84
|
self.hf_split = hf_split
|
|
83
85
|
self.hf_subset = hf_subset
|
|
84
86
|
|
|
85
|
-
def __call__(
|
|
87
|
+
def __call__( # type: ignore[override]
|
|
86
88
|
self,
|
|
87
89
|
model: EncoderProtocol,
|
|
88
|
-
|
|
90
|
+
*,
|
|
91
|
+
encode_kwargs: EncodeKwargs,
|
|
89
92
|
) -> list[torch.Tensor]:
|
|
90
93
|
images = []
|
|
91
94
|
if isinstance(self.images_column_names, str):
|
|
@@ -106,8 +109,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
106
109
|
texts.append(row[col])
|
|
107
110
|
|
|
108
111
|
text_embeddings = model.encode(
|
|
109
|
-
|
|
110
|
-
|
|
112
|
+
_create_dataloader_from_texts(
|
|
113
|
+
texts,
|
|
111
114
|
**encode_kwargs,
|
|
112
115
|
),
|
|
113
116
|
task_metadata=self.task_metadata,
|
|
@@ -128,7 +131,6 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
128
131
|
DataLoader(
|
|
129
132
|
CustomImageDataset(images),
|
|
130
133
|
collate_fn=lambda x: {"image": [item["image"] for item in x]},
|
|
131
|
-
**encode_kwargs,
|
|
132
134
|
),
|
|
133
135
|
task_metadata=self.task_metadata,
|
|
134
136
|
hf_subset=self.hf_subset,
|
|
@@ -14,7 +14,7 @@ from mteb._evaluators.evaluator import Evaluator
|
|
|
14
14
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
15
15
|
from mteb.models import EncoderProtocol
|
|
16
16
|
from mteb.similarity_functions import compute_pairwise_similarity
|
|
17
|
-
from mteb.types import PromptType
|
|
17
|
+
from mteb.types import EncodeKwargs, PromptType
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger(__name__)
|
|
20
20
|
|
|
@@ -85,7 +85,7 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
85
85
|
def __call__(
|
|
86
86
|
self,
|
|
87
87
|
model: EncoderProtocol,
|
|
88
|
-
encode_kwargs:
|
|
88
|
+
encode_kwargs: EncodeKwargs,
|
|
89
89
|
) -> PairClassificationDistances:
|
|
90
90
|
logger.info("Running pair classification - Encoding samples (1/2)")
|
|
91
91
|
embeddings1 = model.encode(
|
|
@@ -148,7 +148,9 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
148
148
|
hf_subset: str,
|
|
149
149
|
**encode_kwargs: Any,
|
|
150
150
|
) -> np.ndarray:
|
|
151
|
-
index_map
|
|
151
|
+
index_map = {}
|
|
152
|
+
all_unique_texts: list[str] = []
|
|
153
|
+
all_texts_indexes = []
|
|
152
154
|
for text in all_texts:
|
|
153
155
|
text_hash = hash(text)
|
|
154
156
|
if text_hash not in index_map:
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Any
|
|
4
3
|
|
|
5
4
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
6
5
|
from mteb.models import SearchProtocol
|
|
7
6
|
from mteb.types import (
|
|
8
7
|
CorpusDatasetType,
|
|
8
|
+
EncodeKwargs,
|
|
9
9
|
QueryDatasetType,
|
|
10
10
|
RelevantDocumentsType,
|
|
11
11
|
RetrievalEvaluationResult,
|
|
@@ -48,7 +48,7 @@ class RetrievalEvaluator(Evaluator):
|
|
|
48
48
|
def __call__( # type: ignore[override]
|
|
49
49
|
self,
|
|
50
50
|
search_model: SearchProtocol,
|
|
51
|
-
encode_kwargs:
|
|
51
|
+
encode_kwargs: EncodeKwargs,
|
|
52
52
|
) -> RetrievalOutputType:
|
|
53
53
|
logger.info("Running retrieval task - Indexing corpus...")
|
|
54
54
|
search_model.index(
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import defaultdict
|
|
3
|
+
from collections.abc import Mapping
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|
|
15
16
|
|
|
16
17
|
def mrr(
|
|
17
18
|
qrels: RelevantDocumentsType,
|
|
18
|
-
results:
|
|
19
|
+
results: Mapping[str, Mapping[str, float]],
|
|
19
20
|
k_values: list[int],
|
|
20
21
|
) -> dict[str, list[float]]:
|
|
21
22
|
mrr_metrics = defaultdict(list)
|
|
@@ -32,7 +33,7 @@ def mrr(
|
|
|
32
33
|
doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
|
|
33
34
|
}
|
|
34
35
|
for k in k_values:
|
|
35
|
-
rr = 0
|
|
36
|
+
rr = 0.0
|
|
36
37
|
for rank, hit in enumerate(top_hits[query_id][0:k]):
|
|
37
38
|
if hit[0] in query_relevant_docs:
|
|
38
39
|
rr = 1.0 / (rank + 1)
|
|
@@ -45,8 +46,8 @@ def recall_cap(
|
|
|
45
46
|
qrels: RelevantDocumentsType,
|
|
46
47
|
results: dict[str, dict[str, float]],
|
|
47
48
|
k_values: list[int],
|
|
48
|
-
) -> dict[str, list[float]]:
|
|
49
|
-
capped_recall = defaultdict(list)
|
|
49
|
+
) -> dict[str, list[float | None]]:
|
|
50
|
+
capped_recall: dict[str, list[float | None]] = defaultdict(list)
|
|
50
51
|
|
|
51
52
|
k_max = max(k_values)
|
|
52
53
|
|
|
@@ -139,7 +140,7 @@ def calculate_pmrr(original_run, new_run, changed_qrels):
|
|
|
139
140
|
changes = []
|
|
140
141
|
for qid in changed_qrels.keys():
|
|
141
142
|
if qid + "-og" not in original_run or qid + "-changed" not in new_run:
|
|
142
|
-
|
|
143
|
+
logger.warning(f"Query {qid} not found in the runs for calculating p-MRR")
|
|
143
144
|
continue
|
|
144
145
|
original_qid_run = original_run[qid + "-og"]
|
|
145
146
|
new_qid_run = new_run[qid + "-changed"]
|
|
@@ -188,7 +189,7 @@ def evaluate_p_mrr_change(
|
|
|
188
189
|
Returns:
|
|
189
190
|
A dictionary with the scores, including "p-MRR", "og" and "changed" keys.
|
|
190
191
|
"""
|
|
191
|
-
followir_scores = defaultdict(dict)
|
|
192
|
+
followir_scores: dict[str, float | dict[str, float]] = defaultdict(dict)
|
|
192
193
|
|
|
193
194
|
qrels_sep = {
|
|
194
195
|
"og": {k: v for k, v in qrels.items() if k.endswith("-og")},
|
|
@@ -227,7 +228,7 @@ def evaluate_p_mrr_change(
|
|
|
227
228
|
ndcg, _map, recall, precision, naucs, avg_mrr, naucs_mrr, cv_recall, {}
|
|
228
229
|
)
|
|
229
230
|
for key, value in scores_dict.items():
|
|
230
|
-
followir_scores[name][key] = value
|
|
231
|
+
followir_scores[name][key] = value # type: ignore[index]
|
|
231
232
|
|
|
232
233
|
return followir_scores
|
|
233
234
|
|
|
@@ -254,8 +255,8 @@ def confidence_scores(sim_scores: list[float]) -> dict[str, float]:
|
|
|
254
255
|
sim_scores_sorted = sorted(sim_scores)[::-1]
|
|
255
256
|
|
|
256
257
|
cs_max = sim_scores_sorted[0]
|
|
257
|
-
cs_std = np.std(sim_scores)
|
|
258
|
-
cs_diff1 =
|
|
258
|
+
cs_std = float(np.std(sim_scores))
|
|
259
|
+
cs_diff1 = 0.0
|
|
259
260
|
if len(sim_scores) > 1:
|
|
260
261
|
cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1]
|
|
261
262
|
elif len(sim_scores) == 1:
|
|
@@ -410,7 +411,7 @@ def make_score_dict(
|
|
|
410
411
|
cv_recall: dict[str, float],
|
|
411
412
|
task_scores: dict[str, float],
|
|
412
413
|
previous_results_model_meta: dict[str, Any] | None = None,
|
|
413
|
-
) -> dict[str,
|
|
414
|
+
) -> dict[str, Any]:
|
|
414
415
|
return {
|
|
415
416
|
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
|
|
416
417
|
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
|
|
@@ -528,7 +529,7 @@ def max_over_subqueries(
|
|
|
528
529
|
|
|
529
530
|
|
|
530
531
|
def calculate_retrieval_scores(
|
|
531
|
-
results:
|
|
532
|
+
results: Mapping[str, Mapping[str, float]],
|
|
532
533
|
qrels: RelevantDocumentsType,
|
|
533
534
|
k_values: list[int],
|
|
534
535
|
skip_first_result: bool = False,
|
|
@@ -576,7 +577,7 @@ def calculate_retrieval_scores(
|
|
|
576
577
|
|
|
577
578
|
|
|
578
579
|
def evaluate_abstention(
|
|
579
|
-
results:
|
|
580
|
+
results: Mapping[str, Mapping[str, float]],
|
|
580
581
|
metric_scores: dict[str, list[float]],
|
|
581
582
|
) -> dict[str, float]:
|
|
582
583
|
"""Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
|
|
@@ -591,21 +592,21 @@ def evaluate_abstention(
|
|
|
591
592
|
all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())]
|
|
592
593
|
all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores]
|
|
593
594
|
conf_fcts = list(all_conf_scores[0].keys())
|
|
594
|
-
|
|
595
|
+
all_conf_scores_ = {
|
|
595
596
|
fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts
|
|
596
597
|
}
|
|
597
|
-
|
|
598
|
+
metric_scores_ = {k: np.array(v) for k, v in metric_scores.items()}
|
|
598
599
|
naucs = {}
|
|
599
600
|
|
|
600
|
-
for metric_name, scores in
|
|
601
|
-
for fct, conf_scores in
|
|
601
|
+
for metric_name, scores in metric_scores_.items():
|
|
602
|
+
for fct, conf_scores in all_conf_scores_.items():
|
|
602
603
|
naucs[f"nAUC_{metric_name}_{fct}"] = nauc(conf_scores, scores)
|
|
603
604
|
|
|
604
605
|
return naucs
|
|
605
606
|
|
|
606
607
|
|
|
607
608
|
def calculate_cv_recall(
|
|
608
|
-
results:
|
|
609
|
+
results: Mapping[str, Mapping[str, float]],
|
|
609
610
|
qrels: RelevantDocumentsType,
|
|
610
611
|
k_values: list[int],
|
|
611
612
|
skip_first_result: bool = False,
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any, Protocol
|
|
2
|
+
from typing import Any, Protocol, cast
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from datasets import Dataset
|
|
@@ -9,7 +9,7 @@ from typing_extensions import Self
|
|
|
9
9
|
from mteb._create_dataloaders import create_dataloader
|
|
10
10
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
11
11
|
from mteb.models import EncoderProtocol
|
|
12
|
-
from mteb.types import BatchedInput
|
|
12
|
+
from mteb.types import Array, BatchedInput, EncodeKwargs
|
|
13
13
|
|
|
14
14
|
from .evaluator import Evaluator
|
|
15
15
|
|
|
@@ -17,11 +17,11 @@ logger = logging.getLogger(__name__)
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class SklearnModelProtocol(Protocol):
|
|
20
|
-
def fit(self, X:
|
|
21
|
-
def predict(self, X:
|
|
20
|
+
def fit(self, X: Array, y: np.ndarray | list[int]) -> None: ... # noqa: N803
|
|
21
|
+
def predict(self, X: Array) -> np.ndarray: ... # noqa: N803
|
|
22
22
|
def get_params(self) -> dict[str, Any]: ...
|
|
23
|
-
def set_params(self, **kwargs: dict[str, Any]) -> Self: ...
|
|
24
|
-
def score(self, X:
|
|
23
|
+
def set_params(self, random_state: int, **kwargs: dict[str, Any]) -> Self: ...
|
|
24
|
+
def score(self, X: Array, y: np.ndarray | list[int]) -> float: ... # noqa: N803
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class SklearnEvaluator(Evaluator):
|
|
@@ -50,7 +50,7 @@ class SklearnEvaluator(Evaluator):
|
|
|
50
50
|
self.evaluator_model = evaluator_model
|
|
51
51
|
|
|
52
52
|
def create_dataloaders(
|
|
53
|
-
self, encode_kwargs:
|
|
53
|
+
self, encode_kwargs: EncodeKwargs
|
|
54
54
|
) -> tuple[DataLoader[BatchedInput], DataLoader[BatchedInput]]:
|
|
55
55
|
dataloader_train = create_dataloader(
|
|
56
56
|
self.train_dataset,
|
|
@@ -70,9 +70,9 @@ class SklearnEvaluator(Evaluator):
|
|
|
70
70
|
self,
|
|
71
71
|
model: EncoderProtocol,
|
|
72
72
|
*,
|
|
73
|
-
encode_kwargs:
|
|
74
|
-
test_cache:
|
|
75
|
-
) -> tuple[np.ndarray,
|
|
73
|
+
encode_kwargs: EncodeKwargs,
|
|
74
|
+
test_cache: Array | None = None,
|
|
75
|
+
) -> tuple[np.ndarray, Array]:
|
|
76
76
|
"""Classification evaluation by training a sklearn classifier on the embeddings of the training set and evaluating on the embeddings of the test set.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
@@ -104,6 +104,7 @@ class SklearnEvaluator(Evaluator):
|
|
|
104
104
|
hf_subset=self.hf_subset,
|
|
105
105
|
**encode_kwargs,
|
|
106
106
|
)
|
|
107
|
+
test_cache = cast(Array, test_cache)
|
|
107
108
|
|
|
108
109
|
logger.info("Running - Fitting classifier...")
|
|
109
110
|
y_train = self.train_dataset[self.label_column_name]
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any
|
|
3
2
|
|
|
4
|
-
import numpy as np
|
|
5
3
|
import torch
|
|
6
4
|
from datasets import Dataset
|
|
7
5
|
from tqdm.auto import tqdm
|
|
@@ -10,6 +8,7 @@ from mteb._create_dataloaders import _create_dataloader_from_texts
|
|
|
10
8
|
from mteb._evaluators.evaluator import Evaluator
|
|
11
9
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
12
10
|
from mteb.models import EncoderProtocol
|
|
11
|
+
from mteb.types import Array, EncodeKwargs
|
|
13
12
|
|
|
14
13
|
logger = logging.getLogger(__name__)
|
|
15
14
|
|
|
@@ -33,7 +32,10 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
33
32
|
self.task_metadata = task_metadata
|
|
34
33
|
|
|
35
34
|
def __call__(
|
|
36
|
-
self,
|
|
35
|
+
self,
|
|
36
|
+
model: EncoderProtocol,
|
|
37
|
+
*,
|
|
38
|
+
encode_kwargs: EncodeKwargs,
|
|
37
39
|
) -> dict[str, list[dict[str, float]]]:
|
|
38
40
|
pair_elements = {p for pair in self.pairs for p in pair}
|
|
39
41
|
if isinstance(self.sentences, Dataset):
|
|
@@ -69,11 +71,11 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
69
71
|
|
|
70
72
|
def _similarity_search(
|
|
71
73
|
self,
|
|
72
|
-
query_embeddings:
|
|
73
|
-
corpus_embeddings:
|
|
74
|
+
query_embeddings: Array,
|
|
75
|
+
corpus_embeddings: Array,
|
|
74
76
|
model: EncoderProtocol,
|
|
75
77
|
query_chunk_size: int = 100,
|
|
76
|
-
corpus_chunk_size: int =
|
|
78
|
+
corpus_chunk_size: int = 500_000,
|
|
77
79
|
) -> list[dict[str, float]]:
|
|
78
80
|
"""This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
|
|
79
81
|
|
|
@@ -104,13 +106,15 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
104
106
|
):
|
|
105
107
|
query_embeddings = query_embeddings.to(corpus_embeddings.device)
|
|
106
108
|
|
|
107
|
-
queries_result_list
|
|
109
|
+
queries_result_list: list[list[dict[str, float]]] = [
|
|
110
|
+
[] for _ in range(len(query_embeddings))
|
|
111
|
+
]
|
|
108
112
|
|
|
109
113
|
for query_start_idx in range(0, len(query_embeddings), query_chunk_size):
|
|
110
114
|
# Iterate over chunks of the corpus
|
|
111
115
|
for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size):
|
|
112
116
|
# Compute cosine similarities
|
|
113
|
-
similarity_scores = model.similarity(
|
|
117
|
+
similarity_scores = model.similarity(
|
|
114
118
|
query_embeddings[
|
|
115
119
|
query_start_idx : query_start_idx + query_chunk_size
|
|
116
120
|
],
|
|
@@ -120,15 +124,17 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
120
124
|
)
|
|
121
125
|
|
|
122
126
|
# Get top-k scores
|
|
123
|
-
|
|
124
|
-
torch.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
127
|
+
cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = (
|
|
128
|
+
torch.topk(
|
|
129
|
+
torch.tensor(similarity_scores),
|
|
130
|
+
1,
|
|
131
|
+
dim=1,
|
|
132
|
+
largest=True,
|
|
133
|
+
sorted=False,
|
|
134
|
+
)
|
|
129
135
|
)
|
|
130
|
-
cos_scores_top_k_values =
|
|
131
|
-
cos_scores_top_k_idx =
|
|
136
|
+
cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
|
|
137
|
+
cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
|
|
132
138
|
|
|
133
139
|
for query_itr in range(len(similarity_scores)):
|
|
134
140
|
for sub_corpus_id, score in zip(
|
|
@@ -141,11 +147,14 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
141
147
|
{"corpus_id": corpus_id, "score": score}
|
|
142
148
|
)
|
|
143
149
|
|
|
150
|
+
result_queries_list: list[dict[str, float]] = [
|
|
151
|
+
{} for _ in range(len(query_embeddings))
|
|
152
|
+
]
|
|
144
153
|
# Sort and strip to top_k results
|
|
145
154
|
for idx in range(len(queries_result_list)):
|
|
146
155
|
queries_result_list[idx] = sorted(
|
|
147
156
|
queries_result_list[idx], key=lambda x: x["score"], reverse=True
|
|
148
157
|
)
|
|
149
|
-
|
|
158
|
+
result_queries_list[idx] = queries_result_list[idx][0]
|
|
150
159
|
|
|
151
|
-
return
|
|
160
|
+
return result_queries_list
|