mteb 2.5.3__py3-none-any.whl → 2.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +10 -15
- mteb/_evaluators/any_sts_evaluator.py +1 -4
- mteb/_evaluators/evaluator.py +2 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
- mteb/_evaluators/pair_classification_evaluator.py +3 -1
- mteb/_evaluators/retrieval_metrics.py +17 -16
- mteb/_evaluators/sklearn_evaluator.py +9 -8
- mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
- mteb/_evaluators/text/summarization_evaluator.py +20 -16
- mteb/abstasks/_data_filter/filters.py +1 -1
- mteb/abstasks/_data_filter/task_pipelines.py +3 -0
- mteb/abstasks/_statistics_calculation.py +18 -10
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +27 -21
- mteb/abstasks/aggregate_task_metadata.py +1 -9
- mteb/abstasks/aggregated_task.py +3 -16
- mteb/abstasks/classification.py +10 -4
- mteb/abstasks/clustering.py +18 -14
- mteb/abstasks/clustering_legacy.py +8 -8
- mteb/abstasks/image/image_text_pair_classification.py +5 -3
- mteb/abstasks/multilabel_classification.py +20 -16
- mteb/abstasks/pair_classification.py +18 -9
- mteb/abstasks/regression.py +3 -3
- mteb/abstasks/retrieval.py +12 -9
- mteb/abstasks/sts.py +6 -3
- mteb/abstasks/task_metadata.py +20 -16
- mteb/abstasks/text/bitext_mining.py +36 -25
- mteb/abstasks/text/reranking.py +7 -5
- mteb/abstasks/text/summarization.py +8 -3
- mteb/abstasks/zeroshot_classification.py +5 -2
- mteb/benchmarks/benchmark.py +2 -2
- mteb/cache.py +20 -18
- mteb/cli/_display_tasks.py +2 -2
- mteb/cli/build_cli.py +5 -5
- mteb/cli/generate_model_card.py +6 -4
- mteb/deprecated_evaluator.py +56 -43
- mteb/evaluate.py +35 -29
- mteb/filter_tasks.py +25 -26
- mteb/get_tasks.py +25 -27
- mteb/languages/language_scripts.py +5 -3
- mteb/leaderboard/app.py +1 -1
- mteb/load_results.py +12 -12
- mteb/models/abs_encoder.py +2 -2
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +2 -1
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +30 -13
- mteb/models/cache_wrappers/cache_wrapper.py +2 -2
- mteb/models/get_model_meta.py +8 -1
- mteb/models/instruct_wrapper.py +11 -5
- mteb/models/model_implementations/andersborges.py +2 -2
- mteb/models/model_implementations/blip_models.py +8 -8
- mteb/models/model_implementations/bm25.py +1 -1
- mteb/models/model_implementations/clip_models.py +3 -3
- mteb/models/model_implementations/cohere_models.py +1 -1
- mteb/models/model_implementations/cohere_v.py +2 -2
- mteb/models/model_implementations/dino_models.py +23 -23
- mteb/models/model_implementations/emillykkejensen_models.py +3 -3
- mteb/models/model_implementations/jina_clip.py +1 -1
- mteb/models/model_implementations/jina_models.py +1 -1
- mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
- mteb/models/model_implementations/llm2clip_models.py +3 -3
- mteb/models/model_implementations/moco_models.py +2 -2
- mteb/models/model_implementations/model2vec_models.py +1 -1
- mteb/models/model_implementations/nomic_models.py +8 -8
- mteb/models/model_implementations/openclip_models.py +7 -7
- mteb/models/model_implementations/random_baseline.py +3 -3
- mteb/models/model_implementations/rasgaard_models.py +1 -1
- mteb/models/model_implementations/repllama_models.py +2 -2
- mteb/models/model_implementations/rerankers_custom.py +3 -3
- mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
- mteb/models/model_implementations/siglip_models.py +10 -10
- mteb/models/model_implementations/vlm2vec_models.py +1 -1
- mteb/models/model_implementations/voyage_v.py +4 -4
- mteb/models/model_meta.py +11 -12
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +5 -5
- mteb/models/search_wrappers.py +22 -10
- mteb/models/sentence_transformer_wrapper.py +9 -4
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +25 -19
- mteb/results/model_result.py +49 -21
- mteb/results/task_result.py +45 -51
- mteb/similarity_functions.py +11 -7
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +12 -12
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +9 -3
- {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
- {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/RECORD +102 -101
- {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
- {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
- {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
mteb/cli/generate_model_card.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import warnings
|
|
3
|
+
from collections.abc import Sequence
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
5
6
|
from huggingface_hub import ModelCard, ModelCardData, repo_exists
|
|
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|
|
13
14
|
|
|
14
15
|
def generate_model_card(
|
|
15
16
|
model_name: str,
|
|
16
|
-
tasks:
|
|
17
|
+
tasks: Sequence[AbsTask] | None = None,
|
|
17
18
|
existing_model_card_id_or_path: str | Path | None = None,
|
|
18
19
|
results_cache: ResultCache = ResultCache(),
|
|
19
20
|
output_path: Path = Path("model_card.md"),
|
|
@@ -48,8 +49,8 @@ def generate_model_card(
|
|
|
48
49
|
for task_result in models_results.task_results:
|
|
49
50
|
eval_results.extend(task_result.get_hf_eval_results())
|
|
50
51
|
|
|
51
|
-
existing_model_card_data = (
|
|
52
|
-
existing_model_card.data if existing_model_card else ModelCardData()
|
|
52
|
+
existing_model_card_data: ModelCardData = (
|
|
53
|
+
existing_model_card.data if existing_model_card else ModelCardData() # type: ignore[assignment]
|
|
53
54
|
)
|
|
54
55
|
|
|
55
56
|
if existing_model_card_data.eval_results is None:
|
|
@@ -89,7 +90,8 @@ def generate_model_card(
|
|
|
89
90
|
benchmark_results, existing_model_card
|
|
90
91
|
)
|
|
91
92
|
|
|
92
|
-
if push_to_hub:
|
|
93
|
+
if push_to_hub and existing_model_card_id_or_path:
|
|
94
|
+
existing_model_card_id_or_path = str(existing_model_card_id_or_path)
|
|
93
95
|
if repo_exists(existing_model_card_id_or_path):
|
|
94
96
|
existing_model_card.push_to_hub(existing_model_card_id_or_path, token=token)
|
|
95
97
|
else:
|
mteb/deprecated_evaluator.py
CHANGED
|
@@ -6,23 +6,23 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
import traceback
|
|
8
8
|
import warnings
|
|
9
|
-
from collections.abc import Iterable
|
|
9
|
+
from collections.abc import Iterable, Sequence
|
|
10
10
|
from copy import deepcopy
|
|
11
11
|
from datetime import datetime
|
|
12
12
|
from itertools import chain
|
|
13
13
|
from pathlib import Path
|
|
14
14
|
from time import time
|
|
15
|
-
from typing import TYPE_CHECKING, Any
|
|
15
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
16
16
|
|
|
17
17
|
import datasets
|
|
18
18
|
|
|
19
19
|
import mteb
|
|
20
20
|
from mteb.abstasks import AbsTask
|
|
21
|
+
from mteb.abstasks.aggregated_task import AbsTaskAggregate
|
|
21
22
|
from mteb.abstasks.task_metadata import TaskCategory, TaskType
|
|
22
23
|
from mteb.benchmarks import Benchmark
|
|
23
24
|
from mteb.models import (
|
|
24
25
|
CrossEncoderWrapper,
|
|
25
|
-
EncoderProtocol,
|
|
26
26
|
ModelMeta,
|
|
27
27
|
MTEBModels,
|
|
28
28
|
SentenceTransformerEncoderWrapper,
|
|
@@ -53,7 +53,7 @@ class MTEB:
|
|
|
53
53
|
)
|
|
54
54
|
def __init__(
|
|
55
55
|
self,
|
|
56
|
-
tasks: Iterable[AbsTask | Benchmark],
|
|
56
|
+
tasks: Iterable[AbsTask] | Iterable[Benchmark],
|
|
57
57
|
*,
|
|
58
58
|
err_logs_path: str = "error_logs.txt",
|
|
59
59
|
) -> None:
|
|
@@ -64,15 +64,14 @@ class MTEB:
|
|
|
64
64
|
`mteb.get_tasks(["task1","task2"]) or `mteb.get_benchmark("MTEB(eng, classic)").
|
|
65
65
|
err_logs_path: Path to save error logs.
|
|
66
66
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
self.tasks = list(tasks)
|
|
70
|
-
if len(self.tasks) > 0 and isinstance(self.tasks[0], Benchmark):
|
|
67
|
+
if isinstance(next(iter(tasks)), Benchmark):
|
|
71
68
|
self.benchmarks = tasks
|
|
72
|
-
self.tasks = list(chain.from_iterable(
|
|
69
|
+
self.tasks = list(chain.from_iterable(cast(Iterable[Benchmark], tasks)))
|
|
70
|
+
elif isinstance(next(iter(tasks)), AbsTask):
|
|
71
|
+
self.tasks = list(cast(Iterable[AbsTask], tasks))
|
|
73
72
|
|
|
74
73
|
self.err_logs_path = Path(err_logs_path)
|
|
75
|
-
self.
|
|
74
|
+
self._last_evaluated_splits: dict[str, list[str]] = {}
|
|
76
75
|
|
|
77
76
|
@property
|
|
78
77
|
def available_tasks(self) -> list[str]:
|
|
@@ -85,7 +84,7 @@ class MTEB:
|
|
|
85
84
|
return sorted({x.metadata.type for x in self.tasks})
|
|
86
85
|
|
|
87
86
|
@property
|
|
88
|
-
def available_task_categories(self) -> set[TaskCategory]:
|
|
87
|
+
def available_task_categories(self) -> set[TaskCategory | None]:
|
|
89
88
|
"""Set of available task categories."""
|
|
90
89
|
return {x.metadata.category for x in self.tasks}
|
|
91
90
|
|
|
@@ -232,13 +231,14 @@ class MTEB:
|
|
|
232
231
|
merged_kg_co2_emissions = None
|
|
233
232
|
if existing_kg_co2_emissions and new_kg_co2_emissions:
|
|
234
233
|
merged_kg_co2_emissions = existing_kg_co2_emissions + new_kg_co2_emissions
|
|
234
|
+
existing_evaluation_time = existing_results.evaluation_time or 0
|
|
235
|
+
new_evaluation_time = new_results.evaluation_time or 0
|
|
235
236
|
merged_results = TaskResult(
|
|
236
237
|
dataset_revision=new_results.dataset_revision,
|
|
237
238
|
task_name=new_results.task_name,
|
|
238
239
|
mteb_version=new_results.mteb_version,
|
|
239
240
|
scores=merged_scores,
|
|
240
|
-
evaluation_time=
|
|
241
|
-
+ new_results.evaluation_time,
|
|
241
|
+
evaluation_time=existing_evaluation_time + new_evaluation_time,
|
|
242
242
|
kg_co2_emissions=merged_kg_co2_emissions,
|
|
243
243
|
)
|
|
244
244
|
|
|
@@ -307,13 +307,16 @@ class MTEB:
|
|
|
307
307
|
elif verbosity == 3:
|
|
308
308
|
datasets.logging.set_verbosity(logging.DEBUG)
|
|
309
309
|
|
|
310
|
-
|
|
311
|
-
output_path = self._create_output_folder(meta, output_folder)
|
|
312
|
-
|
|
310
|
+
mteb_model: MTEBModels
|
|
313
311
|
if isinstance(model, SentenceTransformer):
|
|
314
|
-
|
|
312
|
+
mteb_model = SentenceTransformerEncoderWrapper(model)
|
|
315
313
|
elif isinstance(model, CrossEncoder):
|
|
316
|
-
|
|
314
|
+
mteb_model = CrossEncoderWrapper(model)
|
|
315
|
+
else:
|
|
316
|
+
mteb_model = cast(MTEBModels, model)
|
|
317
|
+
|
|
318
|
+
meta = self.create_model_meta(mteb_model)
|
|
319
|
+
output_path = self._create_output_folder(meta, output_folder)
|
|
317
320
|
|
|
318
321
|
# Disable co2_tracker for API models
|
|
319
322
|
if "API" in meta.framework:
|
|
@@ -334,7 +337,7 @@ class MTEB:
|
|
|
334
337
|
) # save them in case we re-use the object (e.g. for reranking)
|
|
335
338
|
|
|
336
339
|
# To evaluate missing splits, we keep track of the task name and the corresponding splits.
|
|
337
|
-
self.
|
|
340
|
+
self._last_evaluated_splits = {}
|
|
338
341
|
|
|
339
342
|
while len(self.tasks) > 0:
|
|
340
343
|
task = self.tasks[0]
|
|
@@ -343,9 +346,10 @@ class MTEB:
|
|
|
343
346
|
)
|
|
344
347
|
|
|
345
348
|
if task.is_aggregate:
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
+
aggregated_task = cast(AbsTaskAggregate, task)
|
|
350
|
+
self_ = MTEB(tasks=aggregated_task.metadata.tasks)
|
|
351
|
+
aggregated_task_results = self_.run(
|
|
352
|
+
mteb_model,
|
|
349
353
|
verbosity=verbosity - 1,
|
|
350
354
|
output_folder=output_folder,
|
|
351
355
|
eval_splits=eval_splits,
|
|
@@ -356,12 +360,15 @@ class MTEB:
|
|
|
356
360
|
encode_kwargs=encode_kwargs,
|
|
357
361
|
**kwargs,
|
|
358
362
|
)
|
|
359
|
-
new_results =
|
|
363
|
+
new_results = aggregated_task.combine_task_results(
|
|
364
|
+
aggregated_task_results
|
|
365
|
+
)
|
|
360
366
|
evaluation_results.append(new_results)
|
|
361
367
|
|
|
362
368
|
if output_path:
|
|
363
|
-
|
|
364
|
-
|
|
369
|
+
new_results.to_disk(
|
|
370
|
+
output_path / f"{aggregated_task.metadata.name}.json"
|
|
371
|
+
)
|
|
365
372
|
del self.tasks[0]
|
|
366
373
|
continue
|
|
367
374
|
|
|
@@ -383,7 +390,7 @@ class MTEB:
|
|
|
383
390
|
task_subsets = task.hf_subsets
|
|
384
391
|
|
|
385
392
|
existing_results = None
|
|
386
|
-
save_path = None
|
|
393
|
+
save_path: Path | None = None
|
|
387
394
|
final_splits_to_run = task_eval_splits
|
|
388
395
|
missing_evaluations = self._get_missing_evaluations(
|
|
389
396
|
existing_results,
|
|
@@ -433,7 +440,7 @@ class MTEB:
|
|
|
433
440
|
logger.info(
|
|
434
441
|
f"No splits to evaluate for {task.metadata.name}. Skipping evaluation."
|
|
435
442
|
)
|
|
436
|
-
self.
|
|
443
|
+
self._last_evaluated_splits[task.metadata.name] = []
|
|
437
444
|
del self.tasks[0]
|
|
438
445
|
continue
|
|
439
446
|
|
|
@@ -441,11 +448,11 @@ class MTEB:
|
|
|
441
448
|
task.check_if_dataset_is_superseded()
|
|
442
449
|
task.load_data()
|
|
443
450
|
|
|
444
|
-
task_results = {}
|
|
451
|
+
task_results: dict[str, dict[str, dict[str, Any]]] = {}
|
|
445
452
|
evaluation_time = 0
|
|
446
453
|
kg_co2_emissions: int | None = 0 if co2_tracker else None
|
|
447
454
|
|
|
448
|
-
self.
|
|
455
|
+
self._last_evaluated_splits[task.metadata.name] = []
|
|
449
456
|
|
|
450
457
|
for split in final_splits_to_run:
|
|
451
458
|
info = missing_evaluations[split]
|
|
@@ -466,7 +473,9 @@ class MTEB:
|
|
|
466
473
|
|
|
467
474
|
if co2_tracker:
|
|
468
475
|
try:
|
|
469
|
-
from codecarbon import
|
|
476
|
+
from codecarbon import ( # type: ignore[import-untyped]
|
|
477
|
+
EmissionsTracker,
|
|
478
|
+
)
|
|
470
479
|
except ImportError:
|
|
471
480
|
raise ImportError(
|
|
472
481
|
"codecarbon is not installed. Please install it using `pip install 'mteb[codecarbon]'` to track CO₂ emissions."
|
|
@@ -482,7 +491,7 @@ class MTEB:
|
|
|
482
491
|
) as tracker:
|
|
483
492
|
results, tick, tock = self._run_eval(
|
|
484
493
|
task,
|
|
485
|
-
|
|
494
|
+
mteb_model,
|
|
486
495
|
split,
|
|
487
496
|
encode_kwargs=encode_kwargs,
|
|
488
497
|
subsets_to_run=subsets_to_run,
|
|
@@ -495,7 +504,7 @@ class MTEB:
|
|
|
495
504
|
else:
|
|
496
505
|
results, tick, tock = self._run_eval(
|
|
497
506
|
task,
|
|
498
|
-
|
|
507
|
+
mteb_model,
|
|
499
508
|
split,
|
|
500
509
|
subsets_to_run=subsets_to_run,
|
|
501
510
|
encode_kwargs=encode_kwargs,
|
|
@@ -511,25 +520,25 @@ class MTEB:
|
|
|
511
520
|
if verbosity >= 1:
|
|
512
521
|
logger.info(f"Scores: {task_results[split]}")
|
|
513
522
|
|
|
514
|
-
self.
|
|
523
|
+
self._last_evaluated_splits[task.metadata.name].append(split)
|
|
515
524
|
|
|
516
525
|
# Create new TaskResult
|
|
517
526
|
new_results = TaskResult.from_task_results(
|
|
518
527
|
task,
|
|
519
|
-
task_results,
|
|
528
|
+
task_results, # type: ignore[arg-type]
|
|
520
529
|
evaluation_time=evaluation_time,
|
|
521
530
|
kg_co2_emissions=kg_co2_emissions,
|
|
522
531
|
)
|
|
523
532
|
|
|
524
533
|
# Merge with existing if needed
|
|
525
|
-
if output_path and save_path.exists():
|
|
534
|
+
if output_path and save_path and save_path.exists():
|
|
526
535
|
existing_results = TaskResult.from_disk(save_path)
|
|
527
536
|
if existing_results:
|
|
528
537
|
merged_results = self._merge_results(existing_results, new_results)
|
|
529
538
|
else:
|
|
530
539
|
merged_results = new_results
|
|
531
540
|
|
|
532
|
-
if output_path:
|
|
541
|
+
if output_path and save_path:
|
|
533
542
|
merged_results.to_disk(save_path)
|
|
534
543
|
|
|
535
544
|
evaluation_results.append(merged_results)
|
|
@@ -556,7 +565,7 @@ class MTEB:
|
|
|
556
565
|
def create_model_meta(model: MTEBModels) -> ModelMeta:
|
|
557
566
|
"""Create a ModelMeta object for the given model."""
|
|
558
567
|
if hasattr(model, "mteb_model_meta") and model.mteb_model_meta is not None:
|
|
559
|
-
meta = model.mteb_model_meta
|
|
568
|
+
meta = model.mteb_model_meta
|
|
560
569
|
else:
|
|
561
570
|
meta = MTEB._get_model_meta(model)
|
|
562
571
|
|
|
@@ -582,7 +591,11 @@ class MTEB:
|
|
|
582
591
|
if output_folder is None:
|
|
583
592
|
return None
|
|
584
593
|
|
|
585
|
-
model_revision: str =
|
|
594
|
+
model_revision: str = (
|
|
595
|
+
model_meta.revision
|
|
596
|
+
if model_meta.revision is not None
|
|
597
|
+
else "no_revision_available"
|
|
598
|
+
)
|
|
586
599
|
model_path_name = model_meta.model_name_as_path()
|
|
587
600
|
|
|
588
601
|
output_path = Path(output_folder) / model_path_name / model_revision
|
|
@@ -604,15 +617,15 @@ class MTEB:
|
|
|
604
617
|
Tasks with empty lists indicate that results already existed and no splits were evaluated.
|
|
605
618
|
"""
|
|
606
619
|
return deepcopy(
|
|
607
|
-
{task: list(splits) for task, splits in self.
|
|
620
|
+
{task: list(splits) for task, splits in self._last_evaluated_splits.items()}
|
|
608
621
|
)
|
|
609
622
|
|
|
610
623
|
@staticmethod
|
|
611
624
|
def _get_missing_evaluations(
|
|
612
625
|
existing_results: TaskResult | None,
|
|
613
|
-
task_eval_splits:
|
|
614
|
-
task_eval_langs:
|
|
615
|
-
eval_subsets:
|
|
626
|
+
task_eval_splits: Sequence[str],
|
|
627
|
+
task_eval_langs: Sequence[str],
|
|
628
|
+
eval_subsets: Sequence[str] | None,
|
|
616
629
|
) -> dict[str, dict[str, Any]]:
|
|
617
630
|
"""Return a dictionary for each split, indicating if the whole split is missing and which subsets are missing."""
|
|
618
631
|
missing_evaluations = {
|
|
@@ -661,7 +674,7 @@ class MTEB:
|
|
|
661
674
|
return missing_evaluations
|
|
662
675
|
|
|
663
676
|
@staticmethod
|
|
664
|
-
def _get_model_meta(model:
|
|
677
|
+
def _get_model_meta(model: MTEBModels) -> ModelMeta:
|
|
665
678
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
|
666
679
|
|
|
667
680
|
if isinstance(model, CrossEncoder):
|
mteb/evaluate.py
CHANGED
|
@@ -14,11 +14,10 @@ from mteb._helpful_enum import HelpfulStrEnum
|
|
|
14
14
|
from mteb.abstasks import AbsTaskRetrieval
|
|
15
15
|
from mteb.abstasks.abstask import AbsTask
|
|
16
16
|
from mteb.abstasks.aggregated_task import AbsTaskAggregate
|
|
17
|
+
from mteb.benchmarks.benchmark import Benchmark
|
|
17
18
|
from mteb.cache import ResultCache
|
|
18
19
|
from mteb.models.model_meta import ModelMeta
|
|
19
20
|
from mteb.models.models_protocols import (
|
|
20
|
-
CrossEncoderProtocol,
|
|
21
|
-
EncoderProtocol,
|
|
22
21
|
MTEBModels,
|
|
23
22
|
)
|
|
24
23
|
from mteb.models.sentence_transformer_wrapper import (
|
|
@@ -58,27 +57,26 @@ def _sanitize_model(
|
|
|
58
57
|
) -> tuple[MTEBModels | ModelMeta, ModelMeta, ModelName, Revision]:
|
|
59
58
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
|
60
59
|
|
|
60
|
+
wrapped_model: MTEBModels | ModelMeta
|
|
61
61
|
if isinstance(model, SentenceTransformer):
|
|
62
|
-
|
|
63
|
-
meta =
|
|
64
|
-
_mdl = cast(EncoderProtocol, _mdl)
|
|
65
|
-
model = _mdl
|
|
62
|
+
wrapped_model = SentenceTransformerEncoderWrapper(model)
|
|
63
|
+
meta = wrapped_model.mteb_model_meta
|
|
66
64
|
elif isinstance(model, CrossEncoder):
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
meta = _mdl.mteb_model_meta
|
|
70
|
-
model = _mdl
|
|
65
|
+
wrapped_model = CrossEncoderWrapper(model)
|
|
66
|
+
meta = wrapped_model.mteb_model_meta
|
|
71
67
|
elif hasattr(model, "mteb_model_meta"):
|
|
72
|
-
meta = model
|
|
68
|
+
meta = getattr(model, "mteb_model_meta")
|
|
73
69
|
if not isinstance(meta, ModelMeta):
|
|
74
|
-
meta = ModelMeta.
|
|
70
|
+
meta = ModelMeta._from_hub(None)
|
|
71
|
+
wrapped_model = cast(MTEBModels | ModelMeta, model)
|
|
75
72
|
else:
|
|
76
|
-
meta = ModelMeta.
|
|
73
|
+
meta = ModelMeta._from_hub(None) if not isinstance(model, ModelMeta) else model
|
|
74
|
+
wrapped_model = meta
|
|
77
75
|
|
|
78
76
|
model_name = cast(str, meta.name)
|
|
79
77
|
model_revision = cast(str, meta.revision)
|
|
80
78
|
|
|
81
|
-
return
|
|
79
|
+
return wrapped_model, meta, model_name, model_revision
|
|
82
80
|
|
|
83
81
|
|
|
84
82
|
def _evaluate_task(
|
|
@@ -124,7 +122,8 @@ def _evaluate_task(
|
|
|
124
122
|
prediction_folder=prediction_folder,
|
|
125
123
|
public_only=public_only,
|
|
126
124
|
)
|
|
127
|
-
result
|
|
125
|
+
if isinstance(result, TaskResult):
|
|
126
|
+
result.kg_co2_emissions = tracker.final_emissions
|
|
128
127
|
return result
|
|
129
128
|
|
|
130
129
|
task_results = {}
|
|
@@ -150,7 +149,7 @@ def _evaluate_task(
|
|
|
150
149
|
if public_only is False:
|
|
151
150
|
raise e
|
|
152
151
|
|
|
153
|
-
evaluation_time = 0
|
|
152
|
+
evaluation_time = 0.0
|
|
154
153
|
|
|
155
154
|
for split, hf_subsets in splits.items():
|
|
156
155
|
tick = time()
|
|
@@ -197,12 +196,18 @@ def _check_model_modalities(
|
|
|
197
196
|
return
|
|
198
197
|
|
|
199
198
|
model_modalities = set(model.modalities)
|
|
199
|
+
check_tasks: Iterable[AbsTask] = []
|
|
200
200
|
if isinstance(tasks, AbsTask):
|
|
201
|
-
|
|
201
|
+
check_tasks = [tasks]
|
|
202
|
+
elif isinstance(tasks, Benchmark):
|
|
203
|
+
benchmark = cast(Benchmark, tasks)
|
|
204
|
+
check_tasks = benchmark.tasks
|
|
205
|
+
else:
|
|
206
|
+
check_tasks = cast(Iterable[AbsTask], tasks)
|
|
202
207
|
|
|
203
208
|
warnings, errors = [], []
|
|
204
209
|
|
|
205
|
-
for task in
|
|
210
|
+
for task in check_tasks:
|
|
206
211
|
# only retrieval tasks have different modalities for query and document and can be run with partial overlaps
|
|
207
212
|
if isinstance(task, AbsTaskRetrieval):
|
|
208
213
|
query_mods = set(task.metadata.get_modalities(PromptType.query))
|
|
@@ -335,10 +340,10 @@ def evaluate(
|
|
|
335
340
|
|
|
336
341
|
# AbsTaskAggregate is a special case where we have to run multiple tasks and combine the results
|
|
337
342
|
if isinstance(tasks, AbsTaskAggregate):
|
|
338
|
-
|
|
343
|
+
aggregated_task = cast(AbsTaskAggregate, tasks)
|
|
339
344
|
results = evaluate(
|
|
340
345
|
model,
|
|
341
|
-
|
|
346
|
+
aggregated_task.metadata.tasks,
|
|
342
347
|
co2_tracker=co2_tracker,
|
|
343
348
|
raise_error=raise_error,
|
|
344
349
|
encode_kwargs=encode_kwargs,
|
|
@@ -348,17 +353,18 @@ def evaluate(
|
|
|
348
353
|
show_progress_bar=show_progress_bar,
|
|
349
354
|
public_only=public_only,
|
|
350
355
|
)
|
|
351
|
-
|
|
356
|
+
combined_results = aggregated_task.combine_task_results(results.task_results)
|
|
352
357
|
return ModelResult(
|
|
353
358
|
model_name=results.model_name,
|
|
354
359
|
model_revision=results.model_revision,
|
|
355
|
-
task_results=[
|
|
360
|
+
task_results=[combined_results],
|
|
356
361
|
)
|
|
357
362
|
|
|
358
363
|
if isinstance(tasks, AbsTask):
|
|
359
364
|
task = tasks
|
|
360
365
|
else:
|
|
361
|
-
|
|
366
|
+
tasks = cast(Iterable[AbsTask], tasks)
|
|
367
|
+
evaluate_results = []
|
|
362
368
|
exceptions = []
|
|
363
369
|
tasks_tqdm = tqdm(
|
|
364
370
|
tasks,
|
|
@@ -379,23 +385,23 @@ def evaluate(
|
|
|
379
385
|
show_progress_bar=False,
|
|
380
386
|
public_only=public_only,
|
|
381
387
|
)
|
|
382
|
-
|
|
388
|
+
evaluate_results.extend(_res.task_results)
|
|
383
389
|
if _res.exceptions:
|
|
384
390
|
exceptions.extend(_res.exceptions)
|
|
385
391
|
return ModelResult(
|
|
386
392
|
model_name=_res.model_name,
|
|
387
393
|
model_revision=_res.model_revision,
|
|
388
|
-
task_results=
|
|
394
|
+
task_results=evaluate_results,
|
|
389
395
|
exceptions=exceptions,
|
|
390
396
|
)
|
|
391
397
|
|
|
392
398
|
overwrite_strategy = OverwriteStrategy.from_str(overwrite_strategy)
|
|
393
399
|
|
|
394
|
-
existing_results = None
|
|
400
|
+
existing_results: TaskResult | None = None
|
|
395
401
|
if cache and overwrite_strategy != OverwriteStrategy.ALWAYS:
|
|
396
|
-
|
|
397
|
-
if
|
|
398
|
-
existing_results =
|
|
402
|
+
cache_results = cache.load_task_result(task.metadata.name, meta)
|
|
403
|
+
if cache_results:
|
|
404
|
+
existing_results = cache_results
|
|
399
405
|
|
|
400
406
|
if (
|
|
401
407
|
existing_results
|
mteb/filter_tasks.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""This script contains functions that are used to get an overview of the MTEB benchmark."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from collections.abc import Sequence
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
5
|
from typing import overload
|
|
6
6
|
|
|
7
7
|
from mteb.abstasks import (
|
|
@@ -34,14 +34,14 @@ def _check_is_valid_language(lang: str) -> None:
|
|
|
34
34
|
|
|
35
35
|
@overload
|
|
36
36
|
def filter_tasks(
|
|
37
|
-
tasks:
|
|
37
|
+
tasks: Iterable[AbsTask],
|
|
38
38
|
*,
|
|
39
|
-
languages:
|
|
40
|
-
script:
|
|
41
|
-
domains:
|
|
42
|
-
task_types:
|
|
43
|
-
categories:
|
|
44
|
-
modalities:
|
|
39
|
+
languages: Sequence[str] | None = None,
|
|
40
|
+
script: Sequence[str] | None = None,
|
|
41
|
+
domains: Iterable[TaskDomain] | None = None,
|
|
42
|
+
task_types: Iterable[TaskType] | None = None,
|
|
43
|
+
categories: Iterable[TaskCategory] | None = None,
|
|
44
|
+
modalities: Iterable[Modalities] | None = None,
|
|
45
45
|
exclusive_modality_filter: bool = False,
|
|
46
46
|
exclude_superseded: bool = False,
|
|
47
47
|
exclude_aggregate: bool = False,
|
|
@@ -51,14 +51,14 @@ def filter_tasks(
|
|
|
51
51
|
|
|
52
52
|
@overload
|
|
53
53
|
def filter_tasks(
|
|
54
|
-
tasks:
|
|
54
|
+
tasks: Iterable[type[AbsTask]],
|
|
55
55
|
*,
|
|
56
|
-
languages:
|
|
57
|
-
script:
|
|
58
|
-
domains:
|
|
59
|
-
task_types:
|
|
60
|
-
categories:
|
|
61
|
-
modalities:
|
|
56
|
+
languages: Sequence[str] | None = None,
|
|
57
|
+
script: Sequence[str] | None = None,
|
|
58
|
+
domains: Iterable[TaskDomain] | None = None,
|
|
59
|
+
task_types: Iterable[TaskType] | None = None,
|
|
60
|
+
categories: Iterable[TaskCategory] | None = None,
|
|
61
|
+
modalities: Iterable[Modalities] | None = None,
|
|
62
62
|
exclusive_modality_filter: bool = False,
|
|
63
63
|
exclude_superseded: bool = False,
|
|
64
64
|
exclude_aggregate: bool = False,
|
|
@@ -67,14 +67,14 @@ def filter_tasks(
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
def filter_tasks(
|
|
70
|
-
tasks:
|
|
70
|
+
tasks: Iterable[AbsTask] | Iterable[type[AbsTask]],
|
|
71
71
|
*,
|
|
72
|
-
languages:
|
|
73
|
-
script:
|
|
74
|
-
domains:
|
|
75
|
-
task_types:
|
|
76
|
-
categories:
|
|
77
|
-
modalities:
|
|
72
|
+
languages: Sequence[str] | None = None,
|
|
73
|
+
script: Sequence[str] | None = None,
|
|
74
|
+
domains: Iterable[TaskDomain] | None = None,
|
|
75
|
+
task_types: Iterable[TaskType] | None = None,
|
|
76
|
+
categories: Iterable[TaskCategory] | None = None,
|
|
77
|
+
modalities: Iterable[Modalities] | None = None,
|
|
78
78
|
exclusive_modality_filter: bool = False,
|
|
79
79
|
exclude_superseded: bool = False,
|
|
80
80
|
exclude_aggregate: bool = False,
|
|
@@ -92,7 +92,6 @@ def filter_tasks(
|
|
|
92
92
|
task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included.
|
|
93
93
|
categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list.
|
|
94
94
|
exclude_superseded: A boolean flag to exclude datasets which are superseded by another.
|
|
95
|
-
eval_splits: A list of evaluation splits to include. If None, all splits are included.
|
|
96
95
|
modalities: A list of modalities to include. If None, all modalities are included.
|
|
97
96
|
exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
|
|
98
97
|
task's modalities and ALL task modalities are in filter modalities (exact match).
|
|
@@ -113,12 +112,12 @@ def filter_tasks(
|
|
|
113
112
|
"""
|
|
114
113
|
langs_to_keep = None
|
|
115
114
|
if languages:
|
|
116
|
-
[_check_is_valid_language(lang) for lang in languages]
|
|
115
|
+
[_check_is_valid_language(lang) for lang in languages] # type: ignore[func-returns-value]
|
|
117
116
|
langs_to_keep = set(languages)
|
|
118
117
|
|
|
119
118
|
script_to_keep = None
|
|
120
119
|
if script:
|
|
121
|
-
[_check_is_valid_script(s) for s in script]
|
|
120
|
+
[_check_is_valid_script(s) for s in script] # type: ignore[func-returns-value]
|
|
122
121
|
script_to_keep = set(script)
|
|
123
122
|
|
|
124
123
|
domains_to_keep = None
|
|
@@ -178,4 +177,4 @@ def filter_tasks(
|
|
|
178
177
|
|
|
179
178
|
_tasks.append(t)
|
|
180
179
|
|
|
181
|
-
return _tasks
|
|
180
|
+
return _tasks # type: ignore[return-value] # type checker cannot infer the overload return type
|