mteb 2.7.16__py3-none-any.whl → 2.7.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +16 -16
- mteb/_evaluators/any_sts_evaluator.py +1 -1
- mteb/_evaluators/classification_metrics.py +10 -1
- mteb/_evaluators/clustering_evaluator.py +1 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +2 -2
- mteb/_evaluators/pair_classification_evaluator.py +3 -2
- mteb/_evaluators/retrieval_evaluator.py +1 -1
- mteb/_evaluators/retrieval_metrics.py +9 -7
- mteb/_evaluators/sklearn_evaluator.py +13 -6
- mteb/_evaluators/text/bitext_mining_evaluator.py +1 -1
- mteb/_evaluators/text/summarization_evaluator.py +1 -1
- mteb/_evaluators/zeroshot_classification_evaluator.py +1 -1
- mteb/abstasks/_stratification.py +13 -8
- mteb/abstasks/abstask.py +4 -4
- mteb/abstasks/classification.py +6 -4
- mteb/abstasks/clustering.py +1 -1
- mteb/abstasks/clustering_legacy.py +1 -1
- mteb/abstasks/image/image_text_pair_classification.py +1 -1
- mteb/abstasks/multilabel_classification.py +7 -5
- mteb/abstasks/pair_classification.py +1 -1
- mteb/abstasks/regression.py +3 -2
- mteb/abstasks/retrieval.py +8 -5
- mteb/abstasks/retrieval_dataset_loaders.py +27 -8
- mteb/abstasks/sts.py +1 -1
- mteb/abstasks/text/bitext_mining.py +2 -2
- mteb/abstasks/text/reranking.py +1 -1
- mteb/abstasks/text/summarization.py +1 -1
- mteb/abstasks/zeroshot_classification.py +1 -1
- mteb/benchmarks/benchmark.py +131 -3
- mteb/evaluate.py +2 -2
- mteb/leaderboard/figures.py +2 -1
- mteb/leaderboard/table.py +10 -2
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -3
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +3 -3
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +8 -3
- mteb/models/cache_wrappers/cache_wrapper.py +2 -2
- mteb/models/model_implementations/bedrock_models.py +4 -4
- mteb/models/model_implementations/bm25.py +2 -2
- mteb/models/model_implementations/mcinext_models.py +2 -2
- mteb/models/model_implementations/openai_models.py +2 -1
- mteb/models/model_implementations/pylate_models.py +4 -4
- mteb/models/model_implementations/random_baseline.py +4 -3
- mteb/models/model_implementations/seed_models.py +7 -2
- mteb/models/model_implementations/voyage_models.py +1 -1
- mteb/models/models_protocols.py +2 -2
- mteb/models/search_wrappers.py +4 -4
- mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +1 -1
- mteb/tasks/classification/ben/bengali_document_classification.py +2 -2
- mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/hin_dialect_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/language_classification.py +1 -1
- mteb/tasks/classification/multilingual/south_african_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/swa/swahili_news_classification.py +2 -2
- mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/nob/vg_hierarchical_clustering.py +2 -2
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +1 -1
- mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +1 -1
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +8 -8
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +2 -2
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/bright_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -2
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +5 -5
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +1 -0
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/sts/multilingual/sem_rel24_sts.py +1 -1
- mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py +1 -1
- mteb/tasks/sts/por/assin2_sts.py +1 -1
- mteb/types/_encoder_io.py +3 -2
- {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/METADATA +1 -1
- {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/RECORD +173 -173
- {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/WHEEL +0 -0
- {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/entry_points.txt +0 -0
- {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/top_level.txt +0 -0
|
@@ -78,7 +78,7 @@ class RetrievalDatasetLoader:
|
|
|
78
78
|
|
|
79
79
|
def load(
|
|
80
80
|
self,
|
|
81
|
-
num_proc: int =
|
|
81
|
+
num_proc: int | None = None,
|
|
82
82
|
) -> RetrievalSplitData:
|
|
83
83
|
"""Loads the dataset split for the specified configuration.
|
|
84
84
|
|
|
@@ -128,7 +128,11 @@ class RetrievalDatasetLoader:
|
|
|
128
128
|
f"Split {self.split} not found in {splits}. Please specify a valid split."
|
|
129
129
|
)
|
|
130
130
|
|
|
131
|
-
def _load_dataset_split(
|
|
131
|
+
def _load_dataset_split(
|
|
132
|
+
self,
|
|
133
|
+
config: str,
|
|
134
|
+
num_proc: int | None,
|
|
135
|
+
) -> Dataset:
|
|
132
136
|
return load_dataset(
|
|
133
137
|
self.hf_repo,
|
|
134
138
|
config,
|
|
@@ -138,7 +142,10 @@ class RetrievalDatasetLoader:
|
|
|
138
142
|
num_proc=num_proc,
|
|
139
143
|
)
|
|
140
144
|
|
|
141
|
-
def _load_corpus(
|
|
145
|
+
def _load_corpus(
|
|
146
|
+
self,
|
|
147
|
+
num_proc: int | None,
|
|
148
|
+
) -> CorpusDatasetType:
|
|
142
149
|
config = f"{self.config}-corpus" if self.config is not None else "corpus"
|
|
143
150
|
logger.info("Loading corpus subset: %s", config)
|
|
144
151
|
|
|
@@ -151,7 +158,10 @@ class RetrievalDatasetLoader:
|
|
|
151
158
|
logger.debug("Doc Example: %s", corpus_ds[0])
|
|
152
159
|
return corpus_ds
|
|
153
160
|
|
|
154
|
-
def _load_queries(
|
|
161
|
+
def _load_queries(
|
|
162
|
+
self,
|
|
163
|
+
num_proc: int | None,
|
|
164
|
+
) -> QueryDatasetType:
|
|
155
165
|
config = f"{self.config}-queries" if self.config is not None else "queries"
|
|
156
166
|
logger.info("Loading queries subset: %s", config)
|
|
157
167
|
|
|
@@ -168,7 +178,10 @@ class RetrievalDatasetLoader:
|
|
|
168
178
|
|
|
169
179
|
return queries_ds
|
|
170
180
|
|
|
171
|
-
def _load_qrels(
|
|
181
|
+
def _load_qrels(
|
|
182
|
+
self,
|
|
183
|
+
num_proc: int | None,
|
|
184
|
+
) -> RelevantDocumentsType:
|
|
172
185
|
config = f"{self.config}-qrels" if self.config is not None else "default"
|
|
173
186
|
|
|
174
187
|
logger.info("Loading qrels subset: %s", config)
|
|
@@ -203,7 +216,10 @@ class RetrievalDatasetLoader:
|
|
|
203
216
|
logger.info("Loaded %d %s qrels.", len(qrels_dict), self.split.upper())
|
|
204
217
|
return qrels_dict
|
|
205
218
|
|
|
206
|
-
def _load_top_ranked(
|
|
219
|
+
def _load_top_ranked(
|
|
220
|
+
self,
|
|
221
|
+
num_proc: int | None,
|
|
222
|
+
) -> TopRankedDocumentsType:
|
|
207
223
|
config = (
|
|
208
224
|
f"{self.config}-top_ranked" if self.config is not None else "top_ranked"
|
|
209
225
|
)
|
|
@@ -226,7 +242,10 @@ class RetrievalDatasetLoader:
|
|
|
226
242
|
logger.info(f"Top ranked loaded: {len(top_ranked_ds)}")
|
|
227
243
|
return top_ranked_dict
|
|
228
244
|
|
|
229
|
-
def _load_instructions(
|
|
245
|
+
def _load_instructions(
|
|
246
|
+
self,
|
|
247
|
+
num_proc: int | None,
|
|
248
|
+
) -> InstructionDatasetType:
|
|
230
249
|
config = (
|
|
231
250
|
f"{self.config}-instruction" if self.config is not None else "instruction"
|
|
232
251
|
)
|
|
@@ -246,7 +265,7 @@ class RetrievalDatasetLoader:
|
|
|
246
265
|
def _combine_queries_with_instructions_datasets(
|
|
247
266
|
queries_dataset: QueryDatasetType,
|
|
248
267
|
instruction_dataset: InstructionDatasetType | dict[str, str],
|
|
249
|
-
num_proc: int,
|
|
268
|
+
num_proc: int | None,
|
|
250
269
|
) -> Dataset:
|
|
251
270
|
if isinstance(instruction_dataset, Dataset):
|
|
252
271
|
instruction_to_query_idx = {
|
mteb/abstasks/sts.py
CHANGED
|
@@ -82,7 +82,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
82
82
|
*,
|
|
83
83
|
encode_kwargs: EncodeKwargs,
|
|
84
84
|
prediction_folder: Path | None = None,
|
|
85
|
-
num_proc: int =
|
|
85
|
+
num_proc: int | None = None,
|
|
86
86
|
**kwargs: Any,
|
|
87
87
|
) -> dict[HFSubset, ScoresDict]:
|
|
88
88
|
"""Added load for "parallel" datasets"""
|
|
@@ -155,7 +155,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
155
155
|
encode_kwargs: EncodeKwargs,
|
|
156
156
|
prediction_folder: Path | None = None,
|
|
157
157
|
parallel: bool = False,
|
|
158
|
-
num_proc: int =
|
|
158
|
+
num_proc: int | None = None,
|
|
159
159
|
**kwargs,
|
|
160
160
|
) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
|
|
161
161
|
pairs = self._get_pairs(parallel)
|
mteb/abstasks/text/reranking.py
CHANGED
|
@@ -34,7 +34,7 @@ class AbsTaskReranking(AbsTaskRetrieval):
|
|
|
34
34
|
For dataformat and other information, see [AbsTaskRetrieval][mteb.abstasks.retrieval.AbsTaskRetrieval].
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def load_data(self, num_proc: int =
|
|
37
|
+
def load_data(self, num_proc: int | None = None, **kwargs) -> None:
|
|
38
38
|
"""Load the dataset."""
|
|
39
39
|
if self.data_loaded:
|
|
40
40
|
return
|
|
@@ -94,7 +94,7 @@ class AbsTaskSummarization(AbsTask):
|
|
|
94
94
|
hf_subset: str,
|
|
95
95
|
encode_kwargs: EncodeKwargs,
|
|
96
96
|
prediction_folder: Path | None = None,
|
|
97
|
-
num_proc: int =
|
|
97
|
+
num_proc: int | None = None,
|
|
98
98
|
**kwargs,
|
|
99
99
|
) -> SummarizationMetrics:
|
|
100
100
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -127,7 +127,7 @@ class AbsTaskZeroShotClassification(AbsTask):
|
|
|
127
127
|
hf_subset: str,
|
|
128
128
|
encode_kwargs: EncodeKwargs,
|
|
129
129
|
prediction_folder: Path | None = None,
|
|
130
|
-
num_proc: int =
|
|
130
|
+
num_proc: int | None = None,
|
|
131
131
|
**kwargs,
|
|
132
132
|
) -> ZeroShotClassificationMetrics:
|
|
133
133
|
if not isinstance(model, EncoderProtocol):
|
mteb/benchmarks/benchmark.py
CHANGED
|
@@ -164,14 +164,142 @@ class MIEBBenchmark(Benchmark):
|
|
|
164
164
|
class VidoreBenchmark(Benchmark):
|
|
165
165
|
"""Wrapper for Vidore3 benchmark."""
|
|
166
166
|
|
|
167
|
-
def
|
|
167
|
+
def _create_vidore_summary_table(
|
|
168
168
|
self, benchmark_results: BenchmarkResults
|
|
169
169
|
) -> pd.DataFrame:
|
|
170
|
+
"""Create summary table from BenchmarkResults.
|
|
171
|
+
|
|
172
|
+
Returns a DataFrame with one row per model containing summary statistics
|
|
173
|
+
and task type averages. Customized for Vidore benchmark.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
benchmark_results: BenchmarkResults object containing model results
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
DataFrame with model summaries, ready for styling in the leaderboard
|
|
180
|
+
"""
|
|
181
|
+
import mteb
|
|
170
182
|
from mteb.benchmarks._create_table import (
|
|
171
|
-
|
|
183
|
+
_format_max_tokens,
|
|
184
|
+
_format_n_parameters,
|
|
185
|
+
_get_means_per_types,
|
|
186
|
+
_split_on_capital,
|
|
187
|
+
)
|
|
188
|
+
from mteb.get_tasks import get_task
|
|
189
|
+
|
|
190
|
+
data = benchmark_results.to_dataframe(format="long")
|
|
191
|
+
|
|
192
|
+
if data.empty:
|
|
193
|
+
no_results_frame = pd.DataFrame(
|
|
194
|
+
{"No results": ["You can try relaxing your criteria"]}
|
|
195
|
+
)
|
|
196
|
+
return no_results_frame
|
|
197
|
+
public_task_name = benchmark_results._filter_tasks(is_public=True).task_names
|
|
198
|
+
private_task_name = benchmark_results._filter_tasks(is_public=False).task_names
|
|
199
|
+
# Convert to DataFrame and pivot
|
|
200
|
+
per_task = data.pivot(index="model_name", columns="task_name", values="score")
|
|
201
|
+
|
|
202
|
+
# Remove models with no scores
|
|
203
|
+
to_remove = per_task.isna().all(axis="columns")
|
|
204
|
+
if to_remove.all():
|
|
205
|
+
no_results_frame = pd.DataFrame(
|
|
206
|
+
{"No results": ["You can try relaxing your criteria"]}
|
|
207
|
+
)
|
|
208
|
+
return no_results_frame
|
|
209
|
+
|
|
210
|
+
models_to_remove = list(per_task[to_remove].index)
|
|
211
|
+
per_task = per_task.drop(models_to_remove, axis=0)
|
|
212
|
+
|
|
213
|
+
# Calculate means by task type
|
|
214
|
+
mean_per_type = _get_means_per_types(per_task)
|
|
215
|
+
mean_per_type = mean_per_type.pivot(
|
|
216
|
+
index="model_name", columns="task_type", values="score"
|
|
217
|
+
)
|
|
218
|
+
mean_per_type.columns = [
|
|
219
|
+
_split_on_capital(column) for column in mean_per_type.columns
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
# Calculate overall means
|
|
223
|
+
public_mean = per_task[public_task_name].mean(skipna=False, axis=1)
|
|
224
|
+
private_mean = per_task[private_task_name].mean(skipna=False, axis=1)
|
|
225
|
+
|
|
226
|
+
# Build joint table
|
|
227
|
+
joint_table = mean_per_type.copy()
|
|
228
|
+
joint_table.insert(1, "mean(public)", public_mean)
|
|
229
|
+
joint_table.insert(2, "mean(private)", private_mean)
|
|
230
|
+
task_type = get_task(
|
|
231
|
+
per_task.columns[0]
|
|
232
|
+
).metadata.type # "DocumentUnderstanding"
|
|
233
|
+
joint_table = joint_table.sort_values(
|
|
234
|
+
[_split_on_capital(task_type), "mean(public)", "mean(private)"],
|
|
235
|
+
ascending=False,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
joint_table = joint_table.reset_index()
|
|
239
|
+
|
|
240
|
+
# Add model metadata
|
|
241
|
+
model_metas = joint_table["model_name"].map(mteb.get_model_meta)
|
|
242
|
+
joint_table = joint_table[model_metas.notna()]
|
|
243
|
+
joint_table["model_link"] = model_metas.map(lambda m: m.reference)
|
|
244
|
+
|
|
245
|
+
# Insert model metadata columns
|
|
246
|
+
joint_table.insert(
|
|
247
|
+
1,
|
|
248
|
+
"Max Tokens",
|
|
249
|
+
model_metas.map(lambda m: _format_max_tokens(m.max_tokens)),
|
|
250
|
+
)
|
|
251
|
+
joint_table.insert(
|
|
252
|
+
1,
|
|
253
|
+
"Embedding Dimensions",
|
|
254
|
+
model_metas.map(lambda m: int(m.embed_dim) if m.embed_dim else None),
|
|
255
|
+
)
|
|
256
|
+
joint_table.insert(
|
|
257
|
+
1,
|
|
258
|
+
"Number of Parameters (B)",
|
|
259
|
+
model_metas.map(lambda m: _format_n_parameters(m.n_parameters)),
|
|
260
|
+
)
|
|
261
|
+
joint_table.insert(
|
|
262
|
+
1,
|
|
263
|
+
"Memory Usage (MB)",
|
|
264
|
+
model_metas.map(
|
|
265
|
+
lambda m: int(m.memory_usage_mb) if m.memory_usage_mb else None
|
|
266
|
+
),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Clean up model names (remove HF organization)
|
|
270
|
+
joint_table["model_name"] = joint_table["model_name"].map(
|
|
271
|
+
lambda name: name.split("/")[-1]
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Add markdown links to model names
|
|
275
|
+
name_w_link = (
|
|
276
|
+
"[" + joint_table["model_name"] + "](" + joint_table["model_link"] + ")"
|
|
277
|
+
)
|
|
278
|
+
joint_table["model_name"] = joint_table["model_name"].mask(
|
|
279
|
+
joint_table["model_link"].notna(), name_w_link
|
|
280
|
+
)
|
|
281
|
+
joint_table = joint_table.drop(columns=["model_link"])
|
|
282
|
+
|
|
283
|
+
# Rename columns
|
|
284
|
+
rename_dict = {
|
|
285
|
+
"model_name": "Model",
|
|
286
|
+
"mean(public)": "Mean (Public)",
|
|
287
|
+
"mean(private)": "Mean (Private)",
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
joint_table = joint_table.rename(columns=rename_dict)
|
|
291
|
+
|
|
292
|
+
# Add Rank column
|
|
293
|
+
joint_table.insert(
|
|
294
|
+
0, "Rank (Mean Task)", [i + 1 for i in range(len(joint_table))]
|
|
172
295
|
)
|
|
173
296
|
|
|
174
|
-
joint_table
|
|
297
|
+
return joint_table
|
|
298
|
+
|
|
299
|
+
def _create_summary_table(
|
|
300
|
+
self, benchmark_results: BenchmarkResults
|
|
301
|
+
) -> pd.DataFrame:
|
|
302
|
+
joint_table = self._create_vidore_summary_table(benchmark_results)
|
|
175
303
|
# For ViDoRe (V1, V2, V3): all tasks are Document Understanding type, so Document Understanding column = Mean (Task)
|
|
176
304
|
joint_table = joint_table.rename(
|
|
177
305
|
columns={"Document Understanding": "Mean (Task)"}
|
mteb/evaluate.py
CHANGED
|
@@ -91,7 +91,7 @@ def _evaluate_task(
|
|
|
91
91
|
encode_kwargs: EncodeKwargs,
|
|
92
92
|
prediction_folder: Path | None,
|
|
93
93
|
public_only: bool | None,
|
|
94
|
-
num_proc: int =
|
|
94
|
+
num_proc: int | None = None,
|
|
95
95
|
) -> TaskResult | TaskError:
|
|
96
96
|
"""The core logic to run a model on a given task. See `evaluate` for more details.
|
|
97
97
|
|
|
@@ -282,7 +282,7 @@ def evaluate(
|
|
|
282
282
|
prediction_folder: Path | str | None = None,
|
|
283
283
|
show_progress_bar: bool = True,
|
|
284
284
|
public_only: bool | None = None,
|
|
285
|
-
num_proc: int =
|
|
285
|
+
num_proc: int | None = None,
|
|
286
286
|
) -> ModelResult:
|
|
287
287
|
"""This function runs a model on a given task and returns the results.
|
|
288
288
|
|
mteb/leaderboard/figures.py
CHANGED
|
@@ -125,6 +125,7 @@ def _performance_size_plot(df: pd.DataFrame) -> go.Figure:
|
|
|
125
125
|
min_score, max_score = df["Mean (Task)"].min(), df["Mean (Task)"].max()
|
|
126
126
|
df["sqrt(dim)"] = np.sqrt(df["Embedding Dimensions"])
|
|
127
127
|
df["Max Tokens"] = df["Max Tokens"].apply(lambda x: _process_max_tokens(x))
|
|
128
|
+
rank_column = "Rank (Borda)" if "Rank (Borda)" in df.columns else "Rank (Mean Task)"
|
|
128
129
|
fig = px.scatter(
|
|
129
130
|
df,
|
|
130
131
|
x="Number of Parameters",
|
|
@@ -141,7 +142,7 @@ def _performance_size_plot(df: pd.DataFrame) -> go.Figure:
|
|
|
141
142
|
"Embedding Dimensions": True,
|
|
142
143
|
"Number of Parameters": True,
|
|
143
144
|
"Mean (Task)": True,
|
|
144
|
-
|
|
145
|
+
rank_column: True,
|
|
145
146
|
"Log(Tokens)": False,
|
|
146
147
|
"sqrt(dim)": False,
|
|
147
148
|
"model_text": False,
|
mteb/leaderboard/table.py
CHANGED
|
@@ -156,6 +156,7 @@ def _apply_summary_table_styling(joint_table: pd.DataFrame) -> gr.DataFrame:
|
|
|
156
156
|
"""
|
|
157
157
|
excluded_columns = [
|
|
158
158
|
"Rank (Borda)",
|
|
159
|
+
"Rank (Mean Task)",
|
|
159
160
|
"Rank",
|
|
160
161
|
"Model",
|
|
161
162
|
"Number of Parameters (B)",
|
|
@@ -183,10 +184,17 @@ def _apply_summary_table_styling(joint_table: pd.DataFrame) -> gr.DataFrame:
|
|
|
183
184
|
joint_table["Zero-shot"] = joint_table["Zero-shot"].apply(_format_zero_shot)
|
|
184
185
|
joint_table[score_columns] = joint_table[score_columns].map(_format_scores)
|
|
185
186
|
|
|
187
|
+
if "Rank (Borda)" in joint_table.columns:
|
|
188
|
+
rank_column = "Rank (Borda)"
|
|
189
|
+
elif "Rank (Mean Task)" in joint_table.columns:
|
|
190
|
+
rank_column = "Rank (Mean Task)"
|
|
191
|
+
else:
|
|
192
|
+
raise ValueError("No rank column found in the result table.")
|
|
193
|
+
|
|
186
194
|
joint_table_style = joint_table.style.format(
|
|
187
195
|
{
|
|
188
196
|
**dict.fromkeys(score_columns, "{:.2f}"),
|
|
189
|
-
|
|
197
|
+
rank_column: "{:.0f}",
|
|
190
198
|
"Memory Usage (MB)": "{:.0f}",
|
|
191
199
|
"Embedding Dimensions": "{:.0f}",
|
|
192
200
|
"Max Tokens": "{:.0f}",
|
|
@@ -195,7 +203,7 @@ def _apply_summary_table_styling(joint_table: pd.DataFrame) -> gr.DataFrame:
|
|
|
195
203
|
na_rep="",
|
|
196
204
|
)
|
|
197
205
|
joint_table_style = joint_table_style.highlight_min(
|
|
198
|
-
|
|
206
|
+
rank_column, props="font-weight: bold"
|
|
199
207
|
).highlight_max(subset=score_columns, props="font-weight: bold")
|
|
200
208
|
|
|
201
209
|
# Apply background gradients for each selected column
|
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
from mteb.types import Array
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@runtime_checkable
|
|
@@ -26,7 +26,7 @@ class CacheBackendProtocol(Protocol):
|
|
|
26
26
|
**kwargs: Additional backend-specific arguments.
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
def add(self, item: list[dict[str, Any]], vectors:
|
|
29
|
+
def add(self, item: list[dict[str, Any]], vectors: Array) -> None:
|
|
30
30
|
"""Add a vector to the cache.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
@@ -34,7 +34,7 @@ class CacheBackendProtocol(Protocol):
|
|
|
34
34
|
vectors: Embedding vector of shape (dim,) or (1, dim).
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def get_vector(self, item: dict[str, Any]) ->
|
|
37
|
+
def get_vector(self, item: dict[str, Any]) -> Array | None:
|
|
38
38
|
"""Retrieve the cached vector for the given item.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
@@ -15,7 +15,7 @@ from ._hash_utils import _hash_item
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
import faiss
|
|
17
17
|
|
|
18
|
-
from mteb.types import BatchedInput
|
|
18
|
+
from mteb.types import Array, BatchedInput
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
21
21
|
|
|
@@ -43,7 +43,7 @@ class FaissCache:
|
|
|
43
43
|
logger.info(f"Initialized FAISS VectorCacheMap in {self.directory}")
|
|
44
44
|
self.load()
|
|
45
45
|
|
|
46
|
-
def add(self, items: list[dict[str, Any]], vectors:
|
|
46
|
+
def add(self, items: list[dict[str, Any]], vectors: Array) -> None:
|
|
47
47
|
"""Add vector to FAISS index."""
|
|
48
48
|
import faiss
|
|
49
49
|
|
|
@@ -67,7 +67,7 @@ class FaissCache:
|
|
|
67
67
|
vectors_array = np.vstack(vectors_to_add).astype(np.float32)
|
|
68
68
|
self.index.add(vectors_array)
|
|
69
69
|
|
|
70
|
-
def get_vector(self, item:
|
|
70
|
+
def get_vector(self, item: dict[str, Any]) -> Array | None:
|
|
71
71
|
"""Retrieve vector from index by hash."""
|
|
72
72
|
if self.index is None:
|
|
73
73
|
return None
|
|
@@ -1,13 +1,18 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import json
|
|
2
4
|
import logging
|
|
3
5
|
import warnings
|
|
4
6
|
from pathlib import Path
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
6
8
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
|
|
9
11
|
from ._hash_utils import _hash_item
|
|
10
12
|
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from mteb.types import Array
|
|
15
|
+
|
|
11
16
|
logger = logging.getLogger(__name__)
|
|
12
17
|
|
|
13
18
|
|
|
@@ -27,7 +32,7 @@ class NumpyCache:
|
|
|
27
32
|
logger.info(f"Initialized VectorCacheMap in directory: {self.directory}")
|
|
28
33
|
self._initialize_vectors_file()
|
|
29
34
|
|
|
30
|
-
def add(self, items: list[dict[str, Any]], vectors:
|
|
35
|
+
def add(self, items: list[dict[str, Any]], vectors: Array) -> None:
|
|
31
36
|
"""Add a vector to the cache."""
|
|
32
37
|
try:
|
|
33
38
|
if self.vector_dim is None:
|
|
@@ -178,7 +183,7 @@ class NumpyCache:
|
|
|
178
183
|
logger.error(f"Error loading VectorCacheMap: {str(e)}")
|
|
179
184
|
raise
|
|
180
185
|
|
|
181
|
-
def get_vector(self, item: dict[str, Any]) ->
|
|
186
|
+
def get_vector(self, item: dict[str, Any]) -> Array | None:
|
|
182
187
|
"""Retrieve vector from index by hash."""
|
|
183
188
|
if self.vectors is None:
|
|
184
189
|
return None
|
|
@@ -98,7 +98,7 @@ class CachedEmbeddingWrapper:
|
|
|
98
98
|
uncached_items: list[dict[str, Any]] = []
|
|
99
99
|
uncached_indices: list[int] = []
|
|
100
100
|
all_items: Dataset = inputs.dataset
|
|
101
|
-
cached_vectors: dict[int,
|
|
101
|
+
cached_vectors: dict[int, Array] = {}
|
|
102
102
|
|
|
103
103
|
for i, item in enumerate(all_items):
|
|
104
104
|
vector = cache.get_vector(item)
|
|
@@ -108,7 +108,7 @@ class CachedEmbeddingWrapper:
|
|
|
108
108
|
uncached_items.append(item)
|
|
109
109
|
uncached_indices.append(i)
|
|
110
110
|
|
|
111
|
-
newly_encoded: dict[int,
|
|
111
|
+
newly_encoded: dict[int, Array] = {}
|
|
112
112
|
if uncached_items:
|
|
113
113
|
logger.info(f"Encoding {len(uncached_items)} new items")
|
|
114
114
|
# Build a simple DataLoader with only uncached items
|
|
@@ -86,7 +86,7 @@ class BedrockModel(AbsEncoder):
|
|
|
86
86
|
|
|
87
87
|
def _encode_amazon(
|
|
88
88
|
self, sentences: list[str], show_progress_bar: bool = False
|
|
89
|
-
) ->
|
|
89
|
+
) -> Array:
|
|
90
90
|
from botocore.exceptions import ValidationError
|
|
91
91
|
|
|
92
92
|
all_embeddings = []
|
|
@@ -125,7 +125,7 @@ class BedrockModel(AbsEncoder):
|
|
|
125
125
|
sentences: list[str],
|
|
126
126
|
cohere_task_type: str,
|
|
127
127
|
show_progress_bar: bool = False,
|
|
128
|
-
) ->
|
|
128
|
+
) -> Array:
|
|
129
129
|
batches = [
|
|
130
130
|
sentences[i : i + self._max_batch_size]
|
|
131
131
|
for i in range(0, len(sentences), self._max_batch_size)
|
|
@@ -149,7 +149,7 @@ class BedrockModel(AbsEncoder):
|
|
|
149
149
|
|
|
150
150
|
return np.array(all_embeddings)
|
|
151
151
|
|
|
152
|
-
def _embed_amazon(self, sentence: str) ->
|
|
152
|
+
def _embed_amazon(self, sentence: str) -> Array:
|
|
153
153
|
response = self._client.invoke_model(
|
|
154
154
|
body=json.dumps({"inputText": sentence}),
|
|
155
155
|
modelId=self._model_id,
|
|
@@ -158,7 +158,7 @@ class BedrockModel(AbsEncoder):
|
|
|
158
158
|
)
|
|
159
159
|
return self._to_numpy(response)
|
|
160
160
|
|
|
161
|
-
def _to_numpy(self, embedding_response) ->
|
|
161
|
+
def _to_numpy(self, embedding_response) -> Array:
|
|
162
162
|
response = json.loads(embedding_response.get("body").read())
|
|
163
163
|
key = "embedding" if self._provider == "amazon" else "embeddings"
|
|
164
164
|
return np.array(response[key])
|
|
@@ -54,7 +54,7 @@ def bm25_loader(model_name, **kwargs) -> SearchProtocol:
|
|
|
54
54
|
hf_split: str,
|
|
55
55
|
hf_subset: str,
|
|
56
56
|
encode_kwargs: EncodeKwargs,
|
|
57
|
-
num_proc: int =
|
|
57
|
+
num_proc: int | None = None,
|
|
58
58
|
) -> None:
|
|
59
59
|
logger.info("Encoding Corpus...")
|
|
60
60
|
corpus_texts = [
|
|
@@ -81,7 +81,7 @@ def bm25_loader(model_name, **kwargs) -> SearchProtocol:
|
|
|
81
81
|
top_k: int,
|
|
82
82
|
encode_kwargs: EncodeKwargs,
|
|
83
83
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
84
|
-
num_proc: int =
|
|
84
|
+
num_proc: int | None = None,
|
|
85
85
|
) -> RetrievalOutputType:
|
|
86
86
|
logger.info("Encoding Queries...")
|
|
87
87
|
query_ids = list(queries["id"])
|
|
@@ -13,7 +13,7 @@ from mteb.models.abs_encoder import AbsEncoder
|
|
|
13
13
|
from mteb.models.model_meta import ModelMeta
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from mteb.types import PromptType
|
|
16
|
+
from mteb.types import Array, PromptType
|
|
17
17
|
logger = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
HAKIM_CITATION = """@article{sarmadi2025hakim,
|
|
@@ -302,7 +302,7 @@ class HakimModelWrapper(AbsEncoder):
|
|
|
302
302
|
prompt_type: PromptType | None = None,
|
|
303
303
|
batch_size: int = 32,
|
|
304
304
|
**kwargs: Any,
|
|
305
|
-
) ->
|
|
305
|
+
) -> Array:
|
|
306
306
|
"""Encodes sentences using the API.
|
|
307
307
|
|
|
308
308
|
Returns:
|
|
@@ -11,6 +11,7 @@ from mteb.models.abs_encoder import AbsEncoder
|
|
|
11
11
|
from mteb.models.model_meta import ModelMeta, ScoringFunction
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
|
+
from numpy.typing import NDArray
|
|
14
15
|
from torch.utils.data import DataLoader
|
|
15
16
|
|
|
16
17
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
@@ -166,7 +167,7 @@ class OpenAIModel(AbsEncoder):
|
|
|
166
167
|
all_embeddings[mask] = no_empty_embeddings
|
|
167
168
|
return all_embeddings
|
|
168
169
|
|
|
169
|
-
def _to_numpy(self, embedding_response) -> np.
|
|
170
|
+
def _to_numpy(self, embedding_response) -> NDArray[np.floating]:
|
|
170
171
|
return np.array([e.embedding for e in embedding_response.data])
|
|
171
172
|
|
|
172
173
|
|
|
@@ -53,7 +53,7 @@ class PylateSearchEncoder:
|
|
|
53
53
|
hf_split: str,
|
|
54
54
|
hf_subset: str,
|
|
55
55
|
encode_kwargs: EncodeKwargs,
|
|
56
|
-
num_proc: int,
|
|
56
|
+
num_proc: int | None,
|
|
57
57
|
) -> None:
|
|
58
58
|
"""Index the corpus for retrieval.
|
|
59
59
|
|
|
@@ -89,7 +89,7 @@ class PylateSearchEncoder:
|
|
|
89
89
|
top_k: int,
|
|
90
90
|
encode_kwargs: EncodeKwargs,
|
|
91
91
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
92
|
-
num_proc: int,
|
|
92
|
+
num_proc: int | None,
|
|
93
93
|
) -> RetrievalOutputType:
|
|
94
94
|
queries_dataloader = create_dataloader(
|
|
95
95
|
queries,
|
|
@@ -150,7 +150,7 @@ class PylateSearchEncoder:
|
|
|
150
150
|
hf_split: str,
|
|
151
151
|
top_k: int,
|
|
152
152
|
encode_kwargs: EncodeKwargs,
|
|
153
|
-
num_proc: int,
|
|
153
|
+
num_proc: int | None,
|
|
154
154
|
) -> dict[str, list[tuple[float, str]]]:
|
|
155
155
|
from pylate import indexes, retrieve
|
|
156
156
|
|
|
@@ -216,7 +216,7 @@ class PylateSearchEncoder:
|
|
|
216
216
|
hf_subset: str,
|
|
217
217
|
hf_split: str,
|
|
218
218
|
encode_kwargs: EncodeKwargs,
|
|
219
|
-
num_proc: int =
|
|
219
|
+
num_proc: int | None = None,
|
|
220
220
|
) -> dict[str, list[tuple[float, str]]]:
|
|
221
221
|
"""Rerank with PyLate's rank.rerank using per-query candidates.
|
|
222
222
|
|
|
@@ -13,6 +13,7 @@ from mteb.similarity_functions import (
|
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
+
from numpy.typing import NDArray
|
|
16
17
|
from PIL import Image
|
|
17
18
|
from torch.utils.data import DataLoader
|
|
18
19
|
|
|
@@ -20,7 +21,7 @@ if TYPE_CHECKING:
|
|
|
20
21
|
from mteb.types._encoder_io import Array, BatchedInput, PromptType
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
def _string_to_vector(text: str | None, size: int) -> np.
|
|
24
|
+
def _string_to_vector(text: str | None, size: int) -> NDArray[np.floating]:
|
|
24
25
|
"""Generate a deterministic random vector based on a string.
|
|
25
26
|
|
|
26
27
|
Args:
|
|
@@ -39,7 +40,7 @@ def _string_to_vector(text: str | None, size: int) -> np.ndarray:
|
|
|
39
40
|
return rng.random(size, dtype=np.float32)
|
|
40
41
|
|
|
41
42
|
|
|
42
|
-
def _image_to_vector(image: Image.Image, size: int) -> np.
|
|
43
|
+
def _image_to_vector(image: Image.Image, size: int) -> NDArray[np.floating]:
|
|
43
44
|
"""Generate a deterministic random vector based on image content.
|
|
44
45
|
|
|
45
46
|
Args:
|
|
@@ -80,7 +81,7 @@ _common_mock_metadata = dict(
|
|
|
80
81
|
|
|
81
82
|
def _batch_to_embeddings(
|
|
82
83
|
inputs: DataLoader[BatchedInput], embedding_dim: int
|
|
83
|
-
) -> np.
|
|
84
|
+
) -> NDArray[np.floating]:
|
|
84
85
|
"""Convert batched text/image inputs into embeddings.
|
|
85
86
|
|
|
86
87
|
Args:
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import time
|
|
3
|
-
from typing import Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
import torch
|
|
@@ -14,6 +16,9 @@ from mteb.types import PromptType
|
|
|
14
16
|
from .bge_models import bge_chinese_training_data
|
|
15
17
|
from .nvidia_models import nvidia_training_datasets
|
|
16
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from mteb.types import Array
|
|
21
|
+
|
|
17
22
|
logger = logging.getLogger(__name__)
|
|
18
23
|
|
|
19
24
|
|
|
@@ -110,7 +115,7 @@ class SeedTextEmbeddingModel(AbsEncoder):
|
|
|
110
115
|
prompt_type: PromptType | None = None,
|
|
111
116
|
retries: int = 5,
|
|
112
117
|
**kwargs: Any,
|
|
113
|
-
) ->
|
|
118
|
+
) -> Array:
|
|
114
119
|
trimmed_sentences = []
|
|
115
120
|
for sentence in sentences:
|
|
116
121
|
encoded_sentence = self._encoding.encode(sentence)
|