mteb 2.5.2__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +33 -27
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +7 -26
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +22 -19
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +2 -2
  32. mteb/cache.py +27 -22
  33. mteb/cli/_display_tasks.py +2 -2
  34. mteb/cli/build_cli.py +15 -10
  35. mteb/cli/generate_model_card.py +10 -7
  36. mteb/deprecated_evaluator.py +60 -46
  37. mteb/evaluate.py +39 -30
  38. mteb/filter_tasks.py +25 -26
  39. mteb/get_tasks.py +29 -30
  40. mteb/languages/language_scripts.py +5 -3
  41. mteb/leaderboard/app.py +1 -1
  42. mteb/load_results.py +12 -12
  43. mteb/models/abs_encoder.py +7 -5
  44. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  45. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  46. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  47. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  48. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  49. mteb/models/get_model_meta.py +8 -1
  50. mteb/models/instruct_wrapper.py +11 -5
  51. mteb/models/model_implementations/andersborges.py +2 -2
  52. mteb/models/model_implementations/blip_models.py +8 -8
  53. mteb/models/model_implementations/bm25.py +1 -1
  54. mteb/models/model_implementations/clip_models.py +3 -3
  55. mteb/models/model_implementations/cohere_models.py +1 -1
  56. mteb/models/model_implementations/cohere_v.py +2 -2
  57. mteb/models/model_implementations/dino_models.py +23 -23
  58. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  59. mteb/models/model_implementations/gme_v_models.py +4 -3
  60. mteb/models/model_implementations/jina_clip.py +1 -1
  61. mteb/models/model_implementations/jina_models.py +1 -1
  62. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  63. mteb/models/model_implementations/llm2clip_models.py +3 -3
  64. mteb/models/model_implementations/mcinext_models.py +4 -1
  65. mteb/models/model_implementations/moco_models.py +2 -2
  66. mteb/models/model_implementations/model2vec_models.py +1 -1
  67. mteb/models/model_implementations/nomic_models.py +8 -8
  68. mteb/models/model_implementations/openclip_models.py +7 -7
  69. mteb/models/model_implementations/random_baseline.py +3 -3
  70. mteb/models/model_implementations/rasgaard_models.py +1 -1
  71. mteb/models/model_implementations/repllama_models.py +2 -2
  72. mteb/models/model_implementations/rerankers_custom.py +3 -3
  73. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  74. mteb/models/model_implementations/siglip_models.py +10 -10
  75. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  76. mteb/models/model_implementations/voyage_v.py +4 -4
  77. mteb/models/model_meta.py +14 -13
  78. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
  79. mteb/models/search_wrappers.py +26 -12
  80. mteb/models/sentence_transformer_wrapper.py +19 -14
  81. mteb/py.typed +0 -0
  82. mteb/results/benchmark_results.py +28 -20
  83. mteb/results/model_result.py +52 -22
  84. mteb/results/task_result.py +55 -58
  85. mteb/similarity_functions.py +11 -7
  86. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  87. mteb/tasks/classification/est/estonian_valence.py +1 -1
  88. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  89. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  90. mteb/tasks/retrieval/code/code_rag.py +12 -12
  91. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  92. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  93. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  94. mteb/tasks/retrieval/nob/norquad.py +2 -2
  95. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  96. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  97. mteb/types/_result.py +2 -1
  98. mteb/types/statistics.py +9 -3
  99. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
  100. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/RECORD +104 -103
  101. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
  102. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
  103. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
  104. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ from mteb.abstasks._statistics_calculation import (
18
18
  )
19
19
  from mteb.abstasks.abstask import AbsTask
20
20
  from mteb.models.model_meta import ScoringFunction
21
- from mteb.models.models_protocols import EncoderProtocol
21
+ from mteb.models.models_protocols import EncoderProtocol, MTEBModels
22
22
  from mteb.types import PromptType
23
23
  from mteb.types.statistics import (
24
24
  ImageStatistics,
@@ -44,8 +44,8 @@ class PairClassificationDescriptiveStatistics(SplitDescriptiveStatistics):
44
44
  """
45
45
 
46
46
  num_samples: int
47
- number_of_characters: int
48
- unique_pairs: int
47
+ number_of_characters: int | None
48
+ unique_pairs: int | None
49
49
 
50
50
  text1_statistics: TextStatistics | None
51
51
  image1_statistics: ImageStatistics | None
@@ -79,7 +79,7 @@ class AbsTaskPairClassification(AbsTask):
79
79
 
80
80
  def _evaluate_subset(
81
81
  self,
82
- model: EncoderProtocol,
82
+ model: MTEBModels,
83
83
  data_split: Dataset,
84
84
  *,
85
85
  hf_split: str,
@@ -88,6 +88,9 @@ class AbsTaskPairClassification(AbsTask):
88
88
  prediction_folder: Path | None = None,
89
89
  **kwargs,
90
90
  ) -> dict[str, float]:
91
+ if not isinstance(model, EncoderProtocol):
92
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
93
+
91
94
  if self.metadata.modalities == ["text"]:
92
95
  # for compatibility with v1 version where datasets were stored in a single row
93
96
  data_split = data_split[0] if len(data_split) == 1 else data_split
@@ -120,7 +123,7 @@ class AbsTaskPairClassification(AbsTask):
120
123
  self, similarity_scores: PairClassificationDistances, labels: list[int]
121
124
  ) -> dict[str, float]:
122
125
  logger.info("Computing metrics...")
123
- labels = np.asarray(labels)
126
+ np_labels = np.asarray(labels)
124
127
  output_scores = {}
125
128
  max_scores = defaultdict(list)
126
129
  for short_name, scores, reverse in [
@@ -142,7 +145,7 @@ class AbsTaskPairClassification(AbsTask):
142
145
  ],
143
146
  [ScoringFunction.DOT_PRODUCT.value, similarity_scores["dot_scores"], True],
144
147
  ]:
145
- metrics = self._compute_metrics_values(scores, labels, reverse)
148
+ metrics = self._compute_metrics_values(scores, np_labels, reverse) # type: ignore[arg-type]
146
149
  for metric_name, metric_value in metrics.items():
147
150
  output_scores[f"{short_name}_{metric_name}"] = metric_value
148
151
  max_scores[metric_name].append(metric_value)
@@ -237,6 +240,12 @@ class AbsTaskPairClassification(AbsTask):
237
240
 
238
241
  def _push_dataset_to_hub(self, repo_name: str) -> None:
239
242
  # previously pair classification datasets were stored in a single row
243
+ if self.dataset is None:
244
+ # overall this shouldn't happen as we check for dataset before pushing to hub
245
+ # added here for type checking purposes
246
+ raise RuntimeError(
247
+ "Dataset not loaded. To load dataset run `task.load_data()`."
248
+ )
240
249
  if self.metadata.is_multilingual:
241
250
  for subset in self.dataset:
242
251
  for split in self.dataset[subset]:
@@ -290,13 +299,13 @@ class AbsTaskPairClassification(AbsTask):
290
299
  )
291
300
 
292
301
  def _find_best_acc_and_threshold(
293
- self, scores: np.ndarray, labels: np.ndarray, high_score_more_similar: bool
302
+ self, scores: list[float], labels: np.ndarray, high_score_more_similar: bool
294
303
  ) -> tuple[float, float]:
295
304
  rows = list(zip(scores, labels))
296
305
  rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)
297
306
 
298
307
  max_acc = 0
299
- best_threshold = -1
308
+ best_threshold = -1.0
300
309
  positive_so_far = 0
301
310
  remaining_negatives = sum(np.array(labels) == 0)
302
311
 
@@ -323,7 +332,7 @@ class AbsTaskPairClassification(AbsTask):
323
332
 
324
333
  rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)
325
334
 
326
- best_f1 = best_precision = best_recall = 0
335
+ best_f1 = best_precision = best_recall = 0.0
327
336
  threshold = 0
328
337
  nextract = 0
329
338
  ncorrect = 0
@@ -87,7 +87,7 @@ class AbsTaskRegression(AbsTaskClassification):
87
87
  Full details of api in [`SklearnModelProtocol`][mteb._evaluators.sklearn_evaluator.SklearnModelProtocol].
88
88
  """
89
89
 
90
- evaluator: type[SklearnModelProtocol] = SklearnEvaluator
90
+ evaluator: type[SklearnEvaluator] = SklearnEvaluator
91
91
  evaluator_model: SklearnModelProtocol = LinearRegression(n_jobs=-1)
92
92
 
93
93
  train_split: str = "train"
@@ -113,7 +113,7 @@ class AbsTaskRegression(AbsTaskClassification):
113
113
  )["train"]
114
114
  return train_split_sampled, []
115
115
 
116
- def _calculate_scores(
116
+ def _calculate_scores( # type: ignore[override]
117
117
  self,
118
118
  y_test: np.ndarray | list[int],
119
119
  y_pred: np.ndarray,
@@ -183,7 +183,7 @@ class AbsTaskRegression(AbsTaskClassification):
183
183
 
184
184
  return dataset_dict
185
185
 
186
- def _calculate_descriptive_statistics_from_split(
186
+ def _calculate_descriptive_statistics_from_split( # type: ignore[override]
187
187
  self, split: str, hf_subset: str | None = None, compute_overall: bool = False
188
188
  ) -> RegressionDescriptiveStatistics:
189
189
  train_text = []
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from collections import defaultdict
4
- from collections.abc import Callable, Sequence
4
+ from collections.abc import Callable, Mapping, Sequence
5
5
  from pathlib import Path
6
6
  from time import time
7
7
  from typing import Any, Literal
@@ -286,7 +286,7 @@ class AbsTaskRetrieval(AbsTask):
286
286
  encode_kwargs: dict[str, Any],
287
287
  prediction_folder: Path | None = None,
288
288
  **kwargs,
289
- ) -> dict[HFSubset, ScoresDict]:
289
+ ) -> Mapping[HFSubset, ScoresDict]:
290
290
  """Evaluate the model on the retrieval task.
291
291
 
292
292
  Args:
@@ -357,6 +357,8 @@ class AbsTaskRetrieval(AbsTask):
357
357
  **kwargs,
358
358
  )
359
359
 
360
+ search_model: SearchProtocol
361
+
360
362
  if isinstance(model, EncoderProtocol) and not isinstance(model, SearchProtocol):
361
363
  search_model = SearchEncoderWrapper(model)
362
364
  elif isinstance(model, CrossEncoderProtocol):
@@ -578,11 +580,12 @@ class AbsTaskRetrieval(AbsTask):
578
580
  if isinstance(data[split][subset_item], Dataset):
579
581
  sections[split] = data[split][subset_item]
580
582
  elif converter is not None:
583
+ subset_data = data[split][subset_item]
584
+ if subset_data is None:
585
+ continue
586
+
581
587
  sections[split] = Dataset.from_list(
582
- [
583
- converter(idx, item)
584
- for idx, item in data[split][subset_item].items()
585
- ]
588
+ [converter(idx, item) for idx, item in subset_data.items()]
586
589
  )
587
590
  else:
588
591
  raise ValueError(
@@ -680,7 +683,7 @@ class AbsTaskRetrieval(AbsTask):
680
683
 
681
684
  top_k_sorted = defaultdict(list)
682
685
  for query_id, values in top_ranked.items():
683
- sorted_keys = sorted(values, key=values.get, reverse=True)
686
+ sorted_keys = sorted(values, key=lambda k: values[k], reverse=True)
684
687
  top_k_sorted[query_id] = sorted_keys[: self._top_k]
685
688
 
686
689
  self.dataset[subset][split]["top_ranked"] = top_k_sorted
@@ -688,10 +691,10 @@ class AbsTaskRetrieval(AbsTask):
688
691
 
689
692
 
690
693
  def _process_relevant_docs(
691
- collection: dict[str, dict[str, float]],
694
+ collection: Mapping[str, Mapping[str, int]],
692
695
  hf_subset: str,
693
696
  split: str,
694
- ) -> dict[str, dict[str, float]]:
697
+ ) -> dict[str, dict[str, int]]:
695
698
  """Collections can contain overlapping ids in different splits. Prepend split and subset to avoid this
696
699
 
697
700
  Returns:
mteb/abstasks/sts.py CHANGED
@@ -7,7 +7,7 @@ from scipy.stats import pearsonr, spearmanr
7
7
 
8
8
  from mteb._evaluators import AnySTSEvaluator
9
9
  from mteb._evaluators.any_sts_evaluator import STSEvaluatorScores
10
- from mteb.models import EncoderProtocol
10
+ from mteb.models import EncoderProtocol, MTEBModels
11
11
  from mteb.types import PromptType
12
12
  from mteb.types.statistics import (
13
13
  ImageStatistics,
@@ -103,7 +103,7 @@ class AbsTaskSTS(AbsTask):
103
103
 
104
104
  def _evaluate_subset(
105
105
  self,
106
- model: EncoderProtocol,
106
+ model: MTEBModels,
107
107
  data_split: Dataset,
108
108
  encode_kwargs: dict[str, Any],
109
109
  hf_split: str,
@@ -111,6 +111,9 @@ class AbsTaskSTS(AbsTask):
111
111
  prediction_folder: Path | None = None,
112
112
  **kwargs: Any,
113
113
  ) -> STSMetrics:
114
+ if not isinstance(model, EncoderProtocol):
115
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
116
+
114
117
  normalized_scores = list(map(self._normalize, data_split["score"]))
115
118
  data_split = data_split.select_columns(list(self.column_names))
116
119
 
@@ -142,7 +145,7 @@ class AbsTaskSTS(AbsTask):
142
145
  ) -> STSMetrics:
143
146
  def compute_corr(x: list[float], y: list[float]) -> tuple[float, float]:
144
147
  """Return (pearson, spearman) correlations between x and y."""
145
- return pearsonr(x, y)[0], spearmanr(x, y)[0]
148
+ return float(pearsonr(x, y)[0]), float(spearmanr(x, y)[0])
146
149
 
147
150
  cosine_pearson, cosine_spearman = compute_corr(
148
151
  normalized_scores, scores["cosine_scores"]
@@ -2,9 +2,10 @@ import json
2
2
  import logging
3
3
  from collections.abc import Sequence
4
4
  from pathlib import Path
5
- from typing import Any, Literal
5
+ from typing import Any, Literal, cast
6
6
 
7
7
  from huggingface_hub import (
8
+ CardData,
8
9
  DatasetCard,
9
10
  DatasetCardData,
10
11
  constants,
@@ -150,7 +151,7 @@ _TASK_TYPE = (
150
151
  "InstructionReranking",
151
152
  ) + MIEB_TASK_TYPE
152
153
 
153
- TaskType = Literal[_TASK_TYPE]
154
+ TaskType = Literal[_TASK_TYPE] # type: ignore[valid-type]
154
155
  """The type of the task. E.g. includes "Classification", "Retrieval" and "Clustering"."""
155
156
 
156
157
 
@@ -192,8 +193,10 @@ AnnotatorType = Literal[
192
193
  """The type of the annotators. Is often important for understanding the quality of a dataset."""
193
194
 
194
195
 
195
- PromptDict = TypedDict(
196
- "PromptDict", {prompt_type.value: str for prompt_type in PromptType}, total=False
196
+ PromptDict = TypedDict( # type: ignore[misc]
197
+ "PromptDict",
198
+ {prompt_type.value: str for prompt_type in PromptType},
199
+ total=False,
197
200
  )
198
201
  """A dictionary containing the prompt used for the task.
199
202
 
@@ -365,7 +368,7 @@ class TaskMetadata(BaseModel):
365
368
  """Return a dictionary mapping huggingface subsets to languages."""
366
369
  if isinstance(self.eval_langs, dict):
367
370
  return self.eval_langs
368
- return {"default": self.eval_langs} # type: ignore
371
+ return {"default": cast(list[str], self.eval_langs)}
369
372
 
370
373
  @property
371
374
  def intext_citation(self, include_cite: bool = True) -> str:
@@ -376,9 +379,8 @@ class TaskMetadata(BaseModel):
376
379
  if include_cite and cite:
377
380
  # check for whitespace in the citation
378
381
  if " " in cite:
379
- logger.warning(
380
- "Citation contains whitespace. Please ensure that the citation is correctly formatted."
381
- )
382
+ msg = "Citation contains whitespace. Please ensure that the citation is correctly formatted."
383
+ logger.warning(msg)
382
384
  return f"\\cite{{{cite}}}"
383
385
  return cite
384
386
 
@@ -414,7 +416,7 @@ class TaskMetadata(BaseModel):
414
416
  for subset, subset_value in stats.items():
415
417
  if subset == "hf_subset_descriptive_stats":
416
418
  continue
417
- n_samples[subset] = subset_value["num_samples"] # type: ignore
419
+ n_samples[subset] = subset_value["num_samples"]
418
420
  return n_samples
419
421
 
420
422
  @property
@@ -447,7 +449,7 @@ class TaskMetadata(BaseModel):
447
449
  Raises:
448
450
  ValueError: If the prompt type is not recognized.
449
451
  """
450
- if prompt_type is None:
452
+ if prompt_type is None or self.category is None:
451
453
  return self.modalities
452
454
  query_modalities, doc_modalities = self.category.split("2")
453
455
  category_to_modality: dict[str, Modalities] = {
@@ -467,7 +469,7 @@ class TaskMetadata(BaseModel):
467
469
 
468
470
  def _create_dataset_card_data(
469
471
  self,
470
- existing_dataset_card_data: DatasetCardData | None = None,
472
+ existing_dataset_card_data: CardData | None = None,
471
473
  ) -> tuple[DatasetCardData, dict[str, Any]]:
472
474
  """Create a DatasetCardData object from the task metadata.
473
475
 
@@ -502,12 +504,13 @@ class TaskMetadata(BaseModel):
502
504
 
503
505
  tags = ["mteb"] + self.modalities
504
506
 
505
- descriptive_stats = self.descriptive_stats
506
- if descriptive_stats is not None:
507
- for split, split_stat in descriptive_stats.items():
507
+ descriptive_stats = ""
508
+ if self.descriptive_stats is not None:
509
+ descriptive_stats_ = self.descriptive_stats
510
+ for split, split_stat in descriptive_stats_.items():
508
511
  if len(split_stat.get("hf_subset_descriptive_stats", {})) > 10:
509
512
  split_stat.pop("hf_subset_descriptive_stats", {})
510
- descriptive_stats = json.dumps(descriptive_stats, indent=4)
513
+ descriptive_stats = json.dumps(descriptive_stats_, indent=4)
511
514
 
512
515
  dataset_card_data_params = existing_dataset_card_data.to_dict()
513
516
  # override the existing values
@@ -695,11 +698,11 @@ class TaskMetadata(BaseModel):
695
698
 
696
699
  def _hf_languages(self) -> list[str]:
697
700
  languages: list[str] = []
698
- if self.is_multilingual:
699
- for val in list(self.eval_langs.values()):
701
+ if self.is_multilingual and isinstance(self.eval_langs, dict):
702
+ for val in self.eval_langs.values():
700
703
  languages.extend(val)
701
704
  else:
702
- languages = self.eval_langs
705
+ languages = cast(list[str], self.eval_langs)
703
706
  # value "python" is not valid. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters),
704
707
  # or a special value like "code", "multilingual".
705
708
  readme_langs = []
@@ -711,7 +714,7 @@ class TaskMetadata(BaseModel):
711
714
  readme_langs.append(lang_name)
712
715
  return sorted(set(readme_langs))
713
716
 
714
- def _hf_license(self) -> str:
717
+ def _hf_license(self) -> str | None:
715
718
  dataset_license = self.license
716
719
  if dataset_license:
717
720
  license_mapping = {
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from collections import defaultdict
3
3
  from pathlib import Path
4
- from typing import Any, ClassVar, TypedDict
4
+ from typing import Any, ClassVar, TypedDict, cast
5
5
 
6
6
  from datasets import Dataset, DatasetDict
7
7
  from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
@@ -78,6 +78,9 @@ class AbsTaskBitextMining(AbsTask):
78
78
  **kwargs: Any,
79
79
  ) -> dict[HFSubset, ScoresDict]:
80
80
  """Added load for "parallel" datasets"""
81
+ if not isinstance(model, EncoderProtocol):
82
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
83
+
81
84
  if not self.data_loaded:
82
85
  self.load_data()
83
86
 
@@ -87,11 +90,16 @@ class AbsTaskBitextMining(AbsTask):
87
90
  if subsets_to_run is not None:
88
91
  hf_subsets = [s for s in hf_subsets if s in subsets_to_run]
89
92
 
90
- scores = {}
93
+ encoder_model = cast(EncoderProtocol, model)
94
+
95
+ if self.dataset is None:
96
+ raise ValueError("Dataset is not loaded.")
97
+
98
+ scores: dict[str, BitextMiningMetrics] = {}
91
99
  if self.parallel_subsets:
92
- scores = self._evaluate_subset(
93
- model,
94
- self.dataset[split], # type: ignore
100
+ scores = self._evaluate_subset( # type: ignore[assignment]
101
+ encoder_model,
102
+ self.dataset[split],
95
103
  parallel=True,
96
104
  hf_split=split,
97
105
  hf_subset="parallel",
@@ -109,8 +117,8 @@ class AbsTaskBitextMining(AbsTask):
109
117
  data_split = self.dataset[split]
110
118
  else:
111
119
  data_split = self.dataset[hf_subset][split]
112
- scores[hf_subset] = self._evaluate_subset(
113
- model,
120
+ scores[hf_subset] = self._evaluate_subset( # type: ignore[assignment]
121
+ encoder_model,
114
122
  data_split,
115
123
  hf_split=split,
116
124
  hf_subset=hf_subset,
@@ -119,32 +127,32 @@ class AbsTaskBitextMining(AbsTask):
119
127
  **kwargs,
120
128
  )
121
129
 
122
- return scores
130
+ return cast(dict[HFSubset, ScoresDict], scores)
123
131
 
124
132
  def _get_pairs(self, parallel: bool) -> list[tuple[str, str]]:
125
133
  pairs = self._DEFAULT_PAIR
126
134
  if parallel:
127
- pairs = [langpair.split("-") for langpair in self.hf_subsets]
135
+ pairs = [langpair.split("-") for langpair in self.hf_subsets] # type: ignore[misc]
128
136
  return pairs
129
137
 
130
- def _evaluate_subset(
138
+ def _evaluate_subset( # type: ignore[override]
131
139
  self,
132
140
  model: EncoderProtocol,
133
141
  data_split: Dataset,
134
142
  *,
135
143
  hf_split: str,
136
144
  hf_subset: str,
137
- parallel: bool = False,
138
145
  encode_kwargs: dict[str, Any],
139
146
  prediction_folder: Path | None = None,
147
+ parallel: bool = False,
140
148
  **kwargs,
141
- ) -> ScoresDict:
149
+ ) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
142
150
  pairs = self._get_pairs(parallel)
143
151
 
144
152
  evaluator = BitextMiningEvaluator(
145
153
  data_split,
146
154
  task_metadata=self.metadata,
147
- pair_columns=pairs, # type: ignore
155
+ pair_columns=pairs,
148
156
  hf_split=hf_split,
149
157
  hf_subset=hf_subset,
150
158
  **kwargs,
@@ -168,16 +176,16 @@ class AbsTaskBitextMining(AbsTask):
168
176
  )
169
177
 
170
178
  if parallel:
171
- metrics = {}
179
+ parallel_metrics = {}
172
180
  for keys, nearest_neighbors in neighbours.items():
173
- metrics[keys] = self._compute_metrics(nearest_neighbors, gold)
181
+ parallel_metrics[keys] = self._compute_metrics(nearest_neighbors, gold)
174
182
 
175
- for v in metrics.values():
183
+ for v in parallel_metrics.values():
176
184
  self._add_main_score(v)
177
- else:
178
- def_pair_str = "-".join(self._DEFAULT_PAIR[0])
179
- metrics = self._compute_metrics(neighbours[def_pair_str], gold)
180
- self._add_main_score(metrics)
185
+ return parallel_metrics
186
+ def_pair_str = "-".join(self._DEFAULT_PAIR[0])
187
+ metrics = self._compute_metrics(neighbours[def_pair_str], gold)
188
+ self._add_main_score(metrics)
181
189
  return metrics
182
190
 
183
191
  def _compute_metrics(
@@ -250,8 +258,11 @@ class AbsTaskBitextMining(AbsTask):
250
258
  )
251
259
 
252
260
  def _push_dataset_to_hub(self, repo_name: str) -> None:
261
+ if self.dataset is None:
262
+ raise ValueError("Dataset is not loaded.")
263
+
253
264
  if self.metadata.is_multilingual:
254
- dataset = defaultdict(dict)
265
+ dataset: dict[str, dict[str, list[str]]] = defaultdict(dict)
255
266
  for config in self.metadata.eval_langs:
256
267
  logger.info(f"Converting {config} of {self.metadata.name}")
257
268
 
@@ -266,10 +277,10 @@ class AbsTaskBitextMining(AbsTask):
266
277
  for split in self.dataset[config]:
267
278
  dataset[split][lang_1] = self.dataset[config][split][sent_1]
268
279
  dataset[split][lang_2] = self.dataset[config][split][sent_2]
269
- for split in dataset:
270
- dataset[split] = Dataset.from_dict(dataset[split])
271
- dataset = DatasetDict(dataset)
272
- dataset.push_to_hub(repo_name)
280
+ dataset_dict = DatasetDict(
281
+ {split: Dataset.from_dict(dataset[split]) for split in dataset}
282
+ )
283
+ dataset_dict.push_to_hub(repo_name)
273
284
  else:
274
285
  sentences = {}
275
286
  for split in self.dataset:
@@ -16,7 +16,7 @@ else:
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
- OLD_FORMAT_RERANKING_TASKS = []
19
+ OLD_FORMAT_RERANKING_TASKS: list[str] = []
20
20
 
21
21
 
22
22
  @deprecated(
@@ -105,7 +105,9 @@ class AbsTaskReranking(AbsTaskRetrieval):
105
105
  )
106
106
 
107
107
  given_dataset = copy(given_dataset)
108
- self.dataset = defaultdict(lambda: defaultdict(dict))
108
+ self.dataset: dict[str, dict[str, RetrievalSplitData]] = defaultdict(
109
+ lambda: defaultdict(dict) # type: ignore[arg-type]
110
+ )
109
111
 
110
112
  hf_subsets = self.hf_subsets
111
113
 
@@ -115,19 +117,19 @@ class AbsTaskReranking(AbsTaskRetrieval):
115
117
  if hf_subset in cur_dataset:
116
118
  cur_dataset = cur_dataset[hf_subset]
117
119
  elif "name" in self.metadata.dataset:
118
- cur_dataset = datasets.load_dataset(**self.metadata.dataset) # type: ignore
120
+ cur_dataset = datasets.load_dataset(**self.metadata.dataset)
119
121
  assert hf_subset == "default", (
120
122
  f"Only default subset is supported for {self.metadata.name} since `name` is given in the metadata."
121
123
  )
122
124
  else:
123
125
  cur_dataset = datasets.load_dataset(
124
126
  **self.metadata.dataset, name=hf_subset
125
- ) # type: ignore
127
+ )
126
128
 
127
129
  for split in cur_dataset:
128
130
  corpus = []
129
131
  queries = []
130
- relevant_docs = defaultdict(dict)
132
+ relevant_docs: dict[str, dict[str, int]] = defaultdict(dict)
131
133
  top_ranked = defaultdict(list)
132
134
 
133
135
  # Create an enumerated dataset to pass indices
@@ -12,7 +12,7 @@ from mteb.abstasks._statistics_calculation import (
12
12
  calculate_text_statistics,
13
13
  )
14
14
  from mteb.abstasks.abstask import AbsTask
15
- from mteb.models import EncoderProtocol
15
+ from mteb.models import EncoderProtocol, MTEBModels
16
16
  from mteb.types.statistics import (
17
17
  ScoreStatistics,
18
18
  SplitDescriptiveStatistics,
@@ -77,7 +77,7 @@ class AbsTaskSummarization(AbsTask):
77
77
 
78
78
  def _evaluate_subset(
79
79
  self,
80
- model: EncoderProtocol,
80
+ model: MTEBModels,
81
81
  data_split: Dataset,
82
82
  *,
83
83
  hf_split: str,
@@ -86,8 +86,13 @@ class AbsTaskSummarization(AbsTask):
86
86
  prediction_folder: Path | None = None,
87
87
  **kwargs,
88
88
  ) -> SummarizationMetrics:
89
+ if not isinstance(model, EncoderProtocol):
90
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
91
+
89
92
  normalized_scores = [
90
- (np.array(x) - self.min_score) / (self.max_score - self.min_score)
93
+ (
94
+ (np.array(x) - self.min_score) / (self.max_score - self.min_score)
95
+ ).tolist()
91
96
  for x in data_split[self.relevancy_column_name]
92
97
  ]
93
98
  evaluator = self.evaluator(
@@ -7,7 +7,7 @@ from datasets import Dataset
7
7
  from sklearn import metrics
8
8
 
9
9
  from mteb._evaluators import ZeroShotClassificationEvaluator
10
- from mteb.models import EncoderProtocol
10
+ from mteb.models import EncoderProtocol, MTEBModels
11
11
  from mteb.types.statistics import (
12
12
  ImageStatistics,
13
13
  LabelStatistics,
@@ -111,7 +111,7 @@ class AbsTaskZeroShotClassification(AbsTask):
111
111
 
112
112
  def _evaluate_subset(
113
113
  self,
114
- model: EncoderProtocol,
114
+ model: MTEBModels,
115
115
  data_split: Dataset,
116
116
  *,
117
117
  hf_split: str,
@@ -120,6 +120,9 @@ class AbsTaskZeroShotClassification(AbsTask):
120
120
  prediction_folder: Path | None = None,
121
121
  **kwargs,
122
122
  ) -> ZeroShotClassificationMetrics:
123
+ if not isinstance(model, EncoderProtocol):
124
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
125
+
123
126
  candidate_labels = self.get_candidate_labels()
124
127
  data_split = data_split.select_columns(
125
128
  [self.input_column_name, self.label_column_name]
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Iterable, Sequence
3
+ from collections.abc import Iterator, Sequence
4
4
  from dataclasses import dataclass, field
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
@@ -47,7 +47,7 @@ class Benchmark:
47
47
  display_name: str | None = None
48
48
  language_view: list[str] | Literal["all"] = field(default_factory=list)
49
49
 
50
- def __iter__(self) -> Iterable[AbsTask]:
50
+ def __iter__(self) -> Iterator[AbsTask]:
51
51
  return iter(self.tasks)
52
52
 
53
53
  def __len__(self) -> int: