mteb 2.5.3__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +27 -21
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +3 -16
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +20 -16
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +2 -2
  32. mteb/cache.py +20 -18
  33. mteb/cli/_display_tasks.py +2 -2
  34. mteb/cli/build_cli.py +5 -5
  35. mteb/cli/generate_model_card.py +6 -4
  36. mteb/deprecated_evaluator.py +56 -43
  37. mteb/evaluate.py +35 -29
  38. mteb/filter_tasks.py +25 -26
  39. mteb/get_tasks.py +25 -27
  40. mteb/languages/language_scripts.py +5 -3
  41. mteb/leaderboard/app.py +1 -1
  42. mteb/load_results.py +12 -12
  43. mteb/models/abs_encoder.py +2 -2
  44. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  45. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  46. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +2 -1
  47. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +30 -13
  48. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  49. mteb/models/get_model_meta.py +8 -1
  50. mteb/models/instruct_wrapper.py +11 -5
  51. mteb/models/model_implementations/andersborges.py +2 -2
  52. mteb/models/model_implementations/blip_models.py +8 -8
  53. mteb/models/model_implementations/bm25.py +1 -1
  54. mteb/models/model_implementations/clip_models.py +3 -3
  55. mteb/models/model_implementations/cohere_models.py +1 -1
  56. mteb/models/model_implementations/cohere_v.py +2 -2
  57. mteb/models/model_implementations/dino_models.py +23 -23
  58. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  59. mteb/models/model_implementations/jina_clip.py +1 -1
  60. mteb/models/model_implementations/jina_models.py +1 -1
  61. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  62. mteb/models/model_implementations/llm2clip_models.py +3 -3
  63. mteb/models/model_implementations/moco_models.py +2 -2
  64. mteb/models/model_implementations/model2vec_models.py +1 -1
  65. mteb/models/model_implementations/nomic_models.py +8 -8
  66. mteb/models/model_implementations/openclip_models.py +7 -7
  67. mteb/models/model_implementations/random_baseline.py +3 -3
  68. mteb/models/model_implementations/rasgaard_models.py +1 -1
  69. mteb/models/model_implementations/repllama_models.py +2 -2
  70. mteb/models/model_implementations/rerankers_custom.py +3 -3
  71. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  72. mteb/models/model_implementations/siglip_models.py +10 -10
  73. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  74. mteb/models/model_implementations/voyage_v.py +4 -4
  75. mteb/models/model_meta.py +11 -12
  76. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +5 -5
  77. mteb/models/search_wrappers.py +22 -10
  78. mteb/models/sentence_transformer_wrapper.py +9 -4
  79. mteb/py.typed +0 -0
  80. mteb/results/benchmark_results.py +25 -19
  81. mteb/results/model_result.py +49 -21
  82. mteb/results/task_result.py +45 -51
  83. mteb/similarity_functions.py +11 -7
  84. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  85. mteb/tasks/classification/est/estonian_valence.py +1 -1
  86. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  87. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  88. mteb/tasks/retrieval/code/code_rag.py +12 -12
  89. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  90. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  91. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  92. mteb/tasks/retrieval/nob/norquad.py +2 -2
  93. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  94. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  95. mteb/types/_result.py +2 -1
  96. mteb/types/statistics.py +9 -3
  97. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
  98. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/RECORD +102 -101
  99. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
  100. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
  101. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
  102. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import warnings
3
+ from collections.abc import Sequence
3
4
  from pathlib import Path
4
5
 
5
6
  from huggingface_hub import ModelCard, ModelCardData, repo_exists
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
13
14
 
14
15
  def generate_model_card(
15
16
  model_name: str,
16
- tasks: list[AbsTask] | None = None,
17
+ tasks: Sequence[AbsTask] | None = None,
17
18
  existing_model_card_id_or_path: str | Path | None = None,
18
19
  results_cache: ResultCache = ResultCache(),
19
20
  output_path: Path = Path("model_card.md"),
@@ -48,8 +49,8 @@ def generate_model_card(
48
49
  for task_result in models_results.task_results:
49
50
  eval_results.extend(task_result.get_hf_eval_results())
50
51
 
51
- existing_model_card_data = (
52
- existing_model_card.data if existing_model_card else ModelCardData()
52
+ existing_model_card_data: ModelCardData = (
53
+ existing_model_card.data if existing_model_card else ModelCardData() # type: ignore[assignment]
53
54
  )
54
55
 
55
56
  if existing_model_card_data.eval_results is None:
@@ -89,7 +90,8 @@ def generate_model_card(
89
90
  benchmark_results, existing_model_card
90
91
  )
91
92
 
92
- if push_to_hub:
93
+ if push_to_hub and existing_model_card_id_or_path:
94
+ existing_model_card_id_or_path = str(existing_model_card_id_or_path)
93
95
  if repo_exists(existing_model_card_id_or_path):
94
96
  existing_model_card.push_to_hub(existing_model_card_id_or_path, token=token)
95
97
  else:
@@ -6,23 +6,23 @@ import os
6
6
  import sys
7
7
  import traceback
8
8
  import warnings
9
- from collections.abc import Iterable
9
+ from collections.abc import Iterable, Sequence
10
10
  from copy import deepcopy
11
11
  from datetime import datetime
12
12
  from itertools import chain
13
13
  from pathlib import Path
14
14
  from time import time
15
- from typing import TYPE_CHECKING, Any
15
+ from typing import TYPE_CHECKING, Any, cast
16
16
 
17
17
  import datasets
18
18
 
19
19
  import mteb
20
20
  from mteb.abstasks import AbsTask
21
+ from mteb.abstasks.aggregated_task import AbsTaskAggregate
21
22
  from mteb.abstasks.task_metadata import TaskCategory, TaskType
22
23
  from mteb.benchmarks import Benchmark
23
24
  from mteb.models import (
24
25
  CrossEncoderWrapper,
25
- EncoderProtocol,
26
26
  ModelMeta,
27
27
  MTEBModels,
28
28
  SentenceTransformerEncoderWrapper,
@@ -53,7 +53,7 @@ class MTEB:
53
53
  )
54
54
  def __init__(
55
55
  self,
56
- tasks: Iterable[AbsTask | Benchmark],
56
+ tasks: Iterable[AbsTask] | Iterable[Benchmark],
57
57
  *,
58
58
  err_logs_path: str = "error_logs.txt",
59
59
  ) -> None:
@@ -64,15 +64,14 @@ class MTEB:
64
64
  `mteb.get_tasks(["task1","task2"]) or `mteb.get_benchmark("MTEB(eng, classic)").
65
65
  err_logs_path: Path to save error logs.
66
66
  """
67
- from mteb.benchmarks import Benchmark
68
-
69
- self.tasks = list(tasks)
70
- if len(self.tasks) > 0 and isinstance(self.tasks[0], Benchmark):
67
+ if isinstance(next(iter(tasks)), Benchmark):
71
68
  self.benchmarks = tasks
72
- self.tasks = list(chain.from_iterable(self.tasks))
69
+ self.tasks = list(chain.from_iterable(cast(Iterable[Benchmark], tasks)))
70
+ elif isinstance(next(iter(tasks)), AbsTask):
71
+ self.tasks = list(cast(Iterable[AbsTask], tasks))
73
72
 
74
73
  self.err_logs_path = Path(err_logs_path)
75
- self.last_evaluated_splits = {}
74
+ self._last_evaluated_splits: dict[str, list[str]] = {}
76
75
 
77
76
  @property
78
77
  def available_tasks(self) -> list[str]:
@@ -85,7 +84,7 @@ class MTEB:
85
84
  return sorted({x.metadata.type for x in self.tasks})
86
85
 
87
86
  @property
88
- def available_task_categories(self) -> set[TaskCategory]:
87
+ def available_task_categories(self) -> set[TaskCategory | None]:
89
88
  """Set of available task categories."""
90
89
  return {x.metadata.category for x in self.tasks}
91
90
 
@@ -232,13 +231,14 @@ class MTEB:
232
231
  merged_kg_co2_emissions = None
233
232
  if existing_kg_co2_emissions and new_kg_co2_emissions:
234
233
  merged_kg_co2_emissions = existing_kg_co2_emissions + new_kg_co2_emissions
234
+ existing_evaluation_time = existing_results.evaluation_time or 0
235
+ new_evaluation_time = new_results.evaluation_time or 0
235
236
  merged_results = TaskResult(
236
237
  dataset_revision=new_results.dataset_revision,
237
238
  task_name=new_results.task_name,
238
239
  mteb_version=new_results.mteb_version,
239
240
  scores=merged_scores,
240
- evaluation_time=existing_results.evaluation_time
241
- + new_results.evaluation_time,
241
+ evaluation_time=existing_evaluation_time + new_evaluation_time,
242
242
  kg_co2_emissions=merged_kg_co2_emissions,
243
243
  )
244
244
 
@@ -307,13 +307,16 @@ class MTEB:
307
307
  elif verbosity == 3:
308
308
  datasets.logging.set_verbosity(logging.DEBUG)
309
309
 
310
- meta = self.create_model_meta(model)
311
- output_path = self._create_output_folder(meta, output_folder)
312
-
310
+ mteb_model: MTEBModels
313
311
  if isinstance(model, SentenceTransformer):
314
- model = SentenceTransformerEncoderWrapper(model)
312
+ mteb_model = SentenceTransformerEncoderWrapper(model)
315
313
  elif isinstance(model, CrossEncoder):
316
- model = CrossEncoderWrapper(model)
314
+ mteb_model = CrossEncoderWrapper(model)
315
+ else:
316
+ mteb_model = cast(MTEBModels, model)
317
+
318
+ meta = self.create_model_meta(mteb_model)
319
+ output_path = self._create_output_folder(meta, output_folder)
317
320
 
318
321
  # Disable co2_tracker for API models
319
322
  if "API" in meta.framework:
@@ -334,7 +337,7 @@ class MTEB:
334
337
  ) # save them in case we re-use the object (e.g. for reranking)
335
338
 
336
339
  # To evaluate missing splits, we keep track of the task name and the corresponding splits.
337
- self.last_evaluated_splits = {}
340
+ self._last_evaluated_splits = {}
338
341
 
339
342
  while len(self.tasks) > 0:
340
343
  task = self.tasks[0]
@@ -343,9 +346,10 @@ class MTEB:
343
346
  )
344
347
 
345
348
  if task.is_aggregate:
346
- self_ = MTEB(tasks=task.metadata.tasks)
347
- task_results = self_.run(
348
- model,
349
+ aggregated_task = cast(AbsTaskAggregate, task)
350
+ self_ = MTEB(tasks=aggregated_task.metadata.tasks)
351
+ aggregated_task_results = self_.run(
352
+ mteb_model,
349
353
  verbosity=verbosity - 1,
350
354
  output_folder=output_folder,
351
355
  eval_splits=eval_splits,
@@ -356,12 +360,15 @@ class MTEB:
356
360
  encode_kwargs=encode_kwargs,
357
361
  **kwargs,
358
362
  )
359
- new_results = task.combine_task_results(task_results)
363
+ new_results = aggregated_task.combine_task_results(
364
+ aggregated_task_results
365
+ )
360
366
  evaluation_results.append(new_results)
361
367
 
362
368
  if output_path:
363
- save_path = output_path / f"{task.metadata.name}.json"
364
- new_results.to_disk(save_path)
369
+ new_results.to_disk(
370
+ output_path / f"{aggregated_task.metadata.name}.json"
371
+ )
365
372
  del self.tasks[0]
366
373
  continue
367
374
 
@@ -383,7 +390,7 @@ class MTEB:
383
390
  task_subsets = task.hf_subsets
384
391
 
385
392
  existing_results = None
386
- save_path = None
393
+ save_path: Path | None = None
387
394
  final_splits_to_run = task_eval_splits
388
395
  missing_evaluations = self._get_missing_evaluations(
389
396
  existing_results,
@@ -433,7 +440,7 @@ class MTEB:
433
440
  logger.info(
434
441
  f"No splits to evaluate for {task.metadata.name}. Skipping evaluation."
435
442
  )
436
- self.last_evaluated_splits[task.metadata.name] = []
443
+ self._last_evaluated_splits[task.metadata.name] = []
437
444
  del self.tasks[0]
438
445
  continue
439
446
 
@@ -441,11 +448,11 @@ class MTEB:
441
448
  task.check_if_dataset_is_superseded()
442
449
  task.load_data()
443
450
 
444
- task_results = {}
451
+ task_results: dict[str, dict[str, dict[str, Any]]] = {}
445
452
  evaluation_time = 0
446
453
  kg_co2_emissions: int | None = 0 if co2_tracker else None
447
454
 
448
- self.last_evaluated_splits[task.metadata.name] = []
455
+ self._last_evaluated_splits[task.metadata.name] = []
449
456
 
450
457
  for split in final_splits_to_run:
451
458
  info = missing_evaluations[split]
@@ -466,7 +473,9 @@ class MTEB:
466
473
 
467
474
  if co2_tracker:
468
475
  try:
469
- from codecarbon import EmissionsTracker
476
+ from codecarbon import ( # type: ignore[import-untyped]
477
+ EmissionsTracker,
478
+ )
470
479
  except ImportError:
471
480
  raise ImportError(
472
481
  "codecarbon is not installed. Please install it using `pip install 'mteb[codecarbon]'` to track CO₂ emissions."
@@ -482,7 +491,7 @@ class MTEB:
482
491
  ) as tracker:
483
492
  results, tick, tock = self._run_eval(
484
493
  task,
485
- model,
494
+ mteb_model,
486
495
  split,
487
496
  encode_kwargs=encode_kwargs,
488
497
  subsets_to_run=subsets_to_run,
@@ -495,7 +504,7 @@ class MTEB:
495
504
  else:
496
505
  results, tick, tock = self._run_eval(
497
506
  task,
498
- model,
507
+ mteb_model,
499
508
  split,
500
509
  subsets_to_run=subsets_to_run,
501
510
  encode_kwargs=encode_kwargs,
@@ -511,25 +520,25 @@ class MTEB:
511
520
  if verbosity >= 1:
512
521
  logger.info(f"Scores: {task_results[split]}")
513
522
 
514
- self.last_evaluated_splits[task.metadata.name].append(split)
523
+ self._last_evaluated_splits[task.metadata.name].append(split)
515
524
 
516
525
  # Create new TaskResult
517
526
  new_results = TaskResult.from_task_results(
518
527
  task,
519
- task_results,
528
+ task_results, # type: ignore[arg-type]
520
529
  evaluation_time=evaluation_time,
521
530
  kg_co2_emissions=kg_co2_emissions,
522
531
  )
523
532
 
524
533
  # Merge with existing if needed
525
- if output_path and save_path.exists():
534
+ if output_path and save_path and save_path.exists():
526
535
  existing_results = TaskResult.from_disk(save_path)
527
536
  if existing_results:
528
537
  merged_results = self._merge_results(existing_results, new_results)
529
538
  else:
530
539
  merged_results = new_results
531
540
 
532
- if output_path:
541
+ if output_path and save_path:
533
542
  merged_results.to_disk(save_path)
534
543
 
535
544
  evaluation_results.append(merged_results)
@@ -556,7 +565,7 @@ class MTEB:
556
565
  def create_model_meta(model: MTEBModels) -> ModelMeta:
557
566
  """Create a ModelMeta object for the given model."""
558
567
  if hasattr(model, "mteb_model_meta") and model.mteb_model_meta is not None:
559
- meta = model.mteb_model_meta # type: ignore
568
+ meta = model.mteb_model_meta
560
569
  else:
561
570
  meta = MTEB._get_model_meta(model)
562
571
 
@@ -582,7 +591,11 @@ class MTEB:
582
591
  if output_folder is None:
583
592
  return None
584
593
 
585
- model_revision: str = model_meta.revision # type: ignore
594
+ model_revision: str = (
595
+ model_meta.revision
596
+ if model_meta.revision is not None
597
+ else "no_revision_available"
598
+ )
586
599
  model_path_name = model_meta.model_name_as_path()
587
600
 
588
601
  output_path = Path(output_folder) / model_path_name / model_revision
@@ -604,15 +617,15 @@ class MTEB:
604
617
  Tasks with empty lists indicate that results already existed and no splits were evaluated.
605
618
  """
606
619
  return deepcopy(
607
- {task: list(splits) for task, splits in self.last_evaluated_splits.items()}
620
+ {task: list(splits) for task, splits in self._last_evaluated_splits.items()}
608
621
  )
609
622
 
610
623
  @staticmethod
611
624
  def _get_missing_evaluations(
612
625
  existing_results: TaskResult | None,
613
- task_eval_splits: list[str],
614
- task_eval_langs: list[str],
615
- eval_subsets: list[str] | None,
626
+ task_eval_splits: Sequence[str],
627
+ task_eval_langs: Sequence[str],
628
+ eval_subsets: Sequence[str] | None,
616
629
  ) -> dict[str, dict[str, Any]]:
617
630
  """Return a dictionary for each split, indicating if the whole split is missing and which subsets are missing."""
618
631
  missing_evaluations = {
@@ -661,7 +674,7 @@ class MTEB:
661
674
  return missing_evaluations
662
675
 
663
676
  @staticmethod
664
- def _get_model_meta(model: EncoderProtocol) -> ModelMeta:
677
+ def _get_model_meta(model: MTEBModels) -> ModelMeta:
665
678
  from sentence_transformers import CrossEncoder, SentenceTransformer
666
679
 
667
680
  if isinstance(model, CrossEncoder):
mteb/evaluate.py CHANGED
@@ -14,11 +14,10 @@ from mteb._helpful_enum import HelpfulStrEnum
14
14
  from mteb.abstasks import AbsTaskRetrieval
15
15
  from mteb.abstasks.abstask import AbsTask
16
16
  from mteb.abstasks.aggregated_task import AbsTaskAggregate
17
+ from mteb.benchmarks.benchmark import Benchmark
17
18
  from mteb.cache import ResultCache
18
19
  from mteb.models.model_meta import ModelMeta
19
20
  from mteb.models.models_protocols import (
20
- CrossEncoderProtocol,
21
- EncoderProtocol,
22
21
  MTEBModels,
23
22
  )
24
23
  from mteb.models.sentence_transformer_wrapper import (
@@ -58,27 +57,26 @@ def _sanitize_model(
58
57
  ) -> tuple[MTEBModels | ModelMeta, ModelMeta, ModelName, Revision]:
59
58
  from sentence_transformers import CrossEncoder, SentenceTransformer
60
59
 
60
+ wrapped_model: MTEBModels | ModelMeta
61
61
  if isinstance(model, SentenceTransformer):
62
- _mdl = SentenceTransformerEncoderWrapper(model)
63
- meta = _mdl.mteb_model_meta
64
- _mdl = cast(EncoderProtocol, _mdl)
65
- model = _mdl
62
+ wrapped_model = SentenceTransformerEncoderWrapper(model)
63
+ meta = wrapped_model.mteb_model_meta
66
64
  elif isinstance(model, CrossEncoder):
67
- _mdl = CrossEncoderWrapper(model)
68
- _mdl = cast(CrossEncoderProtocol, _mdl)
69
- meta = _mdl.mteb_model_meta
70
- model = _mdl
65
+ wrapped_model = CrossEncoderWrapper(model)
66
+ meta = wrapped_model.mteb_model_meta
71
67
  elif hasattr(model, "mteb_model_meta"):
72
- meta = model.mteb_model_meta # type: ignore[attr-defined]
68
+ meta = getattr(model, "mteb_model_meta")
73
69
  if not isinstance(meta, ModelMeta):
74
- meta = ModelMeta.from_hub(None)
70
+ meta = ModelMeta._from_hub(None)
71
+ wrapped_model = cast(MTEBModels | ModelMeta, model)
75
72
  else:
76
- meta = ModelMeta.from_hub(None) if not isinstance(model, ModelMeta) else model
73
+ meta = ModelMeta._from_hub(None) if not isinstance(model, ModelMeta) else model
74
+ wrapped_model = meta
77
75
 
78
76
  model_name = cast(str, meta.name)
79
77
  model_revision = cast(str, meta.revision)
80
78
 
81
- return model, meta, model_name, model_revision
79
+ return wrapped_model, meta, model_name, model_revision
82
80
 
83
81
 
84
82
  def _evaluate_task(
@@ -124,7 +122,8 @@ def _evaluate_task(
124
122
  prediction_folder=prediction_folder,
125
123
  public_only=public_only,
126
124
  )
127
- result.kg_co2_emissions = tracker.final_emissions
125
+ if isinstance(result, TaskResult):
126
+ result.kg_co2_emissions = tracker.final_emissions
128
127
  return result
129
128
 
130
129
  task_results = {}
@@ -150,7 +149,7 @@ def _evaluate_task(
150
149
  if public_only is False:
151
150
  raise e
152
151
 
153
- evaluation_time = 0
152
+ evaluation_time = 0.0
154
153
 
155
154
  for split, hf_subsets in splits.items():
156
155
  tick = time()
@@ -197,12 +196,18 @@ def _check_model_modalities(
197
196
  return
198
197
 
199
198
  model_modalities = set(model.modalities)
199
+ check_tasks: Iterable[AbsTask] = []
200
200
  if isinstance(tasks, AbsTask):
201
- tasks = [tasks]
201
+ check_tasks = [tasks]
202
+ elif isinstance(tasks, Benchmark):
203
+ benchmark = cast(Benchmark, tasks)
204
+ check_tasks = benchmark.tasks
205
+ else:
206
+ check_tasks = cast(Iterable[AbsTask], tasks)
202
207
 
203
208
  warnings, errors = [], []
204
209
 
205
- for task in tasks:
210
+ for task in check_tasks:
206
211
  # only retrieval tasks have different modalities for query and document and can be run with partial overlaps
207
212
  if isinstance(task, AbsTaskRetrieval):
208
213
  query_mods = set(task.metadata.get_modalities(PromptType.query))
@@ -335,10 +340,10 @@ def evaluate(
335
340
 
336
341
  # AbsTaskAggregate is a special case where we have to run multiple tasks and combine the results
337
342
  if isinstance(tasks, AbsTaskAggregate):
338
- task = cast(AbsTaskAggregate, tasks)
343
+ aggregated_task = cast(AbsTaskAggregate, tasks)
339
344
  results = evaluate(
340
345
  model,
341
- task.metadata.tasks,
346
+ aggregated_task.metadata.tasks,
342
347
  co2_tracker=co2_tracker,
343
348
  raise_error=raise_error,
344
349
  encode_kwargs=encode_kwargs,
@@ -348,17 +353,18 @@ def evaluate(
348
353
  show_progress_bar=show_progress_bar,
349
354
  public_only=public_only,
350
355
  )
351
- result = task.combine_task_results(results.task_results)
356
+ combined_results = aggregated_task.combine_task_results(results.task_results)
352
357
  return ModelResult(
353
358
  model_name=results.model_name,
354
359
  model_revision=results.model_revision,
355
- task_results=[result],
360
+ task_results=[combined_results],
356
361
  )
357
362
 
358
363
  if isinstance(tasks, AbsTask):
359
364
  task = tasks
360
365
  else:
361
- results = []
366
+ tasks = cast(Iterable[AbsTask], tasks)
367
+ evaluate_results = []
362
368
  exceptions = []
363
369
  tasks_tqdm = tqdm(
364
370
  tasks,
@@ -379,23 +385,23 @@ def evaluate(
379
385
  show_progress_bar=False,
380
386
  public_only=public_only,
381
387
  )
382
- results.extend(_res.task_results)
388
+ evaluate_results.extend(_res.task_results)
383
389
  if _res.exceptions:
384
390
  exceptions.extend(_res.exceptions)
385
391
  return ModelResult(
386
392
  model_name=_res.model_name,
387
393
  model_revision=_res.model_revision,
388
- task_results=results,
394
+ task_results=evaluate_results,
389
395
  exceptions=exceptions,
390
396
  )
391
397
 
392
398
  overwrite_strategy = OverwriteStrategy.from_str(overwrite_strategy)
393
399
 
394
- existing_results = None
400
+ existing_results: TaskResult | None = None
395
401
  if cache and overwrite_strategy != OverwriteStrategy.ALWAYS:
396
- results = cache.load_task_result(task.metadata.name, meta)
397
- if results:
398
- existing_results = results
402
+ cache_results = cache.load_task_result(task.metadata.name, meta)
403
+ if cache_results:
404
+ existing_results = cache_results
399
405
 
400
406
  if (
401
407
  existing_results
mteb/filter_tasks.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """This script contains functions that are used to get an overview of the MTEB benchmark."""
2
2
 
3
3
  import logging
4
- from collections.abc import Sequence
4
+ from collections.abc import Iterable, Sequence
5
5
  from typing import overload
6
6
 
7
7
  from mteb.abstasks import (
@@ -34,14 +34,14 @@ def _check_is_valid_language(lang: str) -> None:
34
34
 
35
35
  @overload
36
36
  def filter_tasks(
37
- tasks: Sequence[AbsTask],
37
+ tasks: Iterable[AbsTask],
38
38
  *,
39
- languages: list[str] | None = None,
40
- script: list[str] | None = None,
41
- domains: list[TaskDomain] | None = None,
42
- task_types: list[TaskType] | None = None, # type: ignore
43
- categories: list[TaskCategory] | None = None,
44
- modalities: list[Modalities] | None = None,
39
+ languages: Sequence[str] | None = None,
40
+ script: Sequence[str] | None = None,
41
+ domains: Iterable[TaskDomain] | None = None,
42
+ task_types: Iterable[TaskType] | None = None,
43
+ categories: Iterable[TaskCategory] | None = None,
44
+ modalities: Iterable[Modalities] | None = None,
45
45
  exclusive_modality_filter: bool = False,
46
46
  exclude_superseded: bool = False,
47
47
  exclude_aggregate: bool = False,
@@ -51,14 +51,14 @@ def filter_tasks(
51
51
 
52
52
  @overload
53
53
  def filter_tasks(
54
- tasks: Sequence[type[AbsTask]],
54
+ tasks: Iterable[type[AbsTask]],
55
55
  *,
56
- languages: list[str] | None = None,
57
- script: list[str] | None = None,
58
- domains: list[TaskDomain] | None = None,
59
- task_types: list[TaskType] | None = None, # type: ignore
60
- categories: list[TaskCategory] | None = None,
61
- modalities: list[Modalities] | None = None,
56
+ languages: Sequence[str] | None = None,
57
+ script: Sequence[str] | None = None,
58
+ domains: Iterable[TaskDomain] | None = None,
59
+ task_types: Iterable[TaskType] | None = None,
60
+ categories: Iterable[TaskCategory] | None = None,
61
+ modalities: Iterable[Modalities] | None = None,
62
62
  exclusive_modality_filter: bool = False,
63
63
  exclude_superseded: bool = False,
64
64
  exclude_aggregate: bool = False,
@@ -67,14 +67,14 @@ def filter_tasks(
67
67
 
68
68
 
69
69
  def filter_tasks(
70
- tasks: Sequence[AbsTask] | Sequence[type[AbsTask]],
70
+ tasks: Iterable[AbsTask] | Iterable[type[AbsTask]],
71
71
  *,
72
- languages: list[str] | None = None,
73
- script: list[str] | None = None,
74
- domains: list[TaskDomain] | None = None,
75
- task_types: list[TaskType] | None = None, # type: ignore
76
- categories: list[TaskCategory] | None = None,
77
- modalities: list[Modalities] | None = None,
72
+ languages: Sequence[str] | None = None,
73
+ script: Sequence[str] | None = None,
74
+ domains: Iterable[TaskDomain] | None = None,
75
+ task_types: Iterable[TaskType] | None = None,
76
+ categories: Iterable[TaskCategory] | None = None,
77
+ modalities: Iterable[Modalities] | None = None,
78
78
  exclusive_modality_filter: bool = False,
79
79
  exclude_superseded: bool = False,
80
80
  exclude_aggregate: bool = False,
@@ -92,7 +92,6 @@ def filter_tasks(
92
92
  task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included.
93
93
  categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list.
94
94
  exclude_superseded: A boolean flag to exclude datasets which are superseded by another.
95
- eval_splits: A list of evaluation splits to include. If None, all splits are included.
96
95
  modalities: A list of modalities to include. If None, all modalities are included.
97
96
  exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
98
97
  task's modalities and ALL task modalities are in filter modalities (exact match).
@@ -113,12 +112,12 @@ def filter_tasks(
113
112
  """
114
113
  langs_to_keep = None
115
114
  if languages:
116
- [_check_is_valid_language(lang) for lang in languages]
115
+ [_check_is_valid_language(lang) for lang in languages] # type: ignore[func-returns-value]
117
116
  langs_to_keep = set(languages)
118
117
 
119
118
  script_to_keep = None
120
119
  if script:
121
- [_check_is_valid_script(s) for s in script]
120
+ [_check_is_valid_script(s) for s in script] # type: ignore[func-returns-value]
122
121
  script_to_keep = set(script)
123
122
 
124
123
  domains_to_keep = None
@@ -178,4 +177,4 @@ def filter_tasks(
178
177
 
179
178
  _tasks.append(t)
180
179
 
181
- return _tasks
180
+ return _tasks # type: ignore[return-value] # type checker cannot infer the overload return type