mteb 2.1.4__py3-none-any.whl → 2.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +6 -0
- mteb/_create_dataloaders.py +22 -20
- mteb/_evaluators/any_sts_evaluator.py +23 -14
- mteb/_evaluators/classification_metrics.py +54 -0
- mteb/_evaluators/clustering_evaluator.py +3 -3
- mteb/_evaluators/evaluator.py +4 -2
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +18 -11
- mteb/_evaluators/pair_classification_evaluator.py +34 -40
- mteb/_evaluators/retrieval_evaluator.py +2 -2
- mteb/_evaluators/retrieval_metrics.py +18 -17
- mteb/_evaluators/sklearn_evaluator.py +25 -37
- mteb/_evaluators/text/bitext_mining_evaluator.py +31 -19
- mteb/_evaluators/text/summarization_evaluator.py +27 -20
- mteb/_evaluators/zeroshot_classification_evaluator.py +7 -5
- mteb/abstasks/_data_filter/__init__.py +0 -0
- mteb/abstasks/_data_filter/filters.py +125 -0
- mteb/abstasks/_data_filter/task_pipelines.py +105 -0
- mteb/abstasks/_statistics_calculation.py +23 -11
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +35 -28
- mteb/abstasks/aggregate_task_metadata.py +1 -9
- mteb/abstasks/aggregated_task.py +10 -29
- mteb/abstasks/classification.py +15 -12
- mteb/abstasks/clustering.py +20 -16
- mteb/abstasks/clustering_legacy.py +13 -10
- mteb/abstasks/image/image_text_pair_classification.py +7 -4
- mteb/abstasks/multilabel_classification.py +33 -22
- mteb/abstasks/pair_classification.py +27 -11
- mteb/abstasks/regression.py +4 -4
- mteb/abstasks/retrieval.py +28 -24
- mteb/abstasks/retrieval_dataset_loaders.py +2 -2
- mteb/abstasks/sts.py +14 -4
- mteb/abstasks/task_metadata.py +32 -33
- mteb/abstasks/text/bitext_mining.py +39 -28
- mteb/abstasks/text/reranking.py +8 -6
- mteb/abstasks/text/summarization.py +10 -5
- mteb/abstasks/zeroshot_classification.py +8 -4
- mteb/benchmarks/_create_table.py +84 -37
- mteb/benchmarks/benchmark.py +77 -16
- mteb/benchmarks/benchmarks/__init__.py +12 -0
- mteb/benchmarks/benchmarks/benchmarks.py +361 -16
- mteb/benchmarks/get_benchmark.py +14 -53
- mteb/cache.py +227 -37
- mteb/cli/_display_tasks.py +2 -2
- mteb/cli/build_cli.py +110 -14
- mteb/cli/generate_model_card.py +43 -23
- mteb/deprecated_evaluator.py +71 -62
- mteb/descriptive_stats/BitextMining/RuSciBenchBitextMining.v2.json +61 -0
- mteb/descriptive_stats/Classification/HebrewSentimentAnalysis.v3.json +60 -0
- mteb/descriptive_stats/Classification/TurkishConstitutionalCourtViolation.json +54 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3ComputerScienceRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3EnergyRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceEnRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceFrRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3HrRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3IndustrialRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3NuclearRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PharmaceuticalsRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PhysicsRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3TelecomRetrieval.json +214 -0
- mteb/descriptive_stats/PairClassification/TERRa.V2.json +35 -0
- mteb/descriptive_stats/Reranking/JQaRARerankingLite.json +35 -0
- mteb/descriptive_stats/Reranking/JaCWIRRerankingLite.json +35 -0
- mteb/descriptive_stats/Reranking/MultiLongDocReranking.json +466 -0
- mteb/descriptive_stats/Retrieval/ArguAna-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
- mteb/descriptive_stats/Retrieval/JaCWIRRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/JaqketRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/MIRACLJaRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/MrTyDiJaRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/NFCorpus-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/SCIDOCS-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/SQuADKorV1Retrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/SciFact-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
- mteb/evaluate.py +106 -75
- mteb/filter_tasks.py +25 -26
- mteb/get_tasks.py +29 -30
- mteb/languages/language_scripts.py +5 -3
- mteb/leaderboard/app.py +414 -151
- mteb/leaderboard/benchmark_selector.py +14 -5
- mteb/leaderboard/figures.py +13 -15
- mteb/leaderboard/table.py +82 -17
- mteb/load_results.py +12 -12
- mteb/models/__init__.py +4 -1
- mteb/models/abs_encoder.py +31 -23
- mteb/models/cache_wrappers/__init__.py +2 -1
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +7 -6
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +3 -3
- mteb/models/get_model_meta.py +25 -118
- mteb/models/instruct_wrapper.py +33 -9
- mteb/models/model_implementations/align_models.py +8 -1
- mteb/models/model_implementations/amazon_models.py +1 -0
- mteb/models/model_implementations/andersborges.py +65 -0
- mteb/models/model_implementations/ara_models.py +9 -1
- mteb/models/model_implementations/arctic_models.py +16 -8
- mteb/models/model_implementations/b1ade_models.py +2 -1
- mteb/models/model_implementations/bedrock_models.py +4 -0
- mteb/models/model_implementations/bge_models.py +101 -17
- mteb/models/model_implementations/bica_model.py +35 -0
- mteb/models/model_implementations/blip2_models.py +13 -2
- mteb/models/model_implementations/blip_models.py +43 -16
- mteb/models/model_implementations/bm25.py +5 -4
- mteb/models/model_implementations/bmretriever_models.py +10 -4
- mteb/models/model_implementations/cadet_models.py +10 -1
- mteb/models/model_implementations/cde_models.py +25 -4
- mteb/models/model_implementations/clip_models.py +9 -6
- mteb/models/model_implementations/clips_models.py +100 -0
- mteb/models/model_implementations/codefuse_models.py +165 -3
- mteb/models/model_implementations/codesage_models.py +18 -3
- mteb/models/model_implementations/cohere_models.py +13 -6
- mteb/models/model_implementations/cohere_v.py +7 -2
- mteb/models/model_implementations/colpali_models.py +17 -9
- mteb/models/model_implementations/colqwen_models.py +275 -5
- mteb/models/model_implementations/colsmol_models.py +4 -2
- mteb/models/model_implementations/conan_models.py +2 -1
- mteb/models/model_implementations/dino_models.py +194 -23
- mteb/models/model_implementations/e5_instruct.py +27 -4
- mteb/models/model_implementations/e5_models.py +21 -110
- mteb/models/model_implementations/e5_v.py +7 -6
- mteb/models/model_implementations/eagerworks_models.py +164 -0
- mteb/models/model_implementations/emillykkejensen_models.py +91 -0
- mteb/models/model_implementations/en_code_retriever.py +2 -1
- mteb/models/model_implementations/euler_models.py +32 -0
- mteb/models/model_implementations/evaclip_models.py +4 -0
- mteb/models/model_implementations/fa_models.py +67 -9
- mteb/models/model_implementations/facebookai.py +205 -0
- mteb/models/model_implementations/geogpt_models.py +2 -1
- mteb/models/model_implementations/gme_v_models.py +17 -10
- mteb/models/model_implementations/google_models.py +17 -6
- mteb/models/model_implementations/granite_vision_embedding_models.py +8 -3
- mteb/models/model_implementations/gritlm_models.py +4 -2
- mteb/models/model_implementations/gte_models.py +99 -9
- mteb/models/model_implementations/hinvec_models.py +2 -1
- mteb/models/model_implementations/human.py +1 -0
- mteb/models/model_implementations/ibm_granite_models.py +36 -6
- mteb/models/model_implementations/inf_models.py +4 -2
- mteb/models/model_implementations/jasper_models.py +256 -3
- mteb/models/model_implementations/jina_clip.py +49 -10
- mteb/models/model_implementations/jina_models.py +222 -11
- mteb/models/model_implementations/kalm_models.py +203 -25
- mteb/models/model_implementations/kblab.py +37 -0
- mteb/models/model_implementations/kennethenevoldsen_models.py +74 -0
- mteb/models/model_implementations/kfst.py +25 -0
- mteb/models/model_implementations/kowshik24_models.py +32 -0
- mteb/models/model_implementations/lens_models.py +2 -0
- mteb/models/model_implementations/lgai_embedding_models.py +2 -1
- mteb/models/model_implementations/linq_models.py +4 -3
- mteb/models/model_implementations/listconranker.py +2 -2
- mteb/models/model_implementations/llm2clip_models.py +9 -6
- mteb/models/model_implementations/llm2vec_models.py +16 -8
- mteb/models/model_implementations/mcinext_models.py +7 -1
- mteb/models/model_implementations/mdbr_models.py +19 -3
- mteb/models/model_implementations/misc_models.py +422 -60
- mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
- mteb/models/model_implementations/mme5_models.py +2 -1
- mteb/models/model_implementations/moco_models.py +15 -4
- mteb/models/model_implementations/mod_models.py +191 -0
- mteb/models/model_implementations/model2vec_models.py +27 -14
- mteb/models/model_implementations/moka_models.py +4 -1
- mteb/models/model_implementations/nbailab.py +70 -0
- mteb/models/model_implementations/no_instruct_sentence_models.py +3 -2
- mteb/models/model_implementations/nomic_models.py +173 -6
- mteb/models/model_implementations/nomic_models_vision.py +8 -3
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +32 -19
- mteb/models/model_implementations/nvidia_models.py +155 -20
- mteb/models/model_implementations/octen_models.py +254 -0
- mteb/models/model_implementations/openai_models.py +20 -16
- mteb/models/model_implementations/openclip_models.py +37 -13
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +10 -5
- mteb/models/model_implementations/ops_moa_models.py +5 -3
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
- mteb/models/model_implementations/pawan_models.py +39 -0
- mteb/models/model_implementations/piccolo_models.py +9 -1
- mteb/models/model_implementations/pixie_models.py +56 -0
- mteb/models/model_implementations/promptriever_models.py +12 -8
- mteb/models/model_implementations/pylate_models.py +46 -12
- mteb/models/model_implementations/qodo_models.py +4 -2
- mteb/models/model_implementations/qtack_models.py +2 -1
- mteb/models/model_implementations/qwen3_models.py +9 -6
- mteb/models/model_implementations/qzhou_models.py +5 -3
- mteb/models/model_implementations/random_baseline.py +19 -24
- mteb/models/model_implementations/rasgaard_models.py +34 -0
- mteb/models/model_implementations/reasonir_model.py +2 -1
- mteb/models/model_implementations/repllama_models.py +5 -3
- mteb/models/model_implementations/rerankers_custom.py +15 -9
- mteb/models/model_implementations/rerankers_monot5_based.py +31 -31
- mteb/models/model_implementations/richinfoai_models.py +2 -1
- mteb/models/model_implementations/ru_sentence_models.py +71 -20
- mteb/models/model_implementations/ruri_models.py +322 -0
- mteb/models/model_implementations/salesforce_models.py +6 -3
- mteb/models/model_implementations/samilpwc_models.py +2 -1
- mteb/models/model_implementations/sarashina_embedding_models.py +168 -0
- mteb/models/model_implementations/searchmap_models.py +2 -1
- mteb/models/model_implementations/seed_1_6_embedding_models.py +8 -2
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +625 -0
- mteb/models/model_implementations/seed_models.py +1 -0
- mteb/models/model_implementations/sentence_transformers_models.py +177 -18
- mteb/models/model_implementations/shuu_model.py +32 -31
- mteb/models/model_implementations/siglip_models.py +30 -20
- mteb/models/model_implementations/slm_models.py +416 -0
- mteb/models/model_implementations/sonar_models.py +1 -0
- mteb/models/model_implementations/spartan8806_atles_champion.py +34 -0
- mteb/models/model_implementations/stella_models.py +23 -4
- mteb/models/model_implementations/tarka_models.py +376 -0
- mteb/models/model_implementations/text2vec_models.py +9 -3
- mteb/models/model_implementations/ua_sentence_models.py +11 -1
- mteb/models/model_implementations/uae_models.py +8 -1
- mteb/models/model_implementations/vdr_models.py +3 -1
- mteb/models/model_implementations/vi_vn_models.py +45 -6
- mteb/models/model_implementations/vista_models.py +2 -0
- mteb/models/model_implementations/vlm2vec_models.py +5 -3
- mteb/models/model_implementations/voyage_models.py +99 -0
- mteb/models/model_implementations/voyage_v.py +17 -9
- mteb/models/model_implementations/xyz_models.py +1 -0
- mteb/models/model_implementations/youtu_models.py +2 -1
- mteb/models/model_implementations/yuan_models.py +34 -0
- mteb/models/model_implementations/yuan_models_en.py +58 -0
- mteb/models/model_meta.py +498 -29
- mteb/models/models_protocols.py +22 -6
- mteb/models/search_encoder_index/__init__.py +7 -0
- mteb/models/search_encoder_index/search_backend_protocol.py +50 -0
- mteb/models/search_encoder_index/search_indexes/__init__.py +5 -0
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +160 -0
- mteb/models/search_wrappers.py +197 -65
- mteb/models/sentence_transformer_wrapper.py +52 -32
- mteb/models/vllm_wrapper.py +327 -0
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +114 -65
- mteb/results/model_result.py +63 -26
- mteb/results/task_result.py +117 -77
- mteb/similarity_functions.py +60 -7
- mteb/tasks/bitext_mining/multilingual/__init__.py +2 -1
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +4 -2
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining_fast.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ru_sci_bench_bitext_mining.py +47 -5
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -6
- mteb/tasks/classification/ara/ajgt.py +1 -2
- mteb/tasks/classification/ara/hotel_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/tweet_emotion_classification.py +1 -2
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_document_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -2
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -2
- mteb/tasks/classification/dan/angry_tweets_classification.py +1 -2
- mteb/tasks/classification/dan/danish_political_comments_classification.py +1 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -2
- mteb/tasks/classification/dan/dk_hate_classification.py +2 -3
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -2
- mteb/tasks/classification/deu/ten_k_gnad_classification.py +1 -2
- mteb/tasks/classification/eng/amazon_polarity_classification.py +1 -2
- mteb/tasks/classification/eng/arxiv_classification.py +1 -2
- mteb/tasks/classification/eng/banking77_classification.py +1 -2
- mteb/tasks/classification/eng/dbpedia_classification.py +1 -2
- mteb/tasks/classification/eng/emotion_classification.py +1 -2
- mteb/tasks/classification/eng/financial_phrasebank_classification.py +1 -2
- mteb/tasks/classification/eng/frenk_en_classification.py +1 -2
- mteb/tasks/classification/eng/gtsrb_classification.py +1 -1
- mteb/tasks/classification/eng/imdb_classification.py +1 -2
- mteb/tasks/classification/eng/legal_bench_classification.py +14 -120
- mteb/tasks/classification/eng/news_classification.py +1 -2
- mteb/tasks/classification/eng/patch_camelyon_classification.py +1 -1
- mteb/tasks/classification/eng/patent_classification.py +1 -2
- mteb/tasks/classification/eng/poem_sentiment_classification.py +1 -2
- mteb/tasks/classification/eng/sds_eye_protection_classification.py +1 -2
- mteb/tasks/classification/eng/sds_gloves_classification.py +1 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -19
- mteb/tasks/classification/eng/toxic_conversations_classification.py +1 -2
- mteb/tasks/classification/eng/tweet_sentiment_extraction_classification.py +1 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +2 -13
- mteb/tasks/classification/eng/ucf101_classification.py +1 -5
- mteb/tasks/classification/eng/wikipedia_bio_met_chem_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_chem_fields_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_comp_chem_spectroscopy_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_crystallography_analytical_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_theoretical_applied_classification.py +1 -2
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -2
- mteb/tasks/classification/eng/yelp_review_full_classification.py +1 -2
- mteb/tasks/classification/est/estonian_valence.py +2 -3
- mteb/tasks/classification/fas/fa_mteb_classification.py +7 -14
- mteb/tasks/classification/fil/filipino_hate_speech_classification.py +1 -2
- mteb/tasks/classification/fin/fin_toxicity_classification.py +2 -11
- mteb/tasks/classification/fra/french_book_reviews.py +1 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -2
- mteb/tasks/classification/heb/__init__.py +6 -1
- mteb/tasks/classification/heb/hebrew_sentiment_analysis.py +62 -4
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -2
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -2
- mteb/tasks/classification/hrv/frenk_hr_classification.py +1 -2
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +1 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -2
- mteb/tasks/classification/ita/italian_linguist_acceptability_classification.py +1 -2
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -2
- mteb/tasks/classification/jpn/wrime_classification.py +1 -2
- mteb/tasks/classification/kan/kannada_news_classification.py +1 -2
- mteb/tasks/classification/kor/klue_tc.py +1 -2
- mteb/tasks/classification/kor/kor_hate_classification.py +2 -17
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +2 -19
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +3 -4
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -2
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -2
- mteb/tasks/classification/mkd/macedonian_tweet_sentiment_classification.py +1 -2
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -6
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -4
- mteb/tasks/classification/multilingual/ru_sci_bench_classification.py +4 -23
- mteb/tasks/classification/multilingual/scala_classification.py +2 -3
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -6
- mteb/tasks/classification/mya/myanmar_news.py +1 -2
- mteb/tasks/classification/nep/nepali_news_classification.py +1 -2
- mteb/tasks/classification/nld/dutch_book_review_sentiment_classification.py +4 -2
- mteb/tasks/classification/nld/dutch_cola_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_government_bias_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_news_articles_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +3 -0
- mteb/tasks/classification/nld/iconclass_classification.py +3 -0
- mteb/tasks/classification/nld/open_tender_classification.py +3 -0
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +3 -0
- mteb/tasks/classification/nob/no_rec_classification.py +1 -2
- mteb/tasks/classification/nob/norwegian_parliament_classification.py +1 -2
- mteb/tasks/classification/ory/odia_news_classification.py +1 -2
- mteb/tasks/classification/pol/polish_classification.py +3 -6
- mteb/tasks/classification/ron/moroco.py +1 -2
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -2
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -2
- mteb/tasks/classification/rus/georeview_classification.py +1 -2
- mteb/tasks/classification/rus/headline_classification.py +1 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +1 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +1 -2
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -2
- mteb/tasks/classification/rus/senti_ru_eval.py +1 -2
- mteb/tasks/classification/sin/sinhala_news_classification.py +1 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +1 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/slk/slovak_hate_speech_classification.py +1 -2
- mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -2
- mteb/tasks/classification/spa/spanish_news_classification.py +1 -2
- mteb/tasks/classification/spa/spanish_sentiment_classification.py +1 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -2
- mteb/tasks/classification/swa/swahili_news_classification.py +1 -2
- mteb/tasks/classification/swe/dalaj_classification.py +1 -2
- mteb/tasks/classification/swe/swe_rec_classification.py +1 -2
- mteb/tasks/classification/swe/swedish_sentiment_classification.py +1 -2
- mteb/tasks/classification/tam/tamil_news_classification.py +1 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +1 -2
- mteb/tasks/classification/tha/wisesight_sentiment_classification.py +1 -2
- mteb/tasks/classification/tsn/tswana_news_classification.py +1 -2
- mteb/tasks/classification/tur/__init__.py +4 -0
- mteb/tasks/classification/tur/turkish_constitutional_court.py +41 -0
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +1 -2
- mteb/tasks/classification/tur/turkish_product_sentiment_classification.py +1 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -15
- mteb/tasks/classification/urd/urdu_roman_sentiment_classification.py +1 -2
- mteb/tasks/classification/vie/amazon_counterfactual_vn_classification.py +1 -6
- mteb/tasks/classification/vie/amazon_polarity_vn_classification.py +1 -6
- mteb/tasks/classification/vie/amazon_reviews_vn_classification.py +1 -5
- mteb/tasks/classification/vie/banking77_vn_classification.py +1 -5
- mteb/tasks/classification/vie/emotion_vn_classification.py +1 -5
- mteb/tasks/classification/vie/imdb_vn_classification.py +1 -5
- mteb/tasks/classification/vie/massive_intent_vn_classification.py +1 -5
- mteb/tasks/classification/vie/massive_scenario_vn_classification.py +1 -5
- mteb/tasks/classification/vie/mtop_domain_vn_classification.py +1 -5
- mteb/tasks/classification/vie/mtop_intent_vn_classification.py +1 -5
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -5
- mteb/tasks/classification/vie/tweet_sentiment_extraction_vn_classification.py +1 -5
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -2
- mteb/tasks/classification/zho/cmteb_classification.py +5 -10
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +1 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -2
- mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/jpn/mews_c16_ja_clustering.py +1 -3
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -6
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/open_tender_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +3 -0
- mteb/tasks/clustering/vie/reddit_clustering_p2p_vn.py +1 -5
- mteb/tasks/clustering/vie/reddit_clustering_vn.py +1 -5
- mteb/tasks/clustering/vie/stack_exchange_clustering_p2p_vn.py +1 -5
- mteb/tasks/clustering/vie/stack_exchange_clustering_vn.py +1 -5
- mteb/tasks/clustering/vie/twenty_newsgroups_clustering_vn.py +1 -5
- mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -5
- mteb/tasks/multilabel_classification/kor/kor_hate_speech_ml_classification.py +1 -9
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -6
- mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +3 -0
- mteb/tasks/multilabel_classification/nld/vabb_multi_label_classification.py +3 -0
- mteb/tasks/multilabel_classification/por/brazilian_toxic_tweets_classification.py +1 -6
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -2
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -6
- mteb/tasks/pair_classification/eng/legal_bench_pc.py +1 -9
- mteb/tasks/pair_classification/nld/sick_nl_pair_classification.py +3 -0
- mteb/tasks/pair_classification/nld/xlwic_nl_pair_classification.py +3 -0
- mteb/tasks/pair_classification/rus/__init__.py +2 -2
- mteb/tasks/pair_classification/rus/terra.py +51 -25
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -5
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -5
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -5
- mteb/tasks/regression/multilingual/ru_sci_bench_regression.py +2 -6
- mteb/tasks/reranking/jpn/__init__.py +9 -1
- mteb/tasks/reranking/jpn/j_qa_ra_reranking_lite.py +49 -0
- mteb/tasks/reranking/jpn/ja_cwir_reranking_lite.py +47 -0
- mteb/tasks/reranking/multilingual/__init__.py +2 -0
- mteb/tasks/reranking/multilingual/multi_long_doc_reranking.py +70 -0
- mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
- mteb/tasks/reranking/multilingual/x_glue_wpr_reranking.py +1 -2
- mteb/tasks/reranking/vie/ask_ubuntu_dup_questions_vn.py +1 -5
- mteb/tasks/reranking/vie/sci_docs_reranking_vn.py +1 -5
- mteb/tasks/reranking/vie/stack_overflow_dup_questions_vn.py +1 -5
- mteb/tasks/retrieval/code/code_rag.py +12 -12
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +8 -5
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/eng/__init__.py +2 -0
- mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
- mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -8
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +4 -0
- mteb/tasks/retrieval/jpn/__init__.py +8 -0
- mteb/tasks/retrieval/jpn/ja_cwir_retrieval.py +1 -4
- mteb/tasks/retrieval/jpn/ja_cwir_retrieval_lite.py +47 -0
- mteb/tasks/retrieval/jpn/jaqket_retrieval_lite.py +50 -0
- mteb/tasks/retrieval/jpn/miracl_ja_retrieval_lite.py +52 -0
- mteb/tasks/retrieval/jpn/mr_tydi_ja_retrieval_lite.py +48 -0
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +11 -4
- mteb/tasks/retrieval/kor/__init__.py +16 -1
- mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
- mteb/tasks/retrieval/kor/squad_kor_v1_retrieval.py +47 -0
- mteb/tasks/retrieval/multilingual/__init__.py +24 -0
- mteb/tasks/retrieval/multilingual/belebele_retrieval.py +5 -4
- mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +56 -42
- mteb/tasks/retrieval/multilingual/mkqa_retrieval.py +1 -2
- mteb/tasks/retrieval/multilingual/mlqa_retrieval.py +1 -4
- mteb/tasks/retrieval/multilingual/multi_long_doc_retrieval.py +1 -2
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +9 -4
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -12
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +4 -2
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +389 -0
- mteb/tasks/retrieval/nld/__init__.py +8 -4
- mteb/tasks/retrieval/nld/argu_ana_nl_retrieval.py +46 -27
- mteb/tasks/retrieval/nld/bbsard_nl_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/dutch_news_articles_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/legal_qa_nl_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/nf_corpus_nl_retrieval.py +42 -25
- mteb/tasks/retrieval/nld/open_tender_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/sci_fact_nl_retrieval.py +42 -24
- mteb/tasks/retrieval/nld/scidocsnl_retrieval.py +44 -27
- mteb/tasks/retrieval/nld/vabb_retrieval.py +3 -0
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -7
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/tasks/retrieval/vie/__init__.py +14 -6
- mteb/tasks/retrieval/vie/argu_ana_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +40 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_android_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_gis_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_mathematica_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_physics_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_programmers_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_stats_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_tex_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_unix_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_webmasters_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_wordpress_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +40 -5
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +40 -7
- mteb/tasks/retrieval/vie/fi_qa2018_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/green_node_table_markdown_retrieval.py +16 -1
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +40 -6
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +49 -5
- mteb/tasks/retrieval/vie/nf_corpus_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +40 -5
- mteb/tasks/retrieval/vie/quora_vn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/sci_fact_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/scidocsvn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/touche2020_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/treccovidvn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
- mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
- mteb/tasks/sts/nld/sick_nl_sts.py +1 -0
- mteb/tasks/sts/vie/biosses_stsvn.py +1 -5
- mteb/tasks/sts/vie/sickr_stsvn.py +1 -5
- mteb/tasks/sts/vie/sts_benchmark_stsvn.py +1 -5
- mteb/tasks/zeroshot_classification/eng/gtsrb.py +1 -1
- mteb/tasks/zeroshot_classification/eng/patch_camelyon.py +1 -1
- mteb/tasks/zeroshot_classification/eng/ucf101.py +1 -5
- mteb/types/__init__.py +2 -0
- mteb/types/_encoder_io.py +19 -2
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +9 -3
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/METADATA +25 -8
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/RECORD +525 -438
- mteb/models/model_implementations/mxbai_models.py +0 -102
- mteb/models/model_implementations/nb_sbert.py +0 -25
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
mteb/__init__.py
CHANGED
|
@@ -3,14 +3,17 @@ from importlib.metadata import version
|
|
|
3
3
|
from mteb import types
|
|
4
4
|
from mteb.abstasks import AbsTask
|
|
5
5
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
6
|
+
from mteb.cache import ResultCache
|
|
6
7
|
from mteb.deprecated_evaluator import MTEB
|
|
7
8
|
from mteb.evaluate import evaluate
|
|
8
9
|
from mteb.filter_tasks import filter_tasks
|
|
9
10
|
from mteb.get_tasks import get_task, get_tasks
|
|
10
11
|
from mteb.load_results import load_results
|
|
11
12
|
from mteb.models import (
|
|
13
|
+
CacheBackendProtocol,
|
|
12
14
|
CrossEncoderProtocol,
|
|
13
15
|
EncoderProtocol,
|
|
16
|
+
IndexEncoderSearchProtocol,
|
|
14
17
|
SearchProtocol,
|
|
15
18
|
SentenceTransformerEncoderWrapper,
|
|
16
19
|
)
|
|
@@ -27,8 +30,11 @@ __all__ = [
|
|
|
27
30
|
"AbsTask",
|
|
28
31
|
"Benchmark",
|
|
29
32
|
"BenchmarkResults",
|
|
33
|
+
"CacheBackendProtocol",
|
|
30
34
|
"CrossEncoderProtocol",
|
|
31
35
|
"EncoderProtocol",
|
|
36
|
+
"IndexEncoderSearchProtocol",
|
|
37
|
+
"ResultCache",
|
|
32
38
|
"SearchProtocol",
|
|
33
39
|
"SentenceTransformerEncoderWrapper",
|
|
34
40
|
"TaskMetadata",
|
mteb/_create_dataloaders.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import warnings
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any, cast
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
|
-
from datasets import Dataset
|
|
7
|
+
from datasets import Dataset, Image
|
|
7
8
|
from torch.utils.data import DataLoader, default_collate
|
|
8
9
|
|
|
9
10
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
@@ -22,12 +23,14 @@ logger = logging.getLogger(__name__)
|
|
|
22
23
|
def _create_dataloader_from_texts(
|
|
23
24
|
text: list[str],
|
|
24
25
|
batch_size: int = 32,
|
|
26
|
+
**kwargs: Any,
|
|
25
27
|
) -> DataLoader[TextInput]:
|
|
26
28
|
"""Create a dataloader from a list of text.
|
|
27
29
|
|
|
28
30
|
Args:
|
|
29
31
|
text: A list of text to create a dataloader from.
|
|
30
32
|
batch_size: Batch size for the dataloader.
|
|
33
|
+
kwargs: Not used, present catching extra arguments.
|
|
31
34
|
|
|
32
35
|
Returns:
|
|
33
36
|
A dataloader with the text.
|
|
@@ -111,11 +114,8 @@ def _create_text_dataloader_for_queries(
|
|
|
111
114
|
)
|
|
112
115
|
|
|
113
116
|
|
|
114
|
-
_warned_about_user_role = False
|
|
115
|
-
|
|
116
|
-
|
|
117
117
|
def _convert_conv_history_to_query(
|
|
118
|
-
row: dict[str, list[str] | Conversation],
|
|
118
|
+
row: dict[str, str | list[str] | Conversation],
|
|
119
119
|
) -> dict[str, str | Conversation]:
|
|
120
120
|
"""Convert a conversation history to a single query string.
|
|
121
121
|
|
|
@@ -125,21 +125,18 @@ def _convert_conv_history_to_query(
|
|
|
125
125
|
Returns:
|
|
126
126
|
The updated row with the "query" and "text" fields set to the conversation string, and the "conversation" field set to the list of ConversationTurn.
|
|
127
127
|
"""
|
|
128
|
-
global _warned_about_user_role
|
|
129
|
-
|
|
130
128
|
conversation = row["text"]
|
|
131
129
|
# if it's a list of strings, just join them
|
|
132
130
|
if isinstance(conversation, list) and isinstance(conversation[0], str):
|
|
133
|
-
|
|
134
|
-
conv_str = "; ".join(
|
|
131
|
+
conversation_ = cast(list[str], conversation)
|
|
132
|
+
conv_str = "; ".join(conversation_)
|
|
135
133
|
current_conversation = [
|
|
136
|
-
ConversationTurn(role="user", content=message) for message in
|
|
134
|
+
ConversationTurn(role="user", content=message) for message in conversation_
|
|
137
135
|
]
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
_warned_about_user_role = True
|
|
136
|
+
warnings.warn(
|
|
137
|
+
"Conversations are a list of strings. Used 'user' role for all turns.",
|
|
138
|
+
category=UserWarning,
|
|
139
|
+
)
|
|
143
140
|
# otherwise, it's a list of dictionaries, which we need to convert to strings
|
|
144
141
|
elif isinstance(conversation, list) and isinstance(conversation[0], dict):
|
|
145
142
|
conv = []
|
|
@@ -176,7 +173,7 @@ def _convert_conv_history_to_query(
|
|
|
176
173
|
|
|
177
174
|
row["text"] = conv_str
|
|
178
175
|
row["conversation"] = current_conversation
|
|
179
|
-
return row
|
|
176
|
+
return cast(dict[str, str | list[ConversationTurn]], row)
|
|
180
177
|
|
|
181
178
|
|
|
182
179
|
def _create_dataloader_for_queries_conversation(
|
|
@@ -194,7 +191,8 @@ def _create_dataloader_for_queries_conversation(
|
|
|
194
191
|
"""
|
|
195
192
|
return DataLoader(
|
|
196
193
|
queries.map(
|
|
197
|
-
_convert_conv_history_to_query,
|
|
194
|
+
_convert_conv_history_to_query,
|
|
195
|
+
desc="Converting conversations to queries",
|
|
198
196
|
),
|
|
199
197
|
collate_fn=_custom_collate_fn,
|
|
200
198
|
batch_size=batch_size,
|
|
@@ -244,14 +242,15 @@ def _prepare_image_dataset(
|
|
|
244
242
|
transform: Callable[[Any], Any] | None = None,
|
|
245
243
|
) -> Dataset:
|
|
246
244
|
"""Prepare the image dataset by converting images to RGB and applying transformations."""
|
|
247
|
-
# If the dataset uses a different column name for images, rename it to "image".
|
|
248
245
|
if (
|
|
249
246
|
image_column_name
|
|
250
247
|
and image_column_name in dataset.column_names
|
|
251
248
|
and "image" not in dataset.column_names
|
|
252
249
|
):
|
|
253
250
|
dataset = dataset.rename_column(image_column_name, "image")
|
|
254
|
-
#
|
|
251
|
+
# don't process image if it's already in the correct format
|
|
252
|
+
if isinstance(dataset.features["image"], Image):
|
|
253
|
+
return dataset
|
|
255
254
|
return dataset.map(
|
|
256
255
|
_convert_images_to_rgb,
|
|
257
256
|
fn_kwargs={"image_col_name": "image", "transform": transform},
|
|
@@ -363,6 +362,9 @@ def _create_document_dataloader(
|
|
|
363
362
|
task_metadata: Metadata of the task to determine the document type.
|
|
364
363
|
input_column: The column to use as input. If None, it will use the first column that matches the modality.
|
|
365
364
|
batch_size: Batch size for the dataloader.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
A dataloader for the documents.
|
|
366
368
|
"""
|
|
367
369
|
document_type = task_metadata.get_modalities(PromptType.document)
|
|
368
370
|
if document_type == ["text"]: # text only
|
|
@@ -385,7 +387,7 @@ def create_dataloader(
|
|
|
385
387
|
prompt_type: PromptType | None = None,
|
|
386
388
|
input_column: str | None = None,
|
|
387
389
|
batch_size: int = 32,
|
|
388
|
-
**kwargs:
|
|
390
|
+
**kwargs: Any,
|
|
389
391
|
) -> DataLoader[BatchedInput]:
|
|
390
392
|
"""Create a dataloader from a dataset.
|
|
391
393
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import TypedDict
|
|
3
3
|
|
|
4
4
|
from datasets import Dataset
|
|
5
5
|
from sklearn.metrics.pairwise import (
|
|
@@ -12,6 +12,7 @@ from mteb._create_dataloaders import create_dataloader
|
|
|
12
12
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
13
13
|
from mteb.models import EncoderProtocol
|
|
14
14
|
from mteb.similarity_functions import compute_pairwise_similarity
|
|
15
|
+
from mteb.types import EncodeKwargs, PromptType
|
|
15
16
|
|
|
16
17
|
from .evaluator import Evaluator
|
|
17
18
|
|
|
@@ -42,44 +43,52 @@ class AnySTSEvaluator(Evaluator):
|
|
|
42
43
|
task_metadata: TaskMetadata,
|
|
43
44
|
hf_split: str,
|
|
44
45
|
hf_subset: str,
|
|
46
|
+
input1_prompt_type: PromptType | None,
|
|
47
|
+
input2_prompt_type: PromptType | None,
|
|
45
48
|
**kwargs,
|
|
46
49
|
) -> None:
|
|
47
50
|
super().__init__(**kwargs)
|
|
48
|
-
self.
|
|
49
|
-
|
|
50
|
-
task_metadata,
|
|
51
|
-
input_column=sentences_column_names[0],
|
|
52
|
-
)
|
|
53
|
-
self.second_column = create_dataloader(
|
|
54
|
-
dataset,
|
|
55
|
-
task_metadata,
|
|
56
|
-
input_column=sentences_column_names[1],
|
|
57
|
-
)
|
|
51
|
+
self.dataset = dataset
|
|
52
|
+
self.input_columns = sentences_column_names
|
|
58
53
|
self.task_metadata = task_metadata
|
|
59
54
|
self.hf_split = hf_split
|
|
60
55
|
self.hf_subset = hf_subset
|
|
56
|
+
self.input1_prompt_type = input1_prompt_type
|
|
57
|
+
self.input2_prompt_type = input2_prompt_type
|
|
61
58
|
|
|
62
59
|
def __call__(
|
|
63
60
|
self,
|
|
64
61
|
model: EncoderProtocol,
|
|
65
62
|
*,
|
|
66
|
-
encode_kwargs:
|
|
63
|
+
encode_kwargs: EncodeKwargs,
|
|
67
64
|
) -> STSEvaluatorScores:
|
|
68
65
|
logger.info("Running semantic similarity - Encoding samples (1/2)")
|
|
69
66
|
embeddings1 = model.encode(
|
|
70
|
-
|
|
67
|
+
create_dataloader(
|
|
68
|
+
self.dataset,
|
|
69
|
+
self.task_metadata,
|
|
70
|
+
input_column=self.input_columns[0],
|
|
71
|
+
**encode_kwargs,
|
|
72
|
+
),
|
|
71
73
|
task_metadata=self.task_metadata,
|
|
72
74
|
hf_split=self.hf_split,
|
|
73
75
|
hf_subset=self.hf_subset,
|
|
76
|
+
prompt_type=self.input1_prompt_type,
|
|
74
77
|
**encode_kwargs,
|
|
75
78
|
)
|
|
76
79
|
|
|
77
80
|
logger.info("Running semantic similarity - Encoding samples (2/2)...")
|
|
78
81
|
embeddings2 = model.encode(
|
|
79
|
-
|
|
82
|
+
create_dataloader(
|
|
83
|
+
self.dataset,
|
|
84
|
+
self.task_metadata,
|
|
85
|
+
input_column=self.input_columns[1],
|
|
86
|
+
**encode_kwargs,
|
|
87
|
+
),
|
|
80
88
|
task_metadata=self.task_metadata,
|
|
81
89
|
hf_split=self.hf_split,
|
|
82
90
|
hf_subset=self.hf_subset,
|
|
91
|
+
prompt_type=self.input2_prompt_type,
|
|
83
92
|
**encode_kwargs,
|
|
84
93
|
)
|
|
85
94
|
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def hamming_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
5
|
+
"""Compute the Hamming score (a.k.a. label-based accuracy) for multilabel classification.
|
|
6
|
+
|
|
7
|
+
The Hamming score is the fraction of labels that are correctly predicted for each sample,
|
|
8
|
+
averaged over all samples. For samples where both y_true and y_pred have no labels,
|
|
9
|
+
the score is 1.0 (perfect agreement).
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
y_true: Binary matrix of true labels with shape (n_samples, n_labels)
|
|
13
|
+
y_pred: Binary matrix of predicted labels with shape (n_samples, n_labels)
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
float: Hamming score between 0.0 and 1.0
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
ValueError: If inputs are invalid or have incompatible shapes
|
|
20
|
+
TypeError: If inputs cannot be converted to numpy arrays
|
|
21
|
+
"""
|
|
22
|
+
y_true = np.asarray(y_true)
|
|
23
|
+
y_pred = np.asarray(y_pred)
|
|
24
|
+
|
|
25
|
+
# Check shapes
|
|
26
|
+
if y_true.shape != y_pred.shape:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Shape mismatch: y_true {y_true.shape} != y_pred {y_pred.shape}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Check if arrays are empty
|
|
32
|
+
if y_true.size == 0:
|
|
33
|
+
raise ValueError("Input arrays cannot be empty")
|
|
34
|
+
|
|
35
|
+
# Ensure 2D arrays
|
|
36
|
+
if y_true.ndim != 2:
|
|
37
|
+
raise ValueError(f"Arrays must be 2D, got {y_true.ndim}D")
|
|
38
|
+
|
|
39
|
+
# Check for binary values
|
|
40
|
+
if not (np.all(np.isin(y_true, [0, 1])) and np.all(np.isin(y_pred, [0, 1]))):
|
|
41
|
+
raise ValueError("Arrays must contain only binary values (0 and 1)")
|
|
42
|
+
|
|
43
|
+
# Convert to boolean for bitwise operations
|
|
44
|
+
y_true_bool = y_true.astype(bool)
|
|
45
|
+
y_pred_bool = y_pred.astype(bool)
|
|
46
|
+
|
|
47
|
+
# Calculate intersection and union for each sample
|
|
48
|
+
intersection = (y_true_bool & y_pred_bool).sum(axis=1)
|
|
49
|
+
union = (y_true_bool | y_pred_bool).sum(axis=1)
|
|
50
|
+
|
|
51
|
+
# Handle division by zero: when union is 0, both are all zeros, so score is 1.0
|
|
52
|
+
scores = np.where(union == 0, 1.0, intersection / union)
|
|
53
|
+
|
|
54
|
+
return float(scores.mean())
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any
|
|
3
2
|
|
|
4
3
|
from datasets import Dataset
|
|
5
4
|
from sklearn import cluster
|
|
@@ -7,6 +6,7 @@ from sklearn import cluster
|
|
|
7
6
|
from mteb._create_dataloaders import create_dataloader
|
|
8
7
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
9
8
|
from mteb.models import EncoderProtocol
|
|
9
|
+
from mteb.types import EncodeKwargs
|
|
10
10
|
|
|
11
11
|
from .evaluator import Evaluator
|
|
12
12
|
|
|
@@ -38,13 +38,13 @@ class ClusteringEvaluator(Evaluator):
|
|
|
38
38
|
self,
|
|
39
39
|
model: EncoderProtocol,
|
|
40
40
|
*,
|
|
41
|
-
encode_kwargs:
|
|
41
|
+
encode_kwargs: EncodeKwargs,
|
|
42
42
|
) -> list[int]:
|
|
43
43
|
data_loader = create_dataloader(
|
|
44
44
|
self.dataset,
|
|
45
45
|
self.task_metadata,
|
|
46
46
|
input_column=self.input_column_name,
|
|
47
|
-
|
|
47
|
+
**encode_kwargs,
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
logger.info("Running clustering - Encoding samples...")
|
mteb/_evaluators/evaluator.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Iterable, Mapping
|
|
2
3
|
from typing import Any
|
|
3
4
|
|
|
4
5
|
from mteb.abstasks.abstask import _set_seed
|
|
5
6
|
from mteb.models import EncoderProtocol
|
|
7
|
+
from mteb.types import EncodeKwargs
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class Evaluator(ABC):
|
|
@@ -17,8 +19,8 @@ class Evaluator(ABC):
|
|
|
17
19
|
|
|
18
20
|
@abstractmethod
|
|
19
21
|
def __call__(
|
|
20
|
-
self, model: EncoderProtocol, *, encode_kwargs:
|
|
21
|
-
) ->
|
|
22
|
+
self, model: EncoderProtocol, *, encode_kwargs: EncodeKwargs
|
|
23
|
+
) -> Mapping[str, float] | Iterable[Any]:
|
|
22
24
|
"""This is called during training to evaluate the model.
|
|
23
25
|
|
|
24
26
|
It returns scores.
|
|
@@ -1,19 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
|
-
from
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
3
6
|
|
|
4
7
|
import torch
|
|
5
8
|
import torch.nn.functional as F
|
|
6
|
-
from datasets import Dataset
|
|
7
|
-
from PIL.Image import Image
|
|
8
9
|
from torch.utils.data import DataLoader
|
|
9
10
|
|
|
10
11
|
from mteb._create_dataloaders import (
|
|
12
|
+
_create_dataloader_from_texts,
|
|
11
13
|
_transform_image_to_rgb,
|
|
12
14
|
)
|
|
13
15
|
from mteb._evaluators.evaluator import Evaluator
|
|
14
16
|
from mteb._requires_package import requires_image_dependencies
|
|
15
17
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
16
18
|
from mteb.models.models_protocols import EncoderProtocol
|
|
19
|
+
from mteb.types import EncodeKwargs
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from PIL.Image import Image
|
|
23
|
+
|
|
17
24
|
|
|
18
25
|
logger = logging.getLogger(__name__)
|
|
19
26
|
|
|
@@ -56,8 +63,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
56
63
|
def __init__(
|
|
57
64
|
self,
|
|
58
65
|
dataset,
|
|
59
|
-
images_column_names: str |
|
|
60
|
-
texts_column_names: str |
|
|
66
|
+
images_column_names: str | Sequence[str],
|
|
67
|
+
texts_column_names: str | Sequence[str],
|
|
61
68
|
num_images_per_sample: int,
|
|
62
69
|
num_texts_per_sample: int,
|
|
63
70
|
task_metadata: TaskMetadata,
|
|
@@ -77,10 +84,11 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
77
84
|
self.hf_split = hf_split
|
|
78
85
|
self.hf_subset = hf_subset
|
|
79
86
|
|
|
80
|
-
def __call__(
|
|
87
|
+
def __call__( # type: ignore[override]
|
|
81
88
|
self,
|
|
82
89
|
model: EncoderProtocol,
|
|
83
|
-
|
|
90
|
+
*,
|
|
91
|
+
encode_kwargs: EncodeKwargs,
|
|
84
92
|
) -> list[torch.Tensor]:
|
|
85
93
|
images = []
|
|
86
94
|
if isinstance(self.images_column_names, str):
|
|
@@ -101,9 +109,9 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
101
109
|
texts.append(row[col])
|
|
102
110
|
|
|
103
111
|
text_embeddings = model.encode(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
112
|
+
_create_dataloader_from_texts(
|
|
113
|
+
texts,
|
|
114
|
+
**encode_kwargs,
|
|
107
115
|
),
|
|
108
116
|
task_metadata=self.task_metadata,
|
|
109
117
|
hf_subset=self.hf_subset,
|
|
@@ -122,7 +130,6 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
122
130
|
image_embeddings = model.encode(
|
|
123
131
|
DataLoader(
|
|
124
132
|
CustomImageDataset(images),
|
|
125
|
-
batch_size=encode_kwargs["batch_size"],
|
|
126
133
|
collate_fn=lambda x: {"image": [item["image"] for item in x]},
|
|
127
134
|
),
|
|
128
135
|
task_metadata=self.task_metadata,
|
|
@@ -14,6 +14,7 @@ from mteb._evaluators.evaluator import Evaluator
|
|
|
14
14
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
15
15
|
from mteb.models import EncoderProtocol
|
|
16
16
|
from mteb.similarity_functions import compute_pairwise_similarity
|
|
17
|
+
from mteb.types import EncodeKwargs, PromptType
|
|
17
18
|
|
|
18
19
|
logger = logging.getLogger(__name__)
|
|
19
20
|
|
|
@@ -60,6 +61,8 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
60
61
|
task_metadata: TaskMetadata,
|
|
61
62
|
hf_split: str,
|
|
62
63
|
hf_subset: str,
|
|
64
|
+
input1_prompt_type: PromptType | None,
|
|
65
|
+
input2_prompt_type: PromptType | None,
|
|
63
66
|
**kwargs,
|
|
64
67
|
) -> None:
|
|
65
68
|
super().__init__(**kwargs)
|
|
@@ -69,6 +72,8 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
69
72
|
self.task_metadata = task_metadata
|
|
70
73
|
self.hf_split = hf_split
|
|
71
74
|
self.hf_subset = hf_subset
|
|
75
|
+
self.input1_prompt_type = input1_prompt_type
|
|
76
|
+
self.input2_prompt_type = input2_prompt_type
|
|
72
77
|
|
|
73
78
|
if len(self.dataset[self.input1_column_name]) != len(
|
|
74
79
|
self.dataset[self.input2_column_name]
|
|
@@ -80,49 +85,36 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
80
85
|
def __call__(
|
|
81
86
|
self,
|
|
82
87
|
model: EncoderProtocol,
|
|
83
|
-
encode_kwargs:
|
|
88
|
+
encode_kwargs: EncodeKwargs,
|
|
84
89
|
) -> PairClassificationDistances:
|
|
85
|
-
logger.info("Running pair classification - Encoding
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
self.dataset[self.input1_column_name][:]
|
|
90
|
-
+ self.dataset[self.input2_column_name][:]
|
|
91
|
-
)
|
|
92
|
-
len_sentences1 = len(self.dataset[self.input1_column_name])
|
|
93
|
-
embeddings = self._encode_unique_texts(
|
|
94
|
-
all_sentences,
|
|
95
|
-
model,
|
|
96
|
-
task_metadata=self.task_metadata,
|
|
97
|
-
hf_split=self.hf_split,
|
|
98
|
-
hf_subset=self.hf_subset,
|
|
99
|
-
**encode_kwargs,
|
|
100
|
-
)
|
|
101
|
-
embeddings1 = embeddings[:len_sentences1]
|
|
102
|
-
embeddings2 = embeddings[len_sentences1:]
|
|
103
|
-
else:
|
|
104
|
-
embeddings1 = model.encode(
|
|
105
|
-
create_dataloader(
|
|
106
|
-
self.dataset,
|
|
107
|
-
task_metadata=self.task_metadata,
|
|
108
|
-
input_column=self.input1_column_name,
|
|
109
|
-
),
|
|
90
|
+
logger.info("Running pair classification - Encoding samples (1/2)")
|
|
91
|
+
embeddings1 = model.encode(
|
|
92
|
+
create_dataloader(
|
|
93
|
+
self.dataset,
|
|
110
94
|
task_metadata=self.task_metadata,
|
|
111
|
-
|
|
112
|
-
hf_subset=self.hf_subset,
|
|
95
|
+
input_column=self.input1_column_name,
|
|
113
96
|
**encode_kwargs,
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
97
|
+
),
|
|
98
|
+
task_metadata=self.task_metadata,
|
|
99
|
+
hf_split=self.hf_split,
|
|
100
|
+
hf_subset=self.hf_subset,
|
|
101
|
+
prompt_type=self.input1_prompt_type,
|
|
102
|
+
**encode_kwargs,
|
|
103
|
+
)
|
|
104
|
+
logger.info("Running pair classification - Encoding samples (2/2)")
|
|
105
|
+
embeddings2 = model.encode(
|
|
106
|
+
create_dataloader(
|
|
107
|
+
self.dataset,
|
|
121
108
|
task_metadata=self.task_metadata,
|
|
122
|
-
|
|
123
|
-
hf_subset=self.hf_subset,
|
|
109
|
+
input_column=self.input2_column_name,
|
|
124
110
|
**encode_kwargs,
|
|
125
|
-
)
|
|
111
|
+
),
|
|
112
|
+
task_metadata=self.task_metadata,
|
|
113
|
+
hf_split=self.hf_split,
|
|
114
|
+
hf_subset=self.hf_subset,
|
|
115
|
+
prompt_type=self.input2_prompt_type,
|
|
116
|
+
**encode_kwargs,
|
|
117
|
+
)
|
|
126
118
|
|
|
127
119
|
logger.info("Running pair classification - Evaluating pair similarity...")
|
|
128
120
|
cosine_scores = 1 - paired_cosine_distances(embeddings1, embeddings2)
|
|
@@ -156,7 +148,9 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
156
148
|
hf_subset: str,
|
|
157
149
|
**encode_kwargs: Any,
|
|
158
150
|
) -> np.ndarray:
|
|
159
|
-
index_map
|
|
151
|
+
index_map = {}
|
|
152
|
+
all_unique_texts: list[str] = []
|
|
153
|
+
all_texts_indexes = []
|
|
160
154
|
for text in all_texts:
|
|
161
155
|
text_hash = hash(text)
|
|
162
156
|
if text_hash not in index_map:
|
|
@@ -168,7 +162,7 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
168
162
|
)
|
|
169
163
|
all_unique_texts_embs = np.asarray(
|
|
170
164
|
model.encode(
|
|
171
|
-
_create_dataloader_from_texts(all_unique_texts),
|
|
165
|
+
_create_dataloader_from_texts(all_unique_texts, **encode_kwargs),
|
|
172
166
|
task_metadata=task_metadata,
|
|
173
167
|
hf_split=hf_split,
|
|
174
168
|
hf_subset=hf_subset,
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Any
|
|
4
3
|
|
|
5
4
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
6
5
|
from mteb.models import SearchProtocol
|
|
7
6
|
from mteb.types import (
|
|
8
7
|
CorpusDatasetType,
|
|
8
|
+
EncodeKwargs,
|
|
9
9
|
QueryDatasetType,
|
|
10
10
|
RelevantDocumentsType,
|
|
11
11
|
RetrievalEvaluationResult,
|
|
@@ -48,7 +48,7 @@ class RetrievalEvaluator(Evaluator):
|
|
|
48
48
|
def __call__( # type: ignore[override]
|
|
49
49
|
self,
|
|
50
50
|
search_model: SearchProtocol,
|
|
51
|
-
encode_kwargs:
|
|
51
|
+
encode_kwargs: EncodeKwargs,
|
|
52
52
|
) -> RetrievalOutputType:
|
|
53
53
|
logger.info("Running retrieval task - Indexing corpus...")
|
|
54
54
|
search_model.index(
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import defaultdict
|
|
3
|
+
from collections.abc import Mapping
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|
|
15
16
|
|
|
16
17
|
def mrr(
|
|
17
18
|
qrels: RelevantDocumentsType,
|
|
18
|
-
results:
|
|
19
|
+
results: Mapping[str, Mapping[str, float]],
|
|
19
20
|
k_values: list[int],
|
|
20
21
|
) -> dict[str, list[float]]:
|
|
21
22
|
mrr_metrics = defaultdict(list)
|
|
@@ -32,7 +33,7 @@ def mrr(
|
|
|
32
33
|
doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
|
|
33
34
|
}
|
|
34
35
|
for k in k_values:
|
|
35
|
-
rr = 0
|
|
36
|
+
rr = 0.0
|
|
36
37
|
for rank, hit in enumerate(top_hits[query_id][0:k]):
|
|
37
38
|
if hit[0] in query_relevant_docs:
|
|
38
39
|
rr = 1.0 / (rank + 1)
|
|
@@ -45,8 +46,8 @@ def recall_cap(
|
|
|
45
46
|
qrels: RelevantDocumentsType,
|
|
46
47
|
results: dict[str, dict[str, float]],
|
|
47
48
|
k_values: list[int],
|
|
48
|
-
) -> dict[str, list[float]]:
|
|
49
|
-
capped_recall = defaultdict(list)
|
|
49
|
+
) -> dict[str, list[float | None]]:
|
|
50
|
+
capped_recall: dict[str, list[float | None]] = defaultdict(list)
|
|
50
51
|
|
|
51
52
|
k_max = max(k_values)
|
|
52
53
|
|
|
@@ -139,7 +140,7 @@ def calculate_pmrr(original_run, new_run, changed_qrels):
|
|
|
139
140
|
changes = []
|
|
140
141
|
for qid in changed_qrels.keys():
|
|
141
142
|
if qid + "-og" not in original_run or qid + "-changed" not in new_run:
|
|
142
|
-
|
|
143
|
+
logger.warning(f"Query {qid} not found in the runs for calculating p-MRR")
|
|
143
144
|
continue
|
|
144
145
|
original_qid_run = original_run[qid + "-og"]
|
|
145
146
|
new_qid_run = new_run[qid + "-changed"]
|
|
@@ -188,7 +189,7 @@ def evaluate_p_mrr_change(
|
|
|
188
189
|
Returns:
|
|
189
190
|
A dictionary with the scores, including "p-MRR", "og" and "changed" keys.
|
|
190
191
|
"""
|
|
191
|
-
followir_scores = defaultdict(dict)
|
|
192
|
+
followir_scores: dict[str, float | dict[str, float]] = defaultdict(dict)
|
|
192
193
|
|
|
193
194
|
qrels_sep = {
|
|
194
195
|
"og": {k: v for k, v in qrels.items() if k.endswith("-og")},
|
|
@@ -227,7 +228,7 @@ def evaluate_p_mrr_change(
|
|
|
227
228
|
ndcg, _map, recall, precision, naucs, avg_mrr, naucs_mrr, cv_recall, {}
|
|
228
229
|
)
|
|
229
230
|
for key, value in scores_dict.items():
|
|
230
|
-
followir_scores[name][key] = value
|
|
231
|
+
followir_scores[name][key] = value # type: ignore[index]
|
|
231
232
|
|
|
232
233
|
return followir_scores
|
|
233
234
|
|
|
@@ -254,8 +255,8 @@ def confidence_scores(sim_scores: list[float]) -> dict[str, float]:
|
|
|
254
255
|
sim_scores_sorted = sorted(sim_scores)[::-1]
|
|
255
256
|
|
|
256
257
|
cs_max = sim_scores_sorted[0]
|
|
257
|
-
cs_std = np.std(sim_scores)
|
|
258
|
-
cs_diff1 =
|
|
258
|
+
cs_std = float(np.std(sim_scores))
|
|
259
|
+
cs_diff1 = 0.0
|
|
259
260
|
if len(sim_scores) > 1:
|
|
260
261
|
cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1]
|
|
261
262
|
elif len(sim_scores) == 1:
|
|
@@ -410,7 +411,7 @@ def make_score_dict(
|
|
|
410
411
|
cv_recall: dict[str, float],
|
|
411
412
|
task_scores: dict[str, float],
|
|
412
413
|
previous_results_model_meta: dict[str, Any] | None = None,
|
|
413
|
-
) -> dict[str,
|
|
414
|
+
) -> dict[str, Any]:
|
|
414
415
|
return {
|
|
415
416
|
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
|
|
416
417
|
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
|
|
@@ -528,7 +529,7 @@ def max_over_subqueries(
|
|
|
528
529
|
|
|
529
530
|
|
|
530
531
|
def calculate_retrieval_scores(
|
|
531
|
-
results:
|
|
532
|
+
results: Mapping[str, Mapping[str, float]],
|
|
532
533
|
qrels: RelevantDocumentsType,
|
|
533
534
|
k_values: list[int],
|
|
534
535
|
skip_first_result: bool = False,
|
|
@@ -576,7 +577,7 @@ def calculate_retrieval_scores(
|
|
|
576
577
|
|
|
577
578
|
|
|
578
579
|
def evaluate_abstention(
|
|
579
|
-
results:
|
|
580
|
+
results: Mapping[str, Mapping[str, float]],
|
|
580
581
|
metric_scores: dict[str, list[float]],
|
|
581
582
|
) -> dict[str, float]:
|
|
582
583
|
"""Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
|
|
@@ -591,21 +592,21 @@ def evaluate_abstention(
|
|
|
591
592
|
all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())]
|
|
592
593
|
all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores]
|
|
593
594
|
conf_fcts = list(all_conf_scores[0].keys())
|
|
594
|
-
|
|
595
|
+
all_conf_scores_ = {
|
|
595
596
|
fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts
|
|
596
597
|
}
|
|
597
|
-
|
|
598
|
+
metric_scores_ = {k: np.array(v) for k, v in metric_scores.items()}
|
|
598
599
|
naucs = {}
|
|
599
600
|
|
|
600
|
-
for metric_name, scores in
|
|
601
|
-
for fct, conf_scores in
|
|
601
|
+
for metric_name, scores in metric_scores_.items():
|
|
602
|
+
for fct, conf_scores in all_conf_scores_.items():
|
|
602
603
|
naucs[f"nAUC_{metric_name}_{fct}"] = nauc(conf_scores, scores)
|
|
603
604
|
|
|
604
605
|
return naucs
|
|
605
606
|
|
|
606
607
|
|
|
607
608
|
def calculate_cv_recall(
|
|
608
|
-
results:
|
|
609
|
+
results: Mapping[str, Mapping[str, float]],
|
|
609
610
|
qrels: RelevantDocumentsType,
|
|
610
611
|
k_values: list[int],
|
|
611
612
|
skip_first_result: bool = False,
|