mteb 2.5.2__py3-none-any.whl → 2.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +10 -15
- mteb/_evaluators/any_sts_evaluator.py +1 -4
- mteb/_evaluators/evaluator.py +2 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
- mteb/_evaluators/pair_classification_evaluator.py +3 -1
- mteb/_evaluators/retrieval_metrics.py +17 -16
- mteb/_evaluators/sklearn_evaluator.py +9 -8
- mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
- mteb/_evaluators/text/summarization_evaluator.py +20 -16
- mteb/abstasks/_data_filter/filters.py +1 -1
- mteb/abstasks/_data_filter/task_pipelines.py +3 -0
- mteb/abstasks/_statistics_calculation.py +18 -10
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +33 -27
- mteb/abstasks/aggregate_task_metadata.py +1 -9
- mteb/abstasks/aggregated_task.py +7 -26
- mteb/abstasks/classification.py +10 -4
- mteb/abstasks/clustering.py +18 -14
- mteb/abstasks/clustering_legacy.py +8 -8
- mteb/abstasks/image/image_text_pair_classification.py +5 -3
- mteb/abstasks/multilabel_classification.py +20 -16
- mteb/abstasks/pair_classification.py +18 -9
- mteb/abstasks/regression.py +3 -3
- mteb/abstasks/retrieval.py +12 -9
- mteb/abstasks/sts.py +6 -3
- mteb/abstasks/task_metadata.py +22 -19
- mteb/abstasks/text/bitext_mining.py +36 -25
- mteb/abstasks/text/reranking.py +7 -5
- mteb/abstasks/text/summarization.py +8 -3
- mteb/abstasks/zeroshot_classification.py +5 -2
- mteb/benchmarks/benchmark.py +2 -2
- mteb/cache.py +27 -22
- mteb/cli/_display_tasks.py +2 -2
- mteb/cli/build_cli.py +15 -10
- mteb/cli/generate_model_card.py +10 -7
- mteb/deprecated_evaluator.py +60 -46
- mteb/evaluate.py +39 -30
- mteb/filter_tasks.py +25 -26
- mteb/get_tasks.py +29 -30
- mteb/languages/language_scripts.py +5 -3
- mteb/leaderboard/app.py +1 -1
- mteb/load_results.py +12 -12
- mteb/models/abs_encoder.py +7 -5
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +2 -2
- mteb/models/get_model_meta.py +8 -1
- mteb/models/instruct_wrapper.py +11 -5
- mteb/models/model_implementations/andersborges.py +2 -2
- mteb/models/model_implementations/blip_models.py +8 -8
- mteb/models/model_implementations/bm25.py +1 -1
- mteb/models/model_implementations/clip_models.py +3 -3
- mteb/models/model_implementations/cohere_models.py +1 -1
- mteb/models/model_implementations/cohere_v.py +2 -2
- mteb/models/model_implementations/dino_models.py +23 -23
- mteb/models/model_implementations/emillykkejensen_models.py +3 -3
- mteb/models/model_implementations/gme_v_models.py +4 -3
- mteb/models/model_implementations/jina_clip.py +1 -1
- mteb/models/model_implementations/jina_models.py +1 -1
- mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
- mteb/models/model_implementations/llm2clip_models.py +3 -3
- mteb/models/model_implementations/mcinext_models.py +4 -1
- mteb/models/model_implementations/moco_models.py +2 -2
- mteb/models/model_implementations/model2vec_models.py +1 -1
- mteb/models/model_implementations/nomic_models.py +8 -8
- mteb/models/model_implementations/openclip_models.py +7 -7
- mteb/models/model_implementations/random_baseline.py +3 -3
- mteb/models/model_implementations/rasgaard_models.py +1 -1
- mteb/models/model_implementations/repllama_models.py +2 -2
- mteb/models/model_implementations/rerankers_custom.py +3 -3
- mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
- mteb/models/model_implementations/siglip_models.py +10 -10
- mteb/models/model_implementations/vlm2vec_models.py +1 -1
- mteb/models/model_implementations/voyage_v.py +4 -4
- mteb/models/model_meta.py +14 -13
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
- mteb/models/search_wrappers.py +26 -12
- mteb/models/sentence_transformer_wrapper.py +19 -14
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +28 -20
- mteb/results/model_result.py +52 -22
- mteb/results/task_result.py +55 -58
- mteb/similarity_functions.py +11 -7
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +12 -12
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +9 -3
- {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
- {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/RECORD +104 -103
- {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
- {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
- {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
mteb/evaluate.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import warnings
|
|
4
5
|
from collections.abc import Iterable
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from time import time
|
|
@@ -13,11 +14,10 @@ from mteb._helpful_enum import HelpfulStrEnum
|
|
|
13
14
|
from mteb.abstasks import AbsTaskRetrieval
|
|
14
15
|
from mteb.abstasks.abstask import AbsTask
|
|
15
16
|
from mteb.abstasks.aggregated_task import AbsTaskAggregate
|
|
17
|
+
from mteb.benchmarks.benchmark import Benchmark
|
|
16
18
|
from mteb.cache import ResultCache
|
|
17
19
|
from mteb.models.model_meta import ModelMeta
|
|
18
20
|
from mteb.models.models_protocols import (
|
|
19
|
-
CrossEncoderProtocol,
|
|
20
|
-
EncoderProtocol,
|
|
21
21
|
MTEBModels,
|
|
22
22
|
)
|
|
23
23
|
from mteb.models.sentence_transformer_wrapper import (
|
|
@@ -57,27 +57,26 @@ def _sanitize_model(
|
|
|
57
57
|
) -> tuple[MTEBModels | ModelMeta, ModelMeta, ModelName, Revision]:
|
|
58
58
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
|
59
59
|
|
|
60
|
+
wrapped_model: MTEBModels | ModelMeta
|
|
60
61
|
if isinstance(model, SentenceTransformer):
|
|
61
|
-
|
|
62
|
-
meta =
|
|
63
|
-
_mdl = cast(EncoderProtocol, _mdl)
|
|
64
|
-
model = _mdl
|
|
62
|
+
wrapped_model = SentenceTransformerEncoderWrapper(model)
|
|
63
|
+
meta = wrapped_model.mteb_model_meta
|
|
65
64
|
elif isinstance(model, CrossEncoder):
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
meta = _mdl.mteb_model_meta
|
|
69
|
-
model = _mdl
|
|
65
|
+
wrapped_model = CrossEncoderWrapper(model)
|
|
66
|
+
meta = wrapped_model.mteb_model_meta
|
|
70
67
|
elif hasattr(model, "mteb_model_meta"):
|
|
71
|
-
meta = model
|
|
68
|
+
meta = getattr(model, "mteb_model_meta")
|
|
72
69
|
if not isinstance(meta, ModelMeta):
|
|
73
|
-
meta = ModelMeta.
|
|
70
|
+
meta = ModelMeta._from_hub(None)
|
|
71
|
+
wrapped_model = cast(MTEBModels | ModelMeta, model)
|
|
74
72
|
else:
|
|
75
|
-
meta = ModelMeta.
|
|
73
|
+
meta = ModelMeta._from_hub(None) if not isinstance(model, ModelMeta) else model
|
|
74
|
+
wrapped_model = meta
|
|
76
75
|
|
|
77
76
|
model_name = cast(str, meta.name)
|
|
78
77
|
model_revision = cast(str, meta.revision)
|
|
79
78
|
|
|
80
|
-
return
|
|
79
|
+
return wrapped_model, meta, model_name, model_revision
|
|
81
80
|
|
|
82
81
|
|
|
83
82
|
def _evaluate_task(
|
|
@@ -123,7 +122,8 @@ def _evaluate_task(
|
|
|
123
122
|
prediction_folder=prediction_folder,
|
|
124
123
|
public_only=public_only,
|
|
125
124
|
)
|
|
126
|
-
result
|
|
125
|
+
if isinstance(result, TaskResult):
|
|
126
|
+
result.kg_co2_emissions = tracker.final_emissions
|
|
127
127
|
return result
|
|
128
128
|
|
|
129
129
|
task_results = {}
|
|
@@ -136,10 +136,12 @@ def _evaluate_task(
|
|
|
136
136
|
task.load_data()
|
|
137
137
|
except DatasetNotFoundError as e:
|
|
138
138
|
if not task.metadata.is_public and public_only is None:
|
|
139
|
-
|
|
139
|
+
msg = (
|
|
140
140
|
f"Dataset for private task '{task.metadata.name}' not found. "
|
|
141
141
|
"Make sure you have access to the dataset and that you have set up the authentication correctly. To disable this warning set `public_only=False`"
|
|
142
142
|
)
|
|
143
|
+
logger.warning(msg)
|
|
144
|
+
warnings.warn(msg)
|
|
143
145
|
return TaskError(
|
|
144
146
|
task_name=task.metadata.name,
|
|
145
147
|
exception=str(e),
|
|
@@ -147,7 +149,7 @@ def _evaluate_task(
|
|
|
147
149
|
if public_only is False:
|
|
148
150
|
raise e
|
|
149
151
|
|
|
150
|
-
evaluation_time = 0
|
|
152
|
+
evaluation_time = 0.0
|
|
151
153
|
|
|
152
154
|
for split, hf_subsets in splits.items():
|
|
153
155
|
tick = time()
|
|
@@ -194,12 +196,18 @@ def _check_model_modalities(
|
|
|
194
196
|
return
|
|
195
197
|
|
|
196
198
|
model_modalities = set(model.modalities)
|
|
199
|
+
check_tasks: Iterable[AbsTask] = []
|
|
197
200
|
if isinstance(tasks, AbsTask):
|
|
198
|
-
|
|
201
|
+
check_tasks = [tasks]
|
|
202
|
+
elif isinstance(tasks, Benchmark):
|
|
203
|
+
benchmark = cast(Benchmark, tasks)
|
|
204
|
+
check_tasks = benchmark.tasks
|
|
205
|
+
else:
|
|
206
|
+
check_tasks = cast(Iterable[AbsTask], tasks)
|
|
199
207
|
|
|
200
208
|
warnings, errors = [], []
|
|
201
209
|
|
|
202
|
-
for task in
|
|
210
|
+
for task in check_tasks:
|
|
203
211
|
# only retrieval tasks have different modalities for query and document and can be run with partial overlaps
|
|
204
212
|
if isinstance(task, AbsTaskRetrieval):
|
|
205
213
|
query_mods = set(task.metadata.get_modalities(PromptType.query))
|
|
@@ -332,10 +340,10 @@ def evaluate(
|
|
|
332
340
|
|
|
333
341
|
# AbsTaskAggregate is a special case where we have to run multiple tasks and combine the results
|
|
334
342
|
if isinstance(tasks, AbsTaskAggregate):
|
|
335
|
-
|
|
343
|
+
aggregated_task = cast(AbsTaskAggregate, tasks)
|
|
336
344
|
results = evaluate(
|
|
337
345
|
model,
|
|
338
|
-
|
|
346
|
+
aggregated_task.metadata.tasks,
|
|
339
347
|
co2_tracker=co2_tracker,
|
|
340
348
|
raise_error=raise_error,
|
|
341
349
|
encode_kwargs=encode_kwargs,
|
|
@@ -345,17 +353,18 @@ def evaluate(
|
|
|
345
353
|
show_progress_bar=show_progress_bar,
|
|
346
354
|
public_only=public_only,
|
|
347
355
|
)
|
|
348
|
-
|
|
356
|
+
combined_results = aggregated_task.combine_task_results(results.task_results)
|
|
349
357
|
return ModelResult(
|
|
350
358
|
model_name=results.model_name,
|
|
351
359
|
model_revision=results.model_revision,
|
|
352
|
-
task_results=[
|
|
360
|
+
task_results=[combined_results],
|
|
353
361
|
)
|
|
354
362
|
|
|
355
363
|
if isinstance(tasks, AbsTask):
|
|
356
364
|
task = tasks
|
|
357
365
|
else:
|
|
358
|
-
|
|
366
|
+
tasks = cast(Iterable[AbsTask], tasks)
|
|
367
|
+
evaluate_results = []
|
|
359
368
|
exceptions = []
|
|
360
369
|
tasks_tqdm = tqdm(
|
|
361
370
|
tasks,
|
|
@@ -376,23 +385,23 @@ def evaluate(
|
|
|
376
385
|
show_progress_bar=False,
|
|
377
386
|
public_only=public_only,
|
|
378
387
|
)
|
|
379
|
-
|
|
388
|
+
evaluate_results.extend(_res.task_results)
|
|
380
389
|
if _res.exceptions:
|
|
381
390
|
exceptions.extend(_res.exceptions)
|
|
382
391
|
return ModelResult(
|
|
383
392
|
model_name=_res.model_name,
|
|
384
393
|
model_revision=_res.model_revision,
|
|
385
|
-
task_results=
|
|
394
|
+
task_results=evaluate_results,
|
|
386
395
|
exceptions=exceptions,
|
|
387
396
|
)
|
|
388
397
|
|
|
389
398
|
overwrite_strategy = OverwriteStrategy.from_str(overwrite_strategy)
|
|
390
399
|
|
|
391
|
-
existing_results = None
|
|
400
|
+
existing_results: TaskResult | None = None
|
|
392
401
|
if cache and overwrite_strategy != OverwriteStrategy.ALWAYS:
|
|
393
|
-
|
|
394
|
-
if
|
|
395
|
-
existing_results =
|
|
402
|
+
cache_results = cache.load_task_result(task.metadata.name, meta)
|
|
403
|
+
if cache_results:
|
|
404
|
+
existing_results = cache_results
|
|
396
405
|
|
|
397
406
|
if (
|
|
398
407
|
existing_results
|
mteb/filter_tasks.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""This script contains functions that are used to get an overview of the MTEB benchmark."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from collections.abc import Sequence
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
5
|
from typing import overload
|
|
6
6
|
|
|
7
7
|
from mteb.abstasks import (
|
|
@@ -34,14 +34,14 @@ def _check_is_valid_language(lang: str) -> None:
|
|
|
34
34
|
|
|
35
35
|
@overload
|
|
36
36
|
def filter_tasks(
|
|
37
|
-
tasks:
|
|
37
|
+
tasks: Iterable[AbsTask],
|
|
38
38
|
*,
|
|
39
|
-
languages:
|
|
40
|
-
script:
|
|
41
|
-
domains:
|
|
42
|
-
task_types:
|
|
43
|
-
categories:
|
|
44
|
-
modalities:
|
|
39
|
+
languages: Sequence[str] | None = None,
|
|
40
|
+
script: Sequence[str] | None = None,
|
|
41
|
+
domains: Iterable[TaskDomain] | None = None,
|
|
42
|
+
task_types: Iterable[TaskType] | None = None,
|
|
43
|
+
categories: Iterable[TaskCategory] | None = None,
|
|
44
|
+
modalities: Iterable[Modalities] | None = None,
|
|
45
45
|
exclusive_modality_filter: bool = False,
|
|
46
46
|
exclude_superseded: bool = False,
|
|
47
47
|
exclude_aggregate: bool = False,
|
|
@@ -51,14 +51,14 @@ def filter_tasks(
|
|
|
51
51
|
|
|
52
52
|
@overload
|
|
53
53
|
def filter_tasks(
|
|
54
|
-
tasks:
|
|
54
|
+
tasks: Iterable[type[AbsTask]],
|
|
55
55
|
*,
|
|
56
|
-
languages:
|
|
57
|
-
script:
|
|
58
|
-
domains:
|
|
59
|
-
task_types:
|
|
60
|
-
categories:
|
|
61
|
-
modalities:
|
|
56
|
+
languages: Sequence[str] | None = None,
|
|
57
|
+
script: Sequence[str] | None = None,
|
|
58
|
+
domains: Iterable[TaskDomain] | None = None,
|
|
59
|
+
task_types: Iterable[TaskType] | None = None,
|
|
60
|
+
categories: Iterable[TaskCategory] | None = None,
|
|
61
|
+
modalities: Iterable[Modalities] | None = None,
|
|
62
62
|
exclusive_modality_filter: bool = False,
|
|
63
63
|
exclude_superseded: bool = False,
|
|
64
64
|
exclude_aggregate: bool = False,
|
|
@@ -67,14 +67,14 @@ def filter_tasks(
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
def filter_tasks(
|
|
70
|
-
tasks:
|
|
70
|
+
tasks: Iterable[AbsTask] | Iterable[type[AbsTask]],
|
|
71
71
|
*,
|
|
72
|
-
languages:
|
|
73
|
-
script:
|
|
74
|
-
domains:
|
|
75
|
-
task_types:
|
|
76
|
-
categories:
|
|
77
|
-
modalities:
|
|
72
|
+
languages: Sequence[str] | None = None,
|
|
73
|
+
script: Sequence[str] | None = None,
|
|
74
|
+
domains: Iterable[TaskDomain] | None = None,
|
|
75
|
+
task_types: Iterable[TaskType] | None = None,
|
|
76
|
+
categories: Iterable[TaskCategory] | None = None,
|
|
77
|
+
modalities: Iterable[Modalities] | None = None,
|
|
78
78
|
exclusive_modality_filter: bool = False,
|
|
79
79
|
exclude_superseded: bool = False,
|
|
80
80
|
exclude_aggregate: bool = False,
|
|
@@ -92,7 +92,6 @@ def filter_tasks(
|
|
|
92
92
|
task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included.
|
|
93
93
|
categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list.
|
|
94
94
|
exclude_superseded: A boolean flag to exclude datasets which are superseded by another.
|
|
95
|
-
eval_splits: A list of evaluation splits to include. If None, all splits are included.
|
|
96
95
|
modalities: A list of modalities to include. If None, all modalities are included.
|
|
97
96
|
exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
|
|
98
97
|
task's modalities and ALL task modalities are in filter modalities (exact match).
|
|
@@ -113,12 +112,12 @@ def filter_tasks(
|
|
|
113
112
|
"""
|
|
114
113
|
langs_to_keep = None
|
|
115
114
|
if languages:
|
|
116
|
-
[_check_is_valid_language(lang) for lang in languages]
|
|
115
|
+
[_check_is_valid_language(lang) for lang in languages] # type: ignore[func-returns-value]
|
|
117
116
|
langs_to_keep = set(languages)
|
|
118
117
|
|
|
119
118
|
script_to_keep = None
|
|
120
119
|
if script:
|
|
121
|
-
[_check_is_valid_script(s) for s in script]
|
|
120
|
+
[_check_is_valid_script(s) for s in script] # type: ignore[func-returns-value]
|
|
122
121
|
script_to_keep = set(script)
|
|
123
122
|
|
|
124
123
|
domains_to_keep = None
|
|
@@ -178,4 +177,4 @@ def filter_tasks(
|
|
|
178
177
|
|
|
179
178
|
_tasks.append(t)
|
|
180
179
|
|
|
181
|
-
return _tasks
|
|
180
|
+
return _tasks # type: ignore[return-value] # type checker cannot infer the overload return type
|
mteb/get_tasks.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
import difflib
|
|
4
4
|
import logging
|
|
5
|
+
import warnings
|
|
5
6
|
from collections import Counter, defaultdict
|
|
6
|
-
from collections.abc import Sequence
|
|
7
|
+
from collections.abc import Iterable, Sequence
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
import pandas as pd
|
|
@@ -22,12 +23,11 @@ logger = logging.getLogger(__name__)
|
|
|
22
23
|
def _gather_tasks() -> tuple[type[AbsTask], ...]:
|
|
23
24
|
import mteb.tasks as tasks
|
|
24
25
|
|
|
25
|
-
|
|
26
|
+
return tuple(
|
|
26
27
|
t
|
|
27
28
|
for t in tasks.__dict__.values()
|
|
28
29
|
if isinstance(t, type) and issubclass(t, AbsTask)
|
|
29
|
-
|
|
30
|
-
return tuple(tasks)
|
|
30
|
+
)
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def _create_name_to_task_mapping(
|
|
@@ -43,7 +43,7 @@ def _create_name_to_task_mapping(
|
|
|
43
43
|
return metadata_names
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def _create_similar_tasks(tasks:
|
|
46
|
+
def _create_similar_tasks(tasks: Iterable[type[AbsTask]]) -> dict[str, list[str]]:
|
|
47
47
|
"""Create a dictionary of similar tasks.
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
@@ -194,9 +194,8 @@ class MTEBTasks(tuple[AbsTask]):
|
|
|
194
194
|
string with a LaTeX table.
|
|
195
195
|
"""
|
|
196
196
|
if include_citation_in_name and "name" in properties:
|
|
197
|
-
properties
|
|
198
|
-
df =
|
|
199
|
-
df["name"] = df["name"] + " " + df["intext_citation"]
|
|
197
|
+
df = self.to_dataframe(tuple(properties) + ("intext_citation",))
|
|
198
|
+
df["name"] = df["name"] + " " + df["intext_citation"] # type: ignore[operator]
|
|
200
199
|
df = df.drop(columns=["intext_citation"])
|
|
201
200
|
else:
|
|
202
201
|
df = self.to_dataframe(properties)
|
|
@@ -221,17 +220,17 @@ class MTEBTasks(tuple[AbsTask]):
|
|
|
221
220
|
|
|
222
221
|
|
|
223
222
|
def get_tasks(
|
|
224
|
-
tasks:
|
|
223
|
+
tasks: Sequence[str] | None = None,
|
|
225
224
|
*,
|
|
226
|
-
languages:
|
|
227
|
-
script:
|
|
228
|
-
domains:
|
|
229
|
-
task_types:
|
|
230
|
-
categories:
|
|
225
|
+
languages: Sequence[str] | None = None,
|
|
226
|
+
script: Sequence[str] | None = None,
|
|
227
|
+
domains: Sequence[TaskDomain] | None = None,
|
|
228
|
+
task_types: Sequence[TaskType] | None = None,
|
|
229
|
+
categories: Sequence[TaskCategory] | None = None,
|
|
231
230
|
exclude_superseded: bool = True,
|
|
232
|
-
eval_splits:
|
|
231
|
+
eval_splits: Sequence[str] | None = None,
|
|
233
232
|
exclusive_language_filter: bool = False,
|
|
234
|
-
modalities:
|
|
233
|
+
modalities: Sequence[Modalities] | None = None,
|
|
235
234
|
exclusive_modality_filter: bool = False,
|
|
236
235
|
exclude_aggregate: bool = False,
|
|
237
236
|
exclude_private: bool = True,
|
|
@@ -287,7 +286,7 @@ def get_tasks(
|
|
|
287
286
|
]
|
|
288
287
|
return MTEBTasks(_tasks)
|
|
289
288
|
|
|
290
|
-
|
|
289
|
+
tasks_: Sequence[type[AbsTask]] = filter_tasks(
|
|
291
290
|
TASK_LIST,
|
|
292
291
|
languages=languages,
|
|
293
292
|
script=script,
|
|
@@ -300,12 +299,12 @@ def get_tasks(
|
|
|
300
299
|
exclude_aggregate=exclude_aggregate,
|
|
301
300
|
exclude_private=exclude_private,
|
|
302
301
|
)
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
302
|
+
return MTEBTasks(
|
|
303
|
+
[
|
|
304
|
+
cls().filter_languages(languages, script).filter_eval_splits(eval_splits)
|
|
305
|
+
for cls in tasks_
|
|
306
|
+
]
|
|
307
|
+
)
|
|
309
308
|
|
|
310
309
|
|
|
311
310
|
_TASK_RENAMES = {"PersianTextTone": "SynPerTextToneClassification"}
|
|
@@ -313,10 +312,10 @@ _TASK_RENAMES = {"PersianTextTone": "SynPerTextToneClassification"}
|
|
|
313
312
|
|
|
314
313
|
def get_task(
|
|
315
314
|
task_name: str,
|
|
316
|
-
languages:
|
|
317
|
-
script:
|
|
318
|
-
eval_splits:
|
|
319
|
-
hf_subsets:
|
|
315
|
+
languages: Sequence[str] | None = None,
|
|
316
|
+
script: Sequence[str] | None = None,
|
|
317
|
+
eval_splits: Sequence[str] | None = None,
|
|
318
|
+
hf_subsets: Sequence[str] | None = None,
|
|
320
319
|
exclusive_language_filter: bool = False,
|
|
321
320
|
) -> AbsTask:
|
|
322
321
|
"""Get a task by name.
|
|
@@ -340,9 +339,9 @@ def get_task(
|
|
|
340
339
|
"""
|
|
341
340
|
if task_name in _TASK_RENAMES:
|
|
342
341
|
_task_name = _TASK_RENAMES[task_name]
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
)
|
|
342
|
+
msg = f"The task with the given name '{task_name}' has been renamed to '{_task_name}'. To prevent this warning use the new name."
|
|
343
|
+
logger.warning(msg)
|
|
344
|
+
warnings.warn(msg)
|
|
346
345
|
|
|
347
346
|
if task_name not in _TASKS_REGISTRY:
|
|
348
347
|
close_matches = difflib.get_close_matches(task_name, _TASKS_REGISTRY.keys())
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
|
|
4
4
|
from typing_extensions import Self
|
|
5
5
|
|
|
6
|
-
from mteb.languages import check_language_code
|
|
6
|
+
from mteb.languages.check_language_code import check_language_code
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@dataclass
|
|
@@ -25,7 +25,9 @@ class LanguageScripts:
|
|
|
25
25
|
|
|
26
26
|
@classmethod
|
|
27
27
|
def from_languages_and_scripts(
|
|
28
|
-
cls,
|
|
28
|
+
cls,
|
|
29
|
+
languages: Sequence[str] | None = None,
|
|
30
|
+
scripts: Sequence[str] | None = None,
|
|
29
31
|
) -> Self:
|
|
30
32
|
"""Create a LanguageScripts object from lists of languages and scripts.
|
|
31
33
|
|
mteb/leaderboard/app.py
CHANGED
|
@@ -169,7 +169,7 @@ def _update_task_info(task_names: str) -> gr.DataFrame:
|
|
|
169
169
|
df = df.drop(columns="reference")
|
|
170
170
|
return gr.DataFrame(
|
|
171
171
|
df,
|
|
172
|
-
datatype=["markdown"] + ["str"] * (len(df.columns) - 1),
|
|
172
|
+
datatype=["markdown"] + ["str"] * (len(df.columns) - 1),
|
|
173
173
|
buttons=["copy", "fullscreen"],
|
|
174
174
|
show_search="filter",
|
|
175
175
|
)
|
mteb/load_results.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import sys
|
|
4
|
-
from collections.abc import Sequence
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
7
|
from mteb.abstasks.abstask import AbsTask
|
|
@@ -45,8 +45,8 @@ def _model_name_and_revision(
|
|
|
45
45
|
def load_results(
|
|
46
46
|
results_repo: str = "https://github.com/embeddings-benchmark/results",
|
|
47
47
|
download_latest: bool = True,
|
|
48
|
-
models:
|
|
49
|
-
tasks:
|
|
48
|
+
models: Iterable[ModelMeta] | Sequence[str] | None = None,
|
|
49
|
+
tasks: Iterable[AbsTask] | Sequence[str] | None = None,
|
|
50
50
|
validate_and_filter: bool = True,
|
|
51
51
|
require_model_meta: bool = True,
|
|
52
52
|
only_main_score: bool = False,
|
|
@@ -83,21 +83,21 @@ def load_results(
|
|
|
83
83
|
|
|
84
84
|
if models is not None:
|
|
85
85
|
models_to_keep = {}
|
|
86
|
-
for
|
|
87
|
-
if isinstance(
|
|
88
|
-
models_to_keep[
|
|
86
|
+
for model in models:
|
|
87
|
+
if isinstance(model, ModelMeta):
|
|
88
|
+
models_to_keep[model.name] = model.revision
|
|
89
89
|
else:
|
|
90
|
-
models_to_keep[
|
|
90
|
+
models_to_keep[model] = None
|
|
91
91
|
else:
|
|
92
92
|
models_to_keep = None
|
|
93
93
|
|
|
94
|
-
task_names = {}
|
|
94
|
+
task_names: dict[str, AbsTask | None] = {}
|
|
95
95
|
if tasks is not None:
|
|
96
|
-
for
|
|
97
|
-
if isinstance(
|
|
98
|
-
task_names[
|
|
96
|
+
for task_ in tasks:
|
|
97
|
+
if isinstance(task_, AbsTask):
|
|
98
|
+
task_names[task_.metadata.name] = task_
|
|
99
99
|
else:
|
|
100
|
-
task_names[
|
|
100
|
+
task_names[task_] = None
|
|
101
101
|
|
|
102
102
|
model_results = []
|
|
103
103
|
for model_path in model_paths:
|
mteb/models/abs_encoder.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import warnings
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Callable, Sequence
|
|
4
5
|
from typing import Any, Literal, cast, get_args, overload
|
|
@@ -43,7 +44,7 @@ class AbsEncoder(ABC):
|
|
|
43
44
|
model: Any
|
|
44
45
|
mteb_model_meta: ModelMeta | None = None
|
|
45
46
|
model_prompts: dict[str, str] | None = None
|
|
46
|
-
instruction_template: str | Callable[[str, PromptType], str] | None = None
|
|
47
|
+
instruction_template: str | Callable[[str, PromptType | None], str] | None = None
|
|
47
48
|
prompts_dict: dict[str, str] | None = None
|
|
48
49
|
|
|
49
50
|
def get_prompt_name(
|
|
@@ -110,7 +111,7 @@ class AbsEncoder(ABC):
|
|
|
110
111
|
if not self.model_prompts:
|
|
111
112
|
return None
|
|
112
113
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
113
|
-
return self.model_prompts.get(prompt_name)
|
|
114
|
+
return self.model_prompts.get(prompt_name) if prompt_name else None
|
|
114
115
|
|
|
115
116
|
@staticmethod
|
|
116
117
|
@overload
|
|
@@ -187,6 +188,7 @@ class AbsEncoder(ABC):
|
|
|
187
188
|
except KeyError:
|
|
188
189
|
msg = f"Task name {task_name} is not valid. {valid_keys_msg}"
|
|
189
190
|
logger.warning(msg)
|
|
191
|
+
warnings.warn(msg)
|
|
190
192
|
invalid_task_messages.add(msg)
|
|
191
193
|
invalid_keys.add(task_key)
|
|
192
194
|
|
|
@@ -232,9 +234,9 @@ class AbsEncoder(ABC):
|
|
|
232
234
|
if isinstance(prompt, dict) and prompt_type:
|
|
233
235
|
if prompt.get(prompt_type.value):
|
|
234
236
|
return prompt[prompt_type.value]
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
)
|
|
237
|
+
msg = f"Prompt type '{prompt_type}' not found in task metadata for task '{task_metadata.name}'."
|
|
238
|
+
logger.warning(msg)
|
|
239
|
+
warnings.warn(msg)
|
|
238
240
|
return ""
|
|
239
241
|
|
|
240
242
|
if prompt:
|
|
@@ -5,8 +5,6 @@ from typing import Any, Protocol, runtime_checkable
|
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
|
-
from mteb.types import BatchedInput
|
|
9
|
-
|
|
10
8
|
|
|
11
9
|
@runtime_checkable
|
|
12
10
|
class CacheBackendProtocol(Protocol):
|
|
@@ -26,7 +24,7 @@ class CacheBackendProtocol(Protocol):
|
|
|
26
24
|
**kwargs: Additional backend-specific arguments.
|
|
27
25
|
"""
|
|
28
26
|
|
|
29
|
-
def add(self, item: list[
|
|
27
|
+
def add(self, item: list[dict[str, Any]], vectors: np.ndarray) -> None:
|
|
30
28
|
"""Add a vector to the cache.
|
|
31
29
|
|
|
32
30
|
Args:
|
|
@@ -34,7 +32,7 @@ class CacheBackendProtocol(Protocol):
|
|
|
34
32
|
vectors: Embedding vector of shape (dim,) or (1, dim).
|
|
35
33
|
"""
|
|
36
34
|
|
|
37
|
-
def get_vector(self, item:
|
|
35
|
+
def get_vector(self, item: dict[str, Any]) -> np.ndarray | None:
|
|
38
36
|
"""Retrieve the cached vector for the given item.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
@@ -53,5 +51,5 @@ class CacheBackendProtocol(Protocol):
|
|
|
53
51
|
def close(self) -> None:
|
|
54
52
|
"""Release resources or flush data."""
|
|
55
53
|
|
|
56
|
-
def __contains__(self, item:
|
|
54
|
+
def __contains__(self, item: dict[str, Any]) -> bool:
|
|
57
55
|
"""Check whether the cache contains an item."""
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import hashlib
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from typing import Any
|
|
2
4
|
|
|
3
|
-
from mteb.types import BatchedInput
|
|
4
5
|
|
|
5
|
-
|
|
6
|
-
def _hash_item(item: BatchedInput) -> str:
|
|
6
|
+
def _hash_item(item: Mapping[str, Any]) -> str:
|
|
7
7
|
item_hash = ""
|
|
8
8
|
if "text" in item:
|
|
9
|
-
|
|
9
|
+
item_text: str = item["text"]
|
|
10
|
+
item_hash = hashlib.sha256(item_text.encode()).hexdigest()
|
|
10
11
|
|
|
11
12
|
if "image" in item:
|
|
12
13
|
from PIL import Image
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import warnings
|
|
3
4
|
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
|
|
@@ -36,7 +38,7 @@ class FaissCache:
|
|
|
36
38
|
logger.info(f"Initialized FAISS VectorCacheMap in {self.directory}")
|
|
37
39
|
self.load()
|
|
38
40
|
|
|
39
|
-
def add(self, items: list[
|
|
41
|
+
def add(self, items: list[dict[str, Any]], vectors: np.ndarray) -> None:
|
|
40
42
|
"""Add vector to FAISS index."""
|
|
41
43
|
import faiss
|
|
42
44
|
|
|
@@ -71,7 +73,9 @@ class FaissCache:
|
|
|
71
73
|
try:
|
|
72
74
|
return self.index.reconstruct(idx)
|
|
73
75
|
except Exception:
|
|
74
|
-
|
|
76
|
+
msg = f"Vector id {idx} missing for hash {item_hash}"
|
|
77
|
+
logger.warning(msg)
|
|
78
|
+
warnings.warn(msg)
|
|
75
79
|
return None
|
|
76
80
|
|
|
77
81
|
def save(self) -> None:
|