mteb 2.5.2__py3-none-any.whl → 2.7.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +2 -0
- mteb/_create_dataloaders.py +78 -30
- mteb/_evaluators/any_sts_evaluator.py +13 -6
- mteb/_evaluators/clustering_evaluator.py +13 -5
- mteb/_evaluators/evaluator.py +12 -4
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +22 -11
- mteb/_evaluators/pair_classification_evaluator.py +17 -7
- mteb/_evaluators/retrieval_evaluator.py +23 -14
- mteb/_evaluators/retrieval_metrics.py +26 -19
- mteb/_evaluators/sklearn_evaluator.py +27 -17
- mteb/_evaluators/text/bitext_mining_evaluator.py +36 -20
- mteb/_evaluators/text/summarization_evaluator.py +31 -20
- mteb/_evaluators/zeroshot_classification_evaluator.py +16 -5
- mteb/_helpful_enum.py +5 -1
- mteb/abstasks/_data_filter/filters.py +9 -3
- mteb/abstasks/_data_filter/task_pipelines.py +10 -2
- mteb/abstasks/_statistics_calculation.py +21 -11
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +78 -44
- mteb/abstasks/aggregate_task_metadata.py +21 -18
- mteb/abstasks/aggregated_task.py +23 -35
- mteb/abstasks/classification.py +39 -18
- mteb/abstasks/clustering.py +37 -20
- mteb/abstasks/clustering_legacy.py +30 -16
- mteb/abstasks/image/image_text_pair_classification.py +26 -9
- mteb/abstasks/multilabel_classification.py +33 -21
- mteb/abstasks/pair_classification.py +44 -19
- mteb/abstasks/regression.py +18 -10
- mteb/abstasks/retrieval.py +82 -52
- mteb/abstasks/retrieval_dataset_loaders.py +50 -39
- mteb/abstasks/sts.py +34 -15
- mteb/abstasks/task_metadata.py +44 -37
- mteb/abstasks/text/bitext_mining.py +57 -35
- mteb/abstasks/text/reranking.py +10 -8
- mteb/abstasks/text/summarization.py +26 -10
- mteb/abstasks/zeroshot_classification.py +27 -9
- mteb/benchmarks/_create_table.py +13 -7
- mteb/benchmarks/benchmark.py +15 -3
- mteb/benchmarks/benchmarks/__init__.py +6 -0
- mteb/benchmarks/benchmarks/benchmarks.py +153 -13
- mteb/benchmarks/benchmarks/rteb_benchmarks.py +20 -9
- mteb/benchmarks/get_benchmark.py +14 -55
- mteb/cache.py +189 -31
- mteb/cli/_display_tasks.py +10 -4
- mteb/cli/build_cli.py +112 -13
- mteb/cli/generate_model_card.py +50 -23
- mteb/deprecated_evaluator.py +72 -54
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
- mteb/descriptive_stats/Retrieval/BrightAopsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightBiologyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightBiologyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEarthScienceLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEarthScienceRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEconomicsLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEconomicsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightLeetcodeRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPonyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPonyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPsychologyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPsychologyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightRoboticsLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightRoboticsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightStackoverflowLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightStackoverflowRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightSustainableLivingLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightSustainableLivingRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightTheoremQAQuestionsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightTheoremQATheoremsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
- mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
- mteb/evaluate.py +71 -47
- mteb/filter_tasks.py +36 -32
- mteb/get_tasks.py +37 -33
- mteb/languages/language_scripts.py +11 -4
- mteb/leaderboard/app.py +172 -37
- mteb/leaderboard/table.py +7 -2
- mteb/load_results.py +20 -14
- mteb/models/abs_encoder.py +30 -16
- mteb/models/cache_wrappers/cache_backend_protocol.py +7 -7
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +10 -5
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +13 -4
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +16 -11
- mteb/models/get_model_meta.py +53 -9
- mteb/models/instruct_wrapper.py +41 -13
- mteb/models/model_implementations/align_models.py +11 -5
- mteb/models/model_implementations/amazon_models.py +1 -0
- mteb/models/model_implementations/andersborges.py +6 -4
- mteb/models/model_implementations/ara_models.py +2 -1
- mteb/models/model_implementations/arctic_models.py +16 -8
- mteb/models/model_implementations/b1ade_models.py +2 -1
- mteb/models/model_implementations/bedrock_models.py +20 -6
- mteb/models/model_implementations/bge_models.py +85 -22
- mteb/models/model_implementations/bica_model.py +4 -3
- mteb/models/model_implementations/blip2_models.py +13 -6
- mteb/models/model_implementations/blip_models.py +33 -20
- mteb/models/model_implementations/bm25.py +27 -17
- mteb/models/model_implementations/bmretriever_models.py +16 -6
- mteb/models/model_implementations/cadet_models.py +2 -1
- mteb/models/model_implementations/cde_models.py +22 -9
- mteb/models/model_implementations/clip_models.py +18 -10
- mteb/models/model_implementations/clips_models.py +6 -3
- mteb/models/model_implementations/codefuse_models.py +10 -5
- mteb/models/model_implementations/codesage_models.py +6 -3
- mteb/models/model_implementations/cohere_models.py +19 -9
- mteb/models/model_implementations/cohere_v.py +16 -6
- mteb/models/model_implementations/colpali_models.py +10 -6
- mteb/models/model_implementations/colqwen_models.py +24 -38
- mteb/models/model_implementations/colsmol_models.py +5 -3
- mteb/models/model_implementations/conan_models.py +12 -5
- mteb/models/model_implementations/dino_models.py +70 -46
- mteb/models/model_implementations/e5_instruct.py +27 -4
- mteb/models/model_implementations/e5_models.py +18 -9
- mteb/models/model_implementations/e5_v.py +16 -10
- mteb/models/model_implementations/eagerworks_models.py +12 -5
- mteb/models/model_implementations/emillykkejensen_models.py +9 -6
- mteb/models/model_implementations/en_code_retriever.py +2 -1
- mteb/models/model_implementations/euler_models.py +3 -2
- mteb/models/model_implementations/evaclip_models.py +13 -4
- mteb/models/model_implementations/fa_models.py +18 -9
- mteb/models/model_implementations/facebookai.py +16 -2
- mteb/models/model_implementations/geogpt_models.py +2 -1
- mteb/models/model_implementations/gme_v_models.py +13 -8
- mteb/models/model_implementations/google_models.py +16 -5
- mteb/models/model_implementations/granite_vision_embedding_models.py +8 -6
- mteb/models/model_implementations/gritlm_models.py +5 -2
- mteb/models/model_implementations/gte_models.py +34 -13
- mteb/models/model_implementations/hinvec_models.py +7 -2
- mteb/models/model_implementations/human.py +1 -0
- mteb/models/model_implementations/ibm_granite_models.py +36 -6
- mteb/models/model_implementations/inf_models.py +4 -2
- mteb/models/model_implementations/jasper_models.py +16 -7
- mteb/models/model_implementations/jina_clip.py +58 -14
- mteb/models/model_implementations/jina_models.py +35 -16
- mteb/models/model_implementations/kalm_models.py +24 -12
- mteb/models/model_implementations/kblab.py +13 -6
- mteb/models/model_implementations/kennethenevoldsen_models.py +6 -4
- mteb/models/model_implementations/kfst.py +2 -1
- mteb/models/model_implementations/kowshik24_models.py +2 -1
- mteb/models/model_implementations/lens_models.py +2 -0
- mteb/models/model_implementations/lgai_embedding_models.py +2 -1
- mteb/models/model_implementations/linq_models.py +8 -2
- mteb/models/model_implementations/listconranker.py +11 -5
- mteb/models/model_implementations/llm2clip_models.py +18 -10
- mteb/models/model_implementations/llm2vec_models.py +28 -14
- mteb/models/model_implementations/mcinext_models.py +12 -3
- mteb/models/model_implementations/mdbr_models.py +19 -3
- mteb/models/model_implementations/misc_models.py +131 -68
- mteb/models/model_implementations/mixedbread_ai_models.py +335 -0
- mteb/models/model_implementations/mme5_models.py +3 -2
- mteb/models/model_implementations/moco_models.py +15 -8
- mteb/models/model_implementations/mod_models.py +3 -2
- mteb/models/model_implementations/model2vec_models.py +37 -18
- mteb/models/model_implementations/moka_models.py +4 -1
- mteb/models/model_implementations/nbailab.py +6 -3
- mteb/models/model_implementations/no_instruct_sentence_models.py +15 -7
- mteb/models/model_implementations/nomic_models.py +47 -19
- mteb/models/model_implementations/nomic_models_vision.py +6 -4
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +20 -8
- mteb/models/model_implementations/nvidia_models.py +165 -22
- mteb/models/model_implementations/octen_models.py +64 -3
- mteb/models/model_implementations/openai_models.py +14 -4
- mteb/models/model_implementations/openclip_models.py +30 -17
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +20 -9
- mteb/models/model_implementations/ops_moa_models.py +10 -3
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +2 -1
- mteb/models/model_implementations/pawan_models.py +2 -1
- mteb/models/model_implementations/piccolo_models.py +3 -1
- mteb/models/model_implementations/pixie_models.py +56 -0
- mteb/models/model_implementations/promptriever_models.py +20 -10
- mteb/models/model_implementations/pylate_models.py +41 -21
- mteb/models/model_implementations/qodo_models.py +4 -2
- mteb/models/model_implementations/qtack_models.py +2 -1
- mteb/models/model_implementations/qwen3_models.py +14 -4
- mteb/models/model_implementations/qzhou_models.py +4 -2
- mteb/models/model_implementations/random_baseline.py +7 -6
- mteb/models/model_implementations/rasgaard_models.py +3 -2
- mteb/models/model_implementations/reasonir_model.py +66 -1
- mteb/models/model_implementations/repllama_models.py +18 -9
- mteb/models/model_implementations/rerankers_custom.py +25 -10
- mteb/models/model_implementations/rerankers_monot5_based.py +41 -21
- mteb/models/model_implementations/richinfoai_models.py +2 -1
- mteb/models/model_implementations/ru_sentence_models.py +40 -20
- mteb/models/model_implementations/ruri_models.py +20 -10
- mteb/models/model_implementations/salesforce_models.py +13 -4
- mteb/models/model_implementations/samilpwc_models.py +2 -1
- mteb/models/model_implementations/sarashina_embedding_models.py +4 -2
- mteb/models/model_implementations/searchmap_models.py +2 -1
- mteb/models/model_implementations/seed_1_6_embedding_models.py +5 -2
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +119 -148
- mteb/models/model_implementations/seed_models.py +2 -1
- mteb/models/model_implementations/sentence_transformers_models.py +142 -22
- mteb/models/model_implementations/shuu_model.py +2 -1
- mteb/models/model_implementations/siglip_models.py +39 -24
- mteb/models/model_implementations/slm_models.py +419 -0
- mteb/models/model_implementations/sonar_models.py +2 -1
- mteb/models/model_implementations/spartan8806_atles_champion.py +2 -1
- mteb/models/model_implementations/stella_models.py +23 -4
- mteb/models/model_implementations/tarka_models.py +4 -2
- mteb/models/model_implementations/text2vec_models.py +12 -3
- mteb/models/model_implementations/ua_sentence_models.py +2 -1
- mteb/models/model_implementations/uae_models.py +17 -5
- mteb/models/model_implementations/vdr_models.py +9 -2
- mteb/models/model_implementations/vi_vn_models.py +12 -6
- mteb/models/model_implementations/vista_models.py +11 -4
- mteb/models/model_implementations/vlm2vec_models.py +14 -7
- mteb/models/model_implementations/voyage_models.py +136 -4
- mteb/models/model_implementations/voyage_v.py +17 -10
- mteb/models/model_implementations/xyz_models.py +1 -0
- mteb/models/model_implementations/youtu_models.py +2 -1
- mteb/models/model_implementations/yuan_models.py +2 -1
- mteb/models/model_implementations/yuan_models_en.py +3 -2
- mteb/models/model_meta.py +127 -40
- mteb/models/models_protocols.py +43 -22
- mteb/models/search_encoder_index/search_backend_protocol.py +7 -3
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +21 -10
- mteb/models/search_wrappers.py +63 -29
- mteb/models/sentence_transformer_wrapper.py +52 -26
- mteb/models/vllm_wrapper.py +329 -0
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +48 -35
- mteb/results/model_result.py +68 -32
- mteb/results/task_result.py +110 -72
- mteb/similarity_functions.py +19 -9
- mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py +3 -3
- mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py +3 -3
- mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py +3 -3
- mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py +3 -3
- mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py +3 -3
- mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py +3 -3
- mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py +3 -3
- mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py +3 -3
- mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py +3 -3
- mteb/tasks/bitext_mining/eng/pub_chem_smiles_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/fas/fa_mteb_summary_retrieval.py +3 -3
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/norwegian_courts_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +2 -2
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -1
- mteb/tasks/classification/bul/bulgarian_store_review_sentiment_classfication.py +1 -1
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -1
- mteb/tasks/classification/dan/dk_hate_classification.py +2 -2
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -1
- mteb/tasks/classification/ell/greek_legal_code_classification.py +1 -1
- mteb/tasks/classification/eng/dbpedia_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_conversations_classification.py +2 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +1 -1
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -1
- mteb/tasks/classification/eng/yelp_review_full_classification.py +2 -2
- mteb/tasks/classification/est/estonian_valence.py +2 -2
- mteb/tasks/classification/fas/fa_mteb_classification.py +6 -6
- mteb/tasks/classification/fas/persian_food_sentiment_classification.py +1 -1
- mteb/tasks/classification/fil/filipino_shopee_reviews_classification.py +1 -1
- mteb/tasks/classification/fin/fin_toxicity_classification.py +1 -1
- mteb/tasks/classification/fra/french_book_reviews.py +2 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -1
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -1
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -1
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +2 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -1
- mteb/tasks/classification/ita/dado_eval_coarse_classification.py +1 -1
- mteb/tasks/classification/ita/ita_casehold_classification.py +1 -1
- mteb/tasks/classification/ita/sardi_stance_classification.py +1 -1
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -1
- mteb/tasks/classification/jpn/wrime_classification.py +1 -1
- mteb/tasks/classification/kan/kannada_news_classification.py +2 -2
- mteb/tasks/classification/kor/klue_tc.py +2 -2
- mteb/tasks/classification/kor/kor_fin.py +1 -1
- mteb/tasks/classification/kor/kor_hate_classification.py +1 -1
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +1 -1
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -1
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/afri_senti_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -1
- mteb/tasks/classification/multilingual/cyrillic_turkic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_nlp_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/masakha_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -1
- mteb/tasks/classification/multilingual/multilingual_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +2 -2
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/multilingual/tweet_sentiment_classification.py +1 -1
- mteb/tasks/classification/nep/nepali_news_classification.py +2 -2
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +1 -1
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +1 -1
- mteb/tasks/classification/ory/odia_news_classification.py +2 -2
- mteb/tasks/classification/pan/punjabi_news_classification.py +1 -1
- mteb/tasks/classification/ron/moroco.py +1 -1
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -1
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -1
- mteb/tasks/classification/rus/georeview_classification.py +1 -1
- mteb/tasks/classification/rus/headline_classification.py +2 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +2 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +2 -2
- mteb/tasks/classification/rus/ru_sci_bench_grnti_classification.py +1 -1
- mteb/tasks/classification/rus/ru_sci_bench_oecd_classification.py +1 -1
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -1
- mteb/tasks/classification/san/sanskrit_shlokas_classification.py +1 -1
- mteb/tasks/classification/sin/sinhala_news_classification.py +2 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +2 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -1
- mteb/tasks/classification/spa/spanish_news_classification.py +2 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -1
- mteb/tasks/classification/tam/tamil_news_classification.py +2 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +2 -2
- mteb/tasks/classification/tha/wongnai_reviews_classification.py +1 -1
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +2 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -2
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -1
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -1
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +2 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/arxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/arxiv_hierarchical_clustering.py +2 -2
- mteb/tasks/clustering/eng/big_patent_clustering.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/twenty_newsgroups_clustering.py +1 -1
- mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/fas/fa_mteb_clustering.py +4 -4
- mteb/tasks/clustering/fra/hal_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/wiki_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nob/snl_clustering.py +8 -3
- mteb/tasks/clustering/nob/vg_clustering.py +8 -3
- mteb/tasks/clustering/pol/polish_clustering.py +3 -3
- mteb/tasks/clustering/rus/ru_sci_bench_grnti_clustering_p2p.py +1 -1
- mteb/tasks/clustering/rus/ru_sci_bench_oecd_clustering_p2p.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +6 -6
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +2 -2
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -1
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -1
- mteb/tasks/multilabel_classification/rus/ru_toixic_multilabelclassification_okmlcup.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -1
- mteb/tasks/pair_classification/ara/ar_entail.py +1 -1
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -1
- mteb/tasks/pair_classification/deu/false_friends_de_en_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_ai_sentence_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +4 -3
- mteb/tasks/pair_classification/eng/pub_chem_synonym_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_wiki_paragraphs_pc.py +1 -1
- mteb/tasks/pair_classification/eng/sprint_duplicate_questions_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_sem_eval2015_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_url_corpus_pc.py +1 -1
- mteb/tasks/pair_classification/fas/fa_mteb_pair_classification.py +5 -5
- mteb/tasks/pair_classification/fas/fars_tail.py +2 -2
- mteb/tasks/pair_classification/hye/armenian_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/ita/dis_co_tex_pair_classification.py +1 -1
- mteb/tasks/pair_classification/kor/klue_nli.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +2 -2
- mteb/tasks/pair_classification/multilingual/xnli.py +1 -1
- mteb/tasks/pair_classification/pol/polish_pc.py +4 -4
- mteb/tasks/pair_classification/por/assin2_rte.py +1 -1
- mteb/tasks/pair_classification/por/sick_br_pc.py +1 -1
- mteb/tasks/pair_classification/rus/terra.py +2 -2
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -1
- mteb/tasks/pair_classification/zho/cmteb_pair_classification.py +2 -2
- mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +16 -16
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +2 -2
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +3 -3
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +3 -3
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/__init__.py +44 -0
- mteb/tasks/retrieval/eng/bright_retrieval.py +10 -2
- mteb/tasks/retrieval/eng/bright_v1_1_retrieval.py +968 -0
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
- mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/limit_retrieval.py +6 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/kor/__init__.py +15 -1
- mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
- mteb/tasks/retrieval/multilingual/__init__.py +2 -0
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +5 -5
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +14 -4
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +3 -3
- mteb/tasks/retrieval/nob/snl_retrieval.py +3 -3
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/tasks/retrieval/vie/__init__.py +14 -6
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
- mteb/tasks/sts/fao/faroese_sts.py +1 -1
- mteb/tasks/sts/fra/sick_fr_sts.py +1 -1
- mteb/tasks/sts/kor/klue_sts.py +1 -1
- mteb/tasks/sts/por/sick_br_sts.py +1 -1
- mteb/tasks/sts/rus/ru_para_phraser_sts.py +1 -1
- mteb/tasks/zeroshot_classification/eng/sci_mmir.py +1 -1
- mteb/types/__init__.py +2 -0
- mteb/types/_encoder_io.py +13 -1
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +18 -5
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/METADATA +15 -4
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/RECORD +528 -486
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/WHEEL +1 -1
- mteb/models/model_implementations/mxbai_models.py +0 -111
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/entry_points.txt +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/top_level.txt +0 -0
mteb/abstasks/classification.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from collections import defaultdict
|
|
3
|
-
from
|
|
4
|
-
from typing import Any, TypedDict
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
from datasets import Dataset, DatasetDict
|
|
@@ -16,12 +17,8 @@ from sklearn.metrics import (
|
|
|
16
17
|
|
|
17
18
|
from mteb._evaluators.sklearn_evaluator import SklearnEvaluator, SklearnModelProtocol
|
|
18
19
|
from mteb.models import EncoderProtocol, MTEBModels
|
|
19
|
-
from mteb.types import HFSubset, ScoresDict
|
|
20
20
|
from mteb.types.statistics import (
|
|
21
|
-
ImageStatistics,
|
|
22
|
-
LabelStatistics,
|
|
23
21
|
SplitDescriptiveStatistics,
|
|
24
|
-
TextStatistics,
|
|
25
22
|
)
|
|
26
23
|
|
|
27
24
|
from ._statistics_calculation import (
|
|
@@ -31,6 +28,18 @@ from ._statistics_calculation import (
|
|
|
31
28
|
)
|
|
32
29
|
from .abstask import AbsTask
|
|
33
30
|
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
|
|
34
|
+
from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
|
|
35
|
+
from mteb.models import MTEBModels
|
|
36
|
+
from mteb.types import EncodeKwargs, HFSubset, ScoresDict
|
|
37
|
+
from mteb.types.statistics import (
|
|
38
|
+
ImageStatistics,
|
|
39
|
+
LabelStatistics,
|
|
40
|
+
TextStatistics,
|
|
41
|
+
)
|
|
42
|
+
|
|
34
43
|
logger = logging.getLogger(__name__)
|
|
35
44
|
|
|
36
45
|
|
|
@@ -98,9 +107,8 @@ class AbsTaskClassification(AbsTask):
|
|
|
98
107
|
text: str (for text) or PIL.Image (for image). Column name can be changed via `input_column_name` attribute.
|
|
99
108
|
label: int. Column name can be changed via `label_column_name` attribute.
|
|
100
109
|
evaluator_model: The model to use for evaluation. Can be any sklearn compatible model. Default is `LogisticRegression`.
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
n_experiments: Number of experiments to run. Default is 10.
|
|
110
|
+
samples_per_label: Number of samples per label to use for training the evaluator model. Default is 8.
|
|
111
|
+
n_experiments: Number of experiments to run. Default is 10.
|
|
104
112
|
train_split: Name of the split to use for training the evaluator model. Default is "train".
|
|
105
113
|
label_column_name: Name of the column containing the labels. Default is "label".
|
|
106
114
|
input_column_name: Name of the column containing the input data. Default is "text".
|
|
@@ -126,8 +134,9 @@ class AbsTaskClassification(AbsTask):
|
|
|
126
134
|
split: str = "test",
|
|
127
135
|
subsets_to_run: list[HFSubset] | None = None,
|
|
128
136
|
*,
|
|
129
|
-
encode_kwargs:
|
|
137
|
+
encode_kwargs: EncodeKwargs,
|
|
130
138
|
prediction_folder: Path | None = None,
|
|
139
|
+
num_proc: int = 1,
|
|
131
140
|
**kwargs: Any,
|
|
132
141
|
) -> dict[HFSubset, ScoresDict]:
|
|
133
142
|
"""Evaluate a model on the classification task.
|
|
@@ -141,7 +150,10 @@ class AbsTaskClassification(AbsTask):
|
|
|
141
150
|
)
|
|
142
151
|
|
|
143
152
|
if not self.data_loaded:
|
|
144
|
-
self.load_data()
|
|
153
|
+
self.load_data(num_proc=num_proc)
|
|
154
|
+
|
|
155
|
+
if self.dataset is None:
|
|
156
|
+
raise RuntimeError("Dataset not loaded.")
|
|
145
157
|
|
|
146
158
|
if "random_state" in self.evaluator_model.get_params():
|
|
147
159
|
self.evaluator_model = self.evaluator_model.set_params(
|
|
@@ -171,23 +183,28 @@ class AbsTaskClassification(AbsTask):
|
|
|
171
183
|
hf_subset=hf_subset,
|
|
172
184
|
encode_kwargs=encode_kwargs,
|
|
173
185
|
prediction_folder=prediction_folder,
|
|
186
|
+
num_proc=num_proc,
|
|
174
187
|
**kwargs,
|
|
175
188
|
)
|
|
176
189
|
self._add_main_score(scores[hf_subset])
|
|
177
190
|
|
|
178
|
-
return scores
|
|
191
|
+
return scores # type: ignore[return-value]
|
|
179
192
|
|
|
180
193
|
def _evaluate_subset(
|
|
181
194
|
self,
|
|
182
|
-
model:
|
|
195
|
+
model: MTEBModels,
|
|
183
196
|
data_split: DatasetDict,
|
|
184
197
|
*,
|
|
185
|
-
encode_kwargs:
|
|
198
|
+
encode_kwargs: EncodeKwargs,
|
|
186
199
|
hf_split: str,
|
|
187
200
|
hf_subset: str,
|
|
188
201
|
prediction_folder: Path | None = None,
|
|
202
|
+
num_proc: int = 1,
|
|
189
203
|
**kwargs: Any,
|
|
190
204
|
) -> FullClassificationMetrics:
|
|
205
|
+
if not isinstance(model, EncoderProtocol):
|
|
206
|
+
raise TypeError("Expected model to be an instance of EncoderProtocol")
|
|
207
|
+
|
|
191
208
|
train_split = data_split[self.train_split]
|
|
192
209
|
eval_split = data_split[hf_split]
|
|
193
210
|
|
|
@@ -216,7 +233,10 @@ class AbsTaskClassification(AbsTask):
|
|
|
216
233
|
evaluator_model=self.evaluator_model,
|
|
217
234
|
)
|
|
218
235
|
y_pred, test_cache = evaluator(
|
|
219
|
-
model,
|
|
236
|
+
model,
|
|
237
|
+
encode_kwargs=encode_kwargs,
|
|
238
|
+
test_cache=test_cache,
|
|
239
|
+
num_proc=num_proc,
|
|
220
240
|
)
|
|
221
241
|
if prediction_folder:
|
|
222
242
|
all_predictions.append(y_pred.tolist())
|
|
@@ -237,7 +257,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
237
257
|
# ap will be none for non binary classification tasks
|
|
238
258
|
k: (
|
|
239
259
|
float(np.mean(values))
|
|
240
|
-
if (values := [s[k] for s in scores if s[k] is not None])
|
|
260
|
+
if (values := [s[k] for s in scores if s[k] is not None]) # type: ignore[literal-required]
|
|
241
261
|
else np.nan
|
|
242
262
|
)
|
|
243
263
|
for k in scores[0].keys()
|
|
@@ -245,7 +265,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
245
265
|
logger.info(f"Running {self.metadata.name} - Finished.")
|
|
246
266
|
return FullClassificationMetrics(
|
|
247
267
|
scores_per_experiment=scores,
|
|
248
|
-
**avg_scores,
|
|
268
|
+
**avg_scores, # type: ignore[typeddict-item]
|
|
249
269
|
)
|
|
250
270
|
|
|
251
271
|
def _calculate_scores(
|
|
@@ -358,11 +378,12 @@ class AbsTaskClassification(AbsTask):
|
|
|
358
378
|
label_statistics=label_statistics,
|
|
359
379
|
)
|
|
360
380
|
|
|
361
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
381
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
362
382
|
self._upload_dataset_to_hub(
|
|
363
383
|
repo_name,
|
|
364
384
|
[
|
|
365
385
|
self.input_column_name,
|
|
366
386
|
self.label_column_name,
|
|
367
387
|
],
|
|
388
|
+
num_proc=num_proc,
|
|
368
389
|
)
|
mteb/abstasks/clustering.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import itertools
|
|
2
4
|
import logging
|
|
3
5
|
import random
|
|
4
6
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from typing import Any
|
|
7
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from datasets import Dataset, DatasetDict
|
|
@@ -12,12 +13,9 @@ from sklearn.metrics.cluster import v_measure_score
|
|
|
12
13
|
|
|
13
14
|
from mteb._create_dataloaders import create_dataloader
|
|
14
15
|
from mteb.models import EncoderProtocol
|
|
15
|
-
from mteb.types import
|
|
16
|
+
from mteb.types import Array, HFSubset
|
|
16
17
|
from mteb.types.statistics import (
|
|
17
|
-
ImageStatistics,
|
|
18
|
-
LabelStatistics,
|
|
19
18
|
SplitDescriptiveStatistics,
|
|
20
|
-
TextStatistics,
|
|
21
19
|
)
|
|
22
20
|
|
|
23
21
|
from ._statistics_calculation import (
|
|
@@ -27,6 +25,17 @@ from ._statistics_calculation import (
|
|
|
27
25
|
)
|
|
28
26
|
from .abstask import AbsTask
|
|
29
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
|
|
31
|
+
from mteb.models import MTEBModels
|
|
32
|
+
from mteb.types import Array, EncodeKwargs, ScoresDict
|
|
33
|
+
from mteb.types.statistics import (
|
|
34
|
+
ImageStatistics,
|
|
35
|
+
LabelStatistics,
|
|
36
|
+
TextStatistics,
|
|
37
|
+
)
|
|
38
|
+
|
|
30
39
|
logger = logging.getLogger(__name__)
|
|
31
40
|
|
|
32
41
|
|
|
@@ -34,7 +43,7 @@ MultilingualDataset = dict[HFSubset, DatasetDict]
|
|
|
34
43
|
|
|
35
44
|
|
|
36
45
|
def _evaluate_clustering_bootstrapped(
|
|
37
|
-
embeddings:
|
|
46
|
+
embeddings: Array,
|
|
38
47
|
labels: list[list[str]],
|
|
39
48
|
n_clusters: int,
|
|
40
49
|
cluster_size: int,
|
|
@@ -61,21 +70,21 @@ def _evaluate_clustering_bootstrapped(
|
|
|
61
70
|
max_depth = max(map(len, labels))
|
|
62
71
|
# Evaluate on each level til max depth
|
|
63
72
|
for i_level in range(max_depth):
|
|
64
|
-
level_labels = []
|
|
73
|
+
level_labels: list[str | int] = []
|
|
65
74
|
# Assign -1 to gold label if the level is not there
|
|
66
75
|
for label in labels:
|
|
67
76
|
if len(label) > i_level:
|
|
68
77
|
level_labels.append(label[i_level])
|
|
69
78
|
else:
|
|
70
79
|
level_labels.append(-1)
|
|
71
|
-
|
|
80
|
+
np_level_labels = np.array(level_labels)
|
|
72
81
|
valid_idx = np.array(
|
|
73
|
-
[level_label != -1 for level_label in
|
|
82
|
+
[level_label != -1 for level_label in np_level_labels]
|
|
74
83
|
) # Could be level_labels != -1 but fails with FutureWarning: elementwise comparison failed
|
|
75
|
-
|
|
84
|
+
np_level_labels = np_level_labels[valid_idx]
|
|
76
85
|
level_embeddings = embeddings[valid_idx]
|
|
77
86
|
clustering_model = MiniBatchKMeans(
|
|
78
|
-
n_clusters=np.unique(
|
|
87
|
+
n_clusters=np.unique(np_level_labels).size,
|
|
79
88
|
batch_size=kmean_batch_size,
|
|
80
89
|
init="k-means++",
|
|
81
90
|
n_init=1, # default when kmeans++ is used
|
|
@@ -87,7 +96,7 @@ def _evaluate_clustering_bootstrapped(
|
|
|
87
96
|
cluster_indices = rng_state.choices(range(n_embeddings), k=cluster_size)
|
|
88
97
|
|
|
89
98
|
_embeddings = level_embeddings[cluster_indices]
|
|
90
|
-
_labels =
|
|
99
|
+
_labels = np_level_labels[cluster_indices]
|
|
91
100
|
cluster_assignment = clustering_model.fit_predict(_embeddings)
|
|
92
101
|
v_measure = v_measure_score(_labels, cluster_assignment)
|
|
93
102
|
v_measures[f"Level {i_level}"].append(v_measure)
|
|
@@ -153,15 +162,20 @@ class AbsTaskClustering(AbsTask):
|
|
|
153
162
|
|
|
154
163
|
def _evaluate_subset(
|
|
155
164
|
self,
|
|
156
|
-
model:
|
|
165
|
+
model: MTEBModels,
|
|
157
166
|
data_split: Dataset,
|
|
158
167
|
*,
|
|
159
|
-
encode_kwargs:
|
|
168
|
+
encode_kwargs: EncodeKwargs,
|
|
160
169
|
hf_split: str,
|
|
161
170
|
hf_subset: str,
|
|
162
171
|
prediction_folder: Path | None = None,
|
|
172
|
+
num_proc: int = 1,
|
|
163
173
|
**kwargs: Any,
|
|
164
174
|
) -> ScoresDict:
|
|
175
|
+
if not isinstance(model, EncoderProtocol):
|
|
176
|
+
raise TypeError(
|
|
177
|
+
"Expected encoder model to be an instance of EncoderProtocol."
|
|
178
|
+
)
|
|
165
179
|
if (
|
|
166
180
|
self.max_document_to_embed is not None
|
|
167
181
|
and self.max_fraction_of_documents_to_embed is not None
|
|
@@ -182,13 +196,13 @@ class AbsTaskClustering(AbsTask):
|
|
|
182
196
|
self.max_fraction_of_documents_to_embed * len(data_split)
|
|
183
197
|
)
|
|
184
198
|
else:
|
|
185
|
-
max_documents_to_embed = self.max_document_to_embed
|
|
199
|
+
max_documents_to_embed = cast("int", self.max_document_to_embed)
|
|
186
200
|
|
|
187
|
-
max_documents_to_embed = min(len(data_split), max_documents_to_embed)
|
|
201
|
+
max_documents_to_embed = min(len(data_split), max_documents_to_embed)
|
|
188
202
|
example_indices = self.rng_state.sample(
|
|
189
203
|
range(len(data_split)), k=max_documents_to_embed
|
|
190
204
|
)
|
|
191
|
-
downsampled_dataset = data_split.select(example_indices)
|
|
205
|
+
downsampled_dataset = data_split.select(example_indices)
|
|
192
206
|
|
|
193
207
|
downsampled_dataset = downsampled_dataset.select_columns(
|
|
194
208
|
[self.input_column_name, self.label_column_name]
|
|
@@ -200,6 +214,7 @@ class AbsTaskClustering(AbsTask):
|
|
|
200
214
|
downsampled_dataset,
|
|
201
215
|
self.metadata,
|
|
202
216
|
input_column=self.input_column_name,
|
|
217
|
+
num_proc=num_proc,
|
|
203
218
|
**encode_kwargs,
|
|
204
219
|
),
|
|
205
220
|
task_metadata=self.metadata,
|
|
@@ -283,9 +298,11 @@ class AbsTaskClustering(AbsTask):
|
|
|
283
298
|
labels_statistics=label_statistics,
|
|
284
299
|
)
|
|
285
300
|
|
|
286
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
301
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
287
302
|
self._upload_dataset_to_hub(
|
|
288
|
-
repo_name,
|
|
303
|
+
repo_name,
|
|
304
|
+
[self.input_column_name, self.label_column_name],
|
|
305
|
+
num_proc=num_proc,
|
|
289
306
|
)
|
|
290
307
|
|
|
291
308
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
|
-
from
|
|
3
|
-
from typing import Any, TypedDict
|
|
4
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
from datasets import Dataset
|
|
@@ -8,13 +9,9 @@ from scipy.optimize import linear_sum_assignment
|
|
|
8
9
|
from sklearn import metrics
|
|
9
10
|
|
|
10
11
|
from mteb._evaluators import ClusteringEvaluator
|
|
11
|
-
from mteb.models import EncoderProtocol
|
|
12
|
-
from mteb.types import ScoresDict
|
|
12
|
+
from mteb.models import EncoderProtocol, MTEBModels
|
|
13
13
|
from mteb.types.statistics import (
|
|
14
|
-
ImageStatistics,
|
|
15
|
-
LabelStatistics,
|
|
16
14
|
SplitDescriptiveStatistics,
|
|
17
|
-
TextStatistics,
|
|
18
15
|
)
|
|
19
16
|
|
|
20
17
|
from ._statistics_calculation import (
|
|
@@ -24,6 +21,17 @@ from ._statistics_calculation import (
|
|
|
24
21
|
)
|
|
25
22
|
from .abstask import AbsTask
|
|
26
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
|
|
27
|
+
from mteb.models import MTEBModels
|
|
28
|
+
from mteb.types import EncodeKwargs, ScoresDict
|
|
29
|
+
from mteb.types.statistics import (
|
|
30
|
+
ImageStatistics,
|
|
31
|
+
LabelStatistics,
|
|
32
|
+
TextStatistics,
|
|
33
|
+
)
|
|
34
|
+
|
|
27
35
|
logger = logging.getLogger(__name__)
|
|
28
36
|
|
|
29
37
|
|
|
@@ -80,15 +88,19 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
80
88
|
|
|
81
89
|
def _evaluate_subset(
|
|
82
90
|
self,
|
|
83
|
-
model:
|
|
91
|
+
model: MTEBModels,
|
|
84
92
|
data_split: Dataset,
|
|
85
93
|
*,
|
|
86
|
-
encode_kwargs:
|
|
94
|
+
encode_kwargs: EncodeKwargs,
|
|
87
95
|
hf_split: str,
|
|
88
96
|
hf_subset: str,
|
|
89
97
|
prediction_folder: Path | None = None,
|
|
98
|
+
num_proc: int = 1,
|
|
90
99
|
**kwargs: Any,
|
|
91
100
|
) -> ScoresDict:
|
|
101
|
+
if not isinstance(model, EncoderProtocol):
|
|
102
|
+
raise TypeError("Expected model to be an instance of EncoderProtocol")
|
|
103
|
+
|
|
92
104
|
data_split = data_split.select_columns(
|
|
93
105
|
[self.input_column_name, self.label_column_name]
|
|
94
106
|
)
|
|
@@ -139,9 +151,6 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
139
151
|
}
|
|
140
152
|
return scores
|
|
141
153
|
|
|
142
|
-
data_split = data_split.select_columns(
|
|
143
|
-
[self.input_column_name, self.label_column_name]
|
|
144
|
-
)
|
|
145
154
|
evaluator = self.evaluator(
|
|
146
155
|
data_split,
|
|
147
156
|
input_column_name=self.input_column_name,
|
|
@@ -151,10 +160,14 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
151
160
|
hf_subset=hf_subset,
|
|
152
161
|
**kwargs,
|
|
153
162
|
)
|
|
154
|
-
|
|
163
|
+
evaluate_clusters = evaluator(
|
|
164
|
+
model,
|
|
165
|
+
encode_kwargs=encode_kwargs,
|
|
166
|
+
num_proc=num_proc,
|
|
167
|
+
)
|
|
155
168
|
if prediction_folder:
|
|
156
169
|
self._save_task_predictions(
|
|
157
|
-
|
|
170
|
+
evaluate_clusters,
|
|
158
171
|
model,
|
|
159
172
|
prediction_folder,
|
|
160
173
|
hf_subset=hf_subset,
|
|
@@ -163,7 +176,7 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
163
176
|
|
|
164
177
|
return self._compute_metrics(
|
|
165
178
|
data_split[self.label_column_name],
|
|
166
|
-
|
|
179
|
+
evaluate_clusters,
|
|
167
180
|
)
|
|
168
181
|
|
|
169
182
|
def _compute_metrics(
|
|
@@ -230,11 +243,12 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
230
243
|
label_statistics=label_statistics,
|
|
231
244
|
)
|
|
232
245
|
|
|
233
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
246
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
234
247
|
self._upload_dataset_to_hub(
|
|
235
248
|
repo_name,
|
|
236
249
|
[
|
|
237
250
|
self.input_column_name,
|
|
238
251
|
self.label_column_name,
|
|
239
252
|
],
|
|
253
|
+
num_proc=num_proc,
|
|
240
254
|
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from collections.abc import Sequence
|
|
3
|
-
from
|
|
4
|
-
from typing import Any, TypedDict
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
|
-
from datasets import
|
|
8
|
+
from datasets import concatenate_datasets
|
|
8
9
|
|
|
9
10
|
from mteb._evaluators import ImageTextPairClassificationEvaluator
|
|
10
11
|
from mteb.abstasks._statistics_calculation import (
|
|
@@ -14,11 +15,21 @@ from mteb.abstasks._statistics_calculation import (
|
|
|
14
15
|
from mteb.abstasks.abstask import AbsTask
|
|
15
16
|
from mteb.models.models_protocols import EncoderProtocol
|
|
16
17
|
from mteb.types.statistics import (
|
|
17
|
-
ImageStatistics,
|
|
18
18
|
SplitDescriptiveStatistics,
|
|
19
|
-
TextStatistics,
|
|
20
19
|
)
|
|
21
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
from datasets import Dataset
|
|
25
|
+
|
|
26
|
+
from mteb.models.models_protocols import MTEBModels
|
|
27
|
+
from mteb.types import EncodeKwargs
|
|
28
|
+
from mteb.types.statistics import (
|
|
29
|
+
ImageStatistics,
|
|
30
|
+
TextStatistics,
|
|
31
|
+
)
|
|
32
|
+
|
|
22
33
|
logger = logging.getLogger(__name__)
|
|
23
34
|
|
|
24
35
|
|
|
@@ -116,15 +127,18 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
116
127
|
|
|
117
128
|
def _evaluate_subset(
|
|
118
129
|
self,
|
|
119
|
-
model:
|
|
130
|
+
model: MTEBModels,
|
|
120
131
|
data_split: Dataset,
|
|
121
132
|
*,
|
|
122
|
-
encode_kwargs:
|
|
133
|
+
encode_kwargs: EncodeKwargs,
|
|
123
134
|
hf_split: str,
|
|
124
135
|
hf_subset: str,
|
|
125
136
|
prediction_folder: Path | None = None,
|
|
137
|
+
num_proc: int = 1,
|
|
126
138
|
**kwargs: Any,
|
|
127
139
|
) -> ImageTextPairClassificationMetrics:
|
|
140
|
+
if not isinstance(model, EncoderProtocol):
|
|
141
|
+
raise TypeError("Expected model to be an instance of EncoderProtocol")
|
|
128
142
|
select_columns = []
|
|
129
143
|
for columns in (self.images_column_names, self.texts_column_names):
|
|
130
144
|
if isinstance(columns, str):
|
|
@@ -154,7 +168,9 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
154
168
|
hf_subset=hf_subset,
|
|
155
169
|
**kwargs,
|
|
156
170
|
)
|
|
157
|
-
scores = evaluator(
|
|
171
|
+
scores: list[torch.Tensor] = evaluator(
|
|
172
|
+
model, encode_kwargs=encode_kwargs, num_proc=num_proc
|
|
173
|
+
) # type: ignore[assignment]
|
|
158
174
|
if prediction_folder:
|
|
159
175
|
self._save_task_predictions(
|
|
160
176
|
[score.tolist() for score in scores],
|
|
@@ -202,7 +218,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
202
218
|
accuracy=torch.Tensor(all_correct_scores).float().mean().item(),
|
|
203
219
|
)
|
|
204
220
|
|
|
205
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
221
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
206
222
|
text_columns = (
|
|
207
223
|
[self.texts_column_names]
|
|
208
224
|
if isinstance(self.texts_column_names, str)
|
|
@@ -217,4 +233,5 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
217
233
|
self._upload_dataset_to_hub(
|
|
218
234
|
repo_name,
|
|
219
235
|
[*text_columns, *image_columns],
|
|
236
|
+
num_proc=num_proc,
|
|
220
237
|
)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import itertools
|
|
2
4
|
import logging
|
|
3
5
|
from collections import defaultdict
|
|
4
|
-
from
|
|
5
|
-
from typing import Any, TypedDict
|
|
6
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from datasets import DatasetDict
|
|
@@ -15,23 +16,29 @@ from typing_extensions import override
|
|
|
15
16
|
|
|
16
17
|
from mteb._create_dataloaders import create_dataloader
|
|
17
18
|
from mteb._evaluators.classification_metrics import hamming_score
|
|
18
|
-
from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
|
|
19
19
|
from mteb.models import EncoderProtocol
|
|
20
20
|
|
|
21
21
|
from .classification import AbsTaskClassification
|
|
22
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
|
|
27
|
+
from mteb.models import MTEBModels
|
|
28
|
+
from mteb.types import Array, EncodeKwargs
|
|
29
|
+
|
|
23
30
|
logger = logging.getLogger(__name__)
|
|
24
31
|
|
|
25
32
|
|
|
26
33
|
def _evaluate_classifier(
|
|
27
|
-
embeddings_train:
|
|
34
|
+
embeddings_train: Array,
|
|
28
35
|
y_train: np.ndarray,
|
|
29
|
-
embeddings_test:
|
|
36
|
+
embeddings_test: Array,
|
|
30
37
|
classifier: SklearnModelProtocol,
|
|
31
38
|
) -> tuple[np.ndarray, SklearnModelProtocol]:
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
return
|
|
39
|
+
classifier_copy: SklearnModelProtocol = clone(classifier)
|
|
40
|
+
classifier_copy.fit(embeddings_train, y_train)
|
|
41
|
+
return classifier_copy.predict(embeddings_test), classifier_copy
|
|
35
42
|
|
|
36
43
|
|
|
37
44
|
class MultilabelClassificationMetrics(TypedDict):
|
|
@@ -69,25 +76,29 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
69
76
|
input_column_name: Name of the column containing the input text.
|
|
70
77
|
label_column_name: Name of the column containing the labels.
|
|
71
78
|
samples_per_label: Number of samples to use pr. label. These samples are embedded and a classifier is fit using the labels and samples.
|
|
72
|
-
|
|
79
|
+
evaluator_model: Classifier to use for evaluation. Must implement the SklearnModelProtocol.
|
|
73
80
|
"""
|
|
74
81
|
|
|
75
|
-
|
|
82
|
+
evaluator_model: SklearnModelProtocol = KNeighborsClassifier(n_neighbors=5)
|
|
76
83
|
input_column_name: str = "text"
|
|
77
84
|
label_column_name: str = "label"
|
|
78
85
|
|
|
79
86
|
@override
|
|
80
|
-
def _evaluate_subset(
|
|
87
|
+
def _evaluate_subset( # type: ignore[override]
|
|
81
88
|
self,
|
|
82
|
-
model:
|
|
89
|
+
model: MTEBModels,
|
|
83
90
|
data_split: DatasetDict,
|
|
84
91
|
*,
|
|
85
|
-
encode_kwargs:
|
|
92
|
+
encode_kwargs: EncodeKwargs,
|
|
86
93
|
hf_split: str,
|
|
87
94
|
hf_subset: str,
|
|
88
95
|
prediction_folder: Path | None = None,
|
|
96
|
+
num_proc: int = 1,
|
|
89
97
|
**kwargs: Any,
|
|
90
98
|
) -> FullMultilabelClassificationMetrics:
|
|
99
|
+
if not isinstance(model, EncoderProtocol):
|
|
100
|
+
raise TypeError("Expected model to be an instance of EncoderProtocol")
|
|
101
|
+
|
|
91
102
|
if isinstance(data_split, DatasetDict):
|
|
92
103
|
data_split = data_split.select_columns(
|
|
93
104
|
[self.input_column_name, self.label_column_name]
|
|
@@ -115,6 +126,7 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
115
126
|
unique_train_dataset,
|
|
116
127
|
self.metadata,
|
|
117
128
|
input_column=self.input_column_name,
|
|
129
|
+
num_proc=num_proc,
|
|
118
130
|
**encode_kwargs,
|
|
119
131
|
)
|
|
120
132
|
|
|
@@ -165,7 +177,7 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
165
177
|
y_train = train_split.select(sample_indices)[self.label_column_name]
|
|
166
178
|
y_train = binarizer.transform(y_train)
|
|
167
179
|
y_pred, current_classifier = _evaluate_classifier(
|
|
168
|
-
X_train, y_train, X_test, self.
|
|
180
|
+
X_train, y_train, X_test, self.evaluator_model
|
|
169
181
|
)
|
|
170
182
|
if prediction_folder:
|
|
171
183
|
all_predictions.append(y_pred.tolist())
|
|
@@ -185,19 +197,20 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
185
197
|
)
|
|
186
198
|
|
|
187
199
|
avg_scores: dict[str, Any] = {
|
|
188
|
-
k: np.mean([s[k] for s in scores])
|
|
200
|
+
k: np.mean([s[k] for s in scores]) # type: ignore[literal-required]
|
|
201
|
+
for k in scores[0].keys()
|
|
189
202
|
}
|
|
190
203
|
logger.info("Running multilabel classification - Finished.")
|
|
191
204
|
return FullMultilabelClassificationMetrics(
|
|
192
205
|
scores_per_experiment=scores,
|
|
193
|
-
**avg_scores,
|
|
206
|
+
**avg_scores, # type: ignore[typeddict-item]
|
|
194
207
|
)
|
|
195
208
|
|
|
196
|
-
def _calculate_scores(
|
|
209
|
+
def _calculate_scores( # type: ignore[override]
|
|
197
210
|
self,
|
|
198
211
|
y_test: np.ndarray,
|
|
199
212
|
y_pred: np.ndarray,
|
|
200
|
-
x_test_embedding:
|
|
213
|
+
x_test_embedding: Array,
|
|
201
214
|
current_classifier: SklearnModelProtocol,
|
|
202
215
|
) -> MultilabelClassificationMetrics:
|
|
203
216
|
accuracy = current_classifier.score(x_test_embedding, y_test)
|
|
@@ -232,10 +245,9 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
232
245
|
"""
|
|
233
246
|
sample_indices = []
|
|
234
247
|
if idxs is None:
|
|
235
|
-
idxs = np.arange(len(y))
|
|
248
|
+
idxs = list(np.arange(len(y)))
|
|
236
249
|
self.np_rng.shuffle(idxs)
|
|
237
|
-
|
|
238
|
-
label_counter = defaultdict(int)
|
|
250
|
+
label_counter: dict[int, int] = defaultdict(int)
|
|
239
251
|
for i in idxs:
|
|
240
252
|
if any((label_counter[label] < samples_per_label) for label in y[i]):
|
|
241
253
|
sample_indices.append(i)
|