evalscope 0.10.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +4 -1
- evalscope/api/benchmark/__init__.py +11 -0
- evalscope/api/benchmark/adapters/__init__.py +7 -0
- evalscope/api/benchmark/adapters/agent_adapter.py +8 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +754 -0
- evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +86 -0
- evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +157 -0
- evalscope/api/benchmark/adapters/vision_language_adapter.py +8 -0
- evalscope/api/benchmark/benchmark.py +404 -0
- evalscope/api/benchmark/meta.py +124 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +370 -0
- evalscope/api/dataset/loader.py +266 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +382 -0
- evalscope/api/evaluator/evaluator.py +61 -0
- evalscope/api/evaluator/state.py +280 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +12 -0
- evalscope/api/messages/chat_message.py +248 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +60 -0
- evalscope/api/metric/scorer.py +113 -0
- evalscope/api/mixin/__init__.py +2 -0
- evalscope/api/mixin/llm_judge_mixin.py +170 -0
- evalscope/api/mixin/sandbox_mixin.py +182 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +161 -0
- evalscope/api/model/model.py +386 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/__init__.py +28 -0
- evalscope/app/app.py +38 -0
- evalscope/app/arguments.py +11 -0
- evalscope/app/constants.py +22 -0
- evalscope/app/ui/__init__.py +20 -0
- evalscope/app/ui/app_ui.py +53 -0
- evalscope/app/ui/multi_model.py +353 -0
- evalscope/app/ui/sidebar.py +42 -0
- evalscope/app/ui/single_model.py +220 -0
- evalscope/app/ui/visualization.py +36 -0
- evalscope/app/utils/data_utils.py +195 -0
- evalscope/app/utils/env_utils.py +12 -0
- evalscope/app/utils/localization.py +221 -0
- evalscope/app/utils/text_utils.py +119 -0
- evalscope/app/utils/visualization.py +96 -0
- evalscope/arguments.py +32 -9
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +10 -7
- evalscope/backend/rag_eval/__init__.py +1 -1
- evalscope/backend/rag_eval/backend_manager.py +23 -6
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +33 -21
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/cmteb/arguments.py +14 -1
- evalscope/backend/rag_eval/cmteb/task_template.py +19 -3
- evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +1 -1
- evalscope/backend/rag_eval/ragas/arguments.py +0 -1
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +9 -3
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -6
- evalscope/backend/rag_eval/utils/embedding.py +125 -32
- evalscope/backend/rag_eval/utils/llm.py +16 -16
- evalscope/backend/vlm_eval_kit/backend_manager.py +8 -3
- evalscope/benchmarks/__init__.py +17 -5
- evalscope/benchmarks/aa_lcr/__init__.py +0 -0
- evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
- evalscope/benchmarks/ai2d/__init__.py +0 -0
- evalscope/benchmarks/ai2d/ai2d_adapter.py +54 -0
- evalscope/benchmarks/aime/__init__.py +0 -0
- evalscope/benchmarks/aime/aime24_adapter.py +55 -0
- evalscope/benchmarks/aime/aime25_adapter.py +181 -0
- evalscope/benchmarks/aime/grader.py +307 -0
- evalscope/{metrics/math_accuracy.py → benchmarks/aime/math_normalize.py} +61 -72
- evalscope/benchmarks/alpaca_eval/__init__.py +0 -0
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +133 -0
- evalscope/benchmarks/amc/__init__.py +0 -0
- evalscope/benchmarks/amc/amc_adapter.py +51 -0
- evalscope/benchmarks/arc/arc_adapter.py +34 -149
- evalscope/benchmarks/arena_hard/__init__.py +0 -0
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +149 -0
- evalscope/benchmarks/arena_hard/utils.py +186 -0
- evalscope/benchmarks/bbh/bbh_adapter.py +117 -157
- evalscope/benchmarks/bfcl/__init__.py +0 -0
- evalscope/benchmarks/bfcl/v3/__init__.py +0 -0
- evalscope/benchmarks/bfcl/v3/bfcl_v3_adapter.py +370 -0
- evalscope/benchmarks/bfcl/v3/generation.py +222 -0
- evalscope/benchmarks/bfcl/v3/utils.py +23 -0
- evalscope/benchmarks/bfcl/v4/__init__.py +0 -0
- evalscope/benchmarks/bfcl/v4/bfcl_v4_adapter.py +229 -0
- evalscope/benchmarks/bfcl/v4/utils.py +410 -0
- evalscope/benchmarks/biomix_qa/__init__.py +0 -0
- evalscope/benchmarks/biomix_qa/biomix_qa_adapter.py +36 -0
- evalscope/benchmarks/blink/__init__.py +0 -0
- evalscope/benchmarks/blink/blink_adapter.py +61 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +93 -174
- evalscope/benchmarks/chartqa/__init__.py +0 -0
- evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
- evalscope/benchmarks/chartqa/utils.py +38 -0
- evalscope/benchmarks/chinese_simple_qa/__init__.py +0 -0
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +170 -0
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -140
- evalscope/benchmarks/coin_flip/__init__.py +0 -0
- evalscope/benchmarks/coin_flip/coin_flip_adapter.py +128 -0
- evalscope/benchmarks/commonsense_qa/__init__.py +0 -0
- evalscope/benchmarks/commonsense_qa/commonsense_qa_adapter.py +32 -0
- evalscope/benchmarks/competition_math/competition_math_adapter.py +64 -112
- evalscope/benchmarks/data_collection/__init__.py +0 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +215 -0
- evalscope/benchmarks/docmath/__init__.py +0 -0
- evalscope/benchmarks/docmath/docmath_adapter.py +143 -0
- evalscope/benchmarks/docmath/utils.py +219 -0
- evalscope/benchmarks/docvqa/__init__.py +0 -0
- evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
- evalscope/benchmarks/drivelology/__init__.py +0 -0
- evalscope/benchmarks/drivelology/drivelology_binary_adapter.py +170 -0
- evalscope/benchmarks/drivelology/drivelology_multilabel_adapter.py +254 -0
- evalscope/benchmarks/drivelology/drivelology_selection_adapter.py +49 -0
- evalscope/benchmarks/drivelology/drivelology_writing_adapter.py +218 -0
- evalscope/benchmarks/drop/__init__.py +0 -0
- evalscope/benchmarks/drop/drop_adapter.py +155 -0
- evalscope/benchmarks/drop/utils.py +156 -0
- evalscope/benchmarks/frames/__init__.py +0 -0
- evalscope/benchmarks/frames/frames_adapter.py +175 -0
- evalscope/benchmarks/frames/utils.py +37 -0
- evalscope/benchmarks/general_arena/__init__.py +0 -0
- evalscope/benchmarks/general_arena/general_arena_adapter.py +454 -0
- evalscope/benchmarks/general_arena/utils.py +223 -0
- evalscope/benchmarks/general_mcq/__init__.py +0 -0
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +58 -0
- evalscope/benchmarks/general_qa/general_qa_adapter.py +75 -107
- evalscope/benchmarks/gpqa/__init__.py +0 -0
- evalscope/benchmarks/gpqa/gpqa_adapter.py +90 -0
- evalscope/benchmarks/gpqa/prompt.py +88 -0
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +77 -144
- evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
- evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +159 -0
- evalscope/benchmarks/halu_eval/__init__.py +0 -0
- evalscope/benchmarks/halu_eval/halu_eval_adapter.py +128 -0
- evalscope/benchmarks/halu_eval/halu_eval_instructions.py +84 -0
- evalscope/benchmarks/healthbench/__init__.py +0 -0
- evalscope/benchmarks/healthbench/healthbench_adapter.py +282 -0
- evalscope/benchmarks/healthbench/utils.py +102 -0
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +36 -134
- evalscope/benchmarks/hle/__init__.py +0 -0
- evalscope/benchmarks/hle/hle_adapter.py +153 -0
- evalscope/benchmarks/humaneval/humaneval_adapter.py +80 -88
- evalscope/benchmarks/humaneval/utils.py +235 -0
- evalscope/benchmarks/ifeval/ifeval_adapter.py +71 -45
- evalscope/benchmarks/ifeval/instructions.py +112 -68
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/instructions_util.py +2 -3
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/image_edit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
- evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
- evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
- evalscope/benchmarks/infovqa/__init__.py +0 -0
- evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -58
- evalscope/benchmarks/live_code_bench/__init__.py +0 -0
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +195 -0
- evalscope/benchmarks/live_code_bench/extract_utils.py +70 -0
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +150 -0
- evalscope/benchmarks/live_code_bench/load_utils.py +63 -0
- evalscope/benchmarks/live_code_bench/pass_k_utils.py +56 -0
- evalscope/benchmarks/live_code_bench/prompts.py +207 -0
- evalscope/benchmarks/live_code_bench/sandbox_evaluate_utils.py +220 -0
- evalscope/benchmarks/live_code_bench/testing_util.py +544 -0
- evalscope/benchmarks/logi_qa/__int__.py +0 -0
- evalscope/benchmarks/logi_qa/logi_qa_adapter.py +41 -0
- evalscope/benchmarks/maritime_bench/__init__.py +0 -0
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +56 -0
- evalscope/benchmarks/math_500/__init__.py +0 -0
- evalscope/benchmarks/math_500/math_500_adapter.py +55 -0
- evalscope/benchmarks/math_qa/__init__.py +0 -0
- evalscope/benchmarks/math_qa/math_qa_adapter.py +35 -0
- evalscope/benchmarks/math_verse/__init__.py +0 -0
- evalscope/benchmarks/math_verse/math_verse_adapter.py +105 -0
- evalscope/benchmarks/math_vision/__init__.py +0 -0
- evalscope/benchmarks/math_vision/math_vision_adapter.py +116 -0
- evalscope/benchmarks/math_vista/__init__.py +0 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +114 -0
- evalscope/benchmarks/med_mcqa/__init__.py +0 -0
- evalscope/benchmarks/med_mcqa/med_mcqa_adapter.py +32 -0
- evalscope/benchmarks/minerva_math/__init__.py +0 -0
- evalscope/benchmarks/minerva_math/minerva_math_adapter.py +53 -0
- evalscope/benchmarks/mm_bench/__init__.py +0 -0
- evalscope/benchmarks/mm_bench/mm_bench_adapter.py +99 -0
- evalscope/benchmarks/mm_star/__init__.py +0 -0
- evalscope/benchmarks/mm_star/mm_star_adapter.py +73 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -210
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +87 -103
- evalscope/benchmarks/mmlu_redux/__init__.py +0 -0
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +139 -0
- evalscope/benchmarks/mmmu/__init__.py +0 -0
- evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
- evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +124 -0
- evalscope/benchmarks/mri_mcqa/__init__.py +0 -0
- evalscope/benchmarks/mri_mcqa/mri_mcqa_adapter.py +34 -0
- evalscope/benchmarks/multi_if/__init__.py +0 -0
- evalscope/benchmarks/multi_if/ifeval.py +3354 -0
- evalscope/benchmarks/multi_if/metrics.py +120 -0
- evalscope/benchmarks/multi_if/multi_if_adapter.py +161 -0
- evalscope/benchmarks/music_trivia/__init__.py +0 -0
- evalscope/benchmarks/music_trivia/music_trivia_adapter.py +36 -0
- evalscope/benchmarks/musr/__init__.py +0 -0
- evalscope/benchmarks/musr/musr_adapter.py +43 -0
- evalscope/benchmarks/needle_haystack/__init__.py +0 -0
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +389 -0
- evalscope/benchmarks/needle_haystack/utils.py +79 -0
- evalscope/benchmarks/ner/__init__.py +0 -0
- evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
- evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
- evalscope/benchmarks/ner/copious_adapter.py +85 -0
- evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
- evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
- evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
- evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
- evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
- evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
- evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
- evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
- evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
- evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
- evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
- evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
- evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
- evalscope/benchmarks/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench/ocr_bench_adapter.py +101 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/IoUscore_metric.py +87 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/TEDS_metric.py +963 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/page_ocr_metric.py +50 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/parallel.py +46 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/readme.txt +26 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/script.py +481 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_metric.py +179 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/utils.py +433 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/vqa_metric.py +254 -0
- evalscope/benchmarks/olympiad_bench/__init__.py +0 -0
- evalscope/benchmarks/olympiad_bench/olympiad_bench_adapter.py +163 -0
- evalscope/benchmarks/olympiad_bench/utils.py +565 -0
- evalscope/benchmarks/omni_bench/__init__.py +0 -0
- evalscope/benchmarks/omni_bench/omni_bench_adapter.py +86 -0
- evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
- evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
- evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
- evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
- evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
- evalscope/benchmarks/piqa/__init__.py +0 -0
- evalscope/benchmarks/piqa/piqa_adapter.py +32 -0
- evalscope/benchmarks/poly_math/__init__.py +0 -0
- evalscope/benchmarks/poly_math/poly_math_adapter.py +132 -0
- evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
- evalscope/benchmarks/pope/__init__.py +0 -0
- evalscope/benchmarks/pope/pope_adapter.py +112 -0
- evalscope/benchmarks/process_bench/__init__.py +0 -0
- evalscope/benchmarks/process_bench/process_bench_adapter.py +171 -0
- evalscope/benchmarks/pumed_qa/__init__.py +0 -0
- evalscope/benchmarks/pumed_qa/pubmed_qa_adapter.py +175 -0
- evalscope/benchmarks/qasc/__init__.py +0 -0
- evalscope/benchmarks/qasc/qasc_adapter.py +35 -0
- evalscope/benchmarks/race/race_adapter.py +33 -120
- evalscope/benchmarks/real_world_qa/__init__.py +0 -0
- evalscope/benchmarks/real_world_qa/real_world_qa_adapter.py +64 -0
- evalscope/benchmarks/sciq/__init__.py +0 -0
- evalscope/benchmarks/sciq/sciq_adapter.py +36 -0
- evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
- evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
- evalscope/benchmarks/simple_qa/__init__.py +0 -0
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +169 -0
- evalscope/benchmarks/simple_vqa/__init__.py +0 -0
- evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
- evalscope/benchmarks/siqa/__init__.py +0 -0
- evalscope/benchmarks/siqa/siqa_adapter.py +39 -0
- evalscope/benchmarks/super_gpqa/__init__.py +0 -0
- evalscope/benchmarks/super_gpqa/prompt.py +88 -0
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +165 -0
- evalscope/benchmarks/super_gpqa/utils.py +86 -0
- evalscope/benchmarks/tau_bench/__init__.py +0 -0
- evalscope/benchmarks/tau_bench/tau2_bench/__init__.py +0 -0
- evalscope/benchmarks/tau_bench/tau2_bench/generation.py +158 -0
- evalscope/benchmarks/tau_bench/tau2_bench/tau2_bench_adapter.py +146 -0
- evalscope/benchmarks/tau_bench/tau_bench/__init__.py +0 -0
- evalscope/benchmarks/tau_bench/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench/tau_bench_adapter.py +168 -0
- evalscope/benchmarks/text2image/__init__.py +0 -0
- evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
- evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
- evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
- evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
- evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
- evalscope/benchmarks/tool_bench/__init__.py +0 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +102 -0
- evalscope/benchmarks/tool_bench/utils.py +203 -0
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -118
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -270
- evalscope/benchmarks/visu_logic/__init__.py +0 -0
- evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
- evalscope/benchmarks/winogrande/__init__.py +0 -0
- evalscope/benchmarks/winogrande/winogrande_adapter.py +34 -0
- evalscope/benchmarks/wmt/__init__.py +0 -0
- evalscope/benchmarks/wmt/wmt24_adapter.py +294 -0
- evalscope/benchmarks/zerobench/__init__.py +0 -0
- evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_app.py +12 -2
- evalscope/cli/start_eval.py +4 -3
- evalscope/cli/start_perf.py +10 -2
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +27 -3
- evalscope/collections/sampler.py +12 -11
- evalscope/collections/schema.py +13 -12
- evalscope/config.py +218 -147
- evalscope/constants.py +78 -82
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +334 -318
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +59 -3
- evalscope/metrics/bert_score/__init__.py +0 -0
- evalscope/metrics/bert_score/scorer.py +338 -0
- evalscope/metrics/bert_score/utils.py +697 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +20 -15
- evalscope/metrics/llm_judge.py +211 -0
- evalscope/metrics/math_parser.py +545 -0
- evalscope/metrics/metric.py +611 -0
- evalscope/metrics/metrics.py +112 -23
- evalscope/metrics/rouge_metric.py +11 -13
- evalscope/metrics/t2v_metrics/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +134 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +282 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +115 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +87 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +99 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +176 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +82 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +74 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +306 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +84 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +223 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +153 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +24 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +190 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +100 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +313 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +192 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +320 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1111 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +457 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +370 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +765 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +274 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +896 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1876 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +83 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +58 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +187 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +179 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +115 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +348 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +870 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +514 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1291 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +476 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +35 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +393 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +129 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +18 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +23 -13
- evalscope/models/image_edit_model.py +125 -0
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +69 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +144 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +708 -0
- evalscope/perf/__init__.py +0 -1
- evalscope/perf/arguments.py +103 -69
- evalscope/perf/benchmark.py +114 -163
- evalscope/perf/http_client.py +59 -89
- evalscope/perf/main.py +91 -18
- evalscope/perf/plugin/__init__.py +3 -2
- evalscope/perf/plugin/api/__init__.py +4 -3
- evalscope/perf/plugin/api/base.py +27 -7
- evalscope/perf/plugin/api/custom_api.py +170 -57
- evalscope/perf/plugin/api/dashscope_api.py +4 -10
- evalscope/perf/plugin/api/default_api.py +214 -0
- evalscope/perf/plugin/api/openai_api.py +120 -41
- evalscope/perf/plugin/datasets/__init__.py +10 -6
- evalscope/perf/plugin/datasets/base.py +43 -1
- evalscope/perf/plugin/datasets/custom.py +22 -3
- evalscope/perf/plugin/datasets/flickr8k.py +5 -27
- evalscope/perf/plugin/datasets/kontext_bench.py +28 -0
- evalscope/perf/plugin/datasets/line_by_line.py +7 -3
- evalscope/perf/plugin/datasets/longalpaca.py +7 -3
- evalscope/perf/plugin/datasets/openqa.py +13 -14
- evalscope/perf/plugin/datasets/random_dataset.py +67 -0
- evalscope/perf/plugin/datasets/random_vl_dataset.py +80 -0
- evalscope/perf/plugin/datasets/speed_benchmark.py +11 -0
- evalscope/perf/plugin/registry.py +36 -16
- evalscope/perf/utils/analysis_result.py +24 -23
- evalscope/perf/utils/benchmark_util.py +95 -55
- evalscope/perf/utils/db_util.py +115 -78
- evalscope/perf/utils/local_server.py +12 -47
- evalscope/perf/utils/log_utils.py +63 -0
- evalscope/perf/utils/rich_display.py +192 -0
- evalscope/report/__init__.py +46 -3
- evalscope/report/combinator.py +143 -32
- evalscope/report/generator.py +74 -34
- evalscope/report/report.py +238 -0
- evalscope/run.py +71 -46
- evalscope/summarizer.py +5 -5
- evalscope/third_party/longbench_write/infer.py +1 -1
- evalscope/third_party/thinkbench/__init__.py +3 -0
- evalscope/third_party/thinkbench/eval.py +441 -0
- evalscope/third_party/thinkbench/infer.py +130 -0
- evalscope/third_party/thinkbench/resources/critique_template.txt +17 -0
- evalscope/third_party/thinkbench/resources/reformat_template.txt +31 -0
- evalscope/third_party/thinkbench/tools/__init__.py +0 -0
- evalscope/third_party/thinkbench/tools/llm.py +48 -0
- evalscope/third_party/thinkbench/tools/utils.py +13 -0
- evalscope/third_party/toolbench_static/llm/swift_infer.py +46 -20
- evalscope/third_party/toolbench_static/toolbench_static.py +2 -1
- evalscope/utils/__init__.py +82 -2
- evalscope/utils/argument_utils.py +64 -0
- evalscope/utils/chat_service.py +8 -6
- evalscope/utils/deprecation_utils.py +53 -0
- evalscope/utils/function_utils.py +266 -0
- evalscope/utils/import_utils.py +154 -0
- evalscope/utils/io_utils.py +336 -8
- evalscope/utils/json_schema.py +231 -0
- evalscope/utils/logger.py +121 -31
- evalscope/utils/model_utils.py +57 -1
- evalscope/utils/multi_choices.py +303 -0
- evalscope/utils/ner.py +377 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- evalscope-1.2.0.dist-info/METADATA +553 -0
- evalscope-1.2.0.dist-info/RECORD +628 -0
- {evalscope-0.10.0.dist-info → evalscope-1.2.0.dist-info}/WHEEL +1 -1
- {evalscope-0.10.0.dist-info → evalscope-1.2.0.dist-info}/top_level.txt +0 -1
- evalscope/backend/vlm_eval_kit/custom_dataset.py +0 -46
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -76
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/ceval/samples.jsonl +0 -1
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -291
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/collections/evaluator.py +0 -198
- evalscope/evaluator/rating_eval.py +0 -157
- evalscope/evaluator/reviewer/__init__.py +0 -1
- evalscope/evaluator/reviewer/auto_reviewer.py +0 -391
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/named_metrics.py +0 -17
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- evalscope/models/base_adapter.py +0 -52
- evalscope/models/chat_adapter.py +0 -138
- evalscope/models/choice_adapter.py +0 -211
- evalscope/models/custom/__init__.py +0 -3
- evalscope/models/custom/custom_model.py +0 -53
- evalscope/models/custom/dummy_model.py +0 -63
- evalscope/models/custom_adapter.py +0 -67
- evalscope/models/local_model.py +0 -74
- evalscope/models/model.py +0 -229
- evalscope/models/server_adapter.py +0 -111
- evalscope/registry/__init__.py +0 -1
- evalscope/registry/config/cfg_arena.yaml +0 -77
- evalscope/registry/config/cfg_arena_zhihu.yaml +0 -63
- evalscope/registry/config/cfg_pairwise_baseline.yaml +0 -83
- evalscope/registry/config/cfg_single.yaml +0 -78
- evalscope/registry/data/prompt_template/lmsys_v2.jsonl +0 -8
- evalscope/registry/data/prompt_template/prompt_templates.jsonl +0 -8
- evalscope/registry/data/qa_browser/battle.jsonl +0 -634
- evalscope/registry/data/qa_browser/category_mapping.yaml +0 -10
- evalscope/registry/data/question.jsonl +0 -80
- evalscope/registry/tasks/arc.yaml +0 -28
- evalscope/registry/tasks/bbh.yaml +0 -26
- evalscope/registry/tasks/bbh_mini.yaml +0 -26
- evalscope/registry/tasks/ceval.yaml +0 -27
- evalscope/registry/tasks/ceval_mini.yaml +0 -26
- evalscope/registry/tasks/cmmlu.yaml +0 -27
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +0 -28
- evalscope/registry/tasks/general_qa.yaml +0 -27
- evalscope/registry/tasks/gsm8k.yaml +0 -29
- evalscope/registry/tasks/mmlu.yaml +0 -29
- evalscope/registry/tasks/mmlu_mini.yaml +0 -27
- evalscope/report/app.py +0 -506
- evalscope/report/utils.py +0 -133
- evalscope/run_arena.py +0 -202
- evalscope/utils/arena_utils.py +0 -217
- evalscope/utils/completion_parsers.py +0 -82
- evalscope/utils/utils.py +0 -301
- evalscope-0.10.0.dist-info/METADATA +0 -565
- evalscope-0.10.0.dist-info/RECORD +0 -286
- tests/__init__.py +0 -1
- tests/cli/__init__.py +0 -1
- tests/cli/test_collection.py +0 -57
- tests/cli/test_run.py +0 -165
- tests/perf/__init__.py +0 -1
- tests/perf/test_perf.py +0 -101
- tests/rag/test_clip_benchmark.py +0 -85
- tests/rag/test_mteb.py +0 -138
- tests/rag/test_ragas.py +0 -120
- tests/swift/__init__.py +0 -1
- tests/swift/test_run_swift_eval.py +0 -145
- tests/swift/test_run_swift_vlm_eval.py +0 -127
- tests/swift/test_run_swift_vlm_jugde_eval.py +0 -156
- tests/test_run_all.py +0 -12
- tests/vlm/__init__.py +0 -1
- tests/vlm/test_vlmeval.py +0 -60
- {tests/rag → evalscope/api}/__init__.py +0 -0
- {evalscope-0.10.0.dist-info → evalscope-1.2.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.10.0.dist-info → evalscope-1.2.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,697 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import torch
|
|
5
|
+
from collections import Counter, defaultdict
|
|
6
|
+
from functools import partial
|
|
7
|
+
from itertools import chain
|
|
8
|
+
from math import log
|
|
9
|
+
from modelscope import AutoModel, AutoTokenizer
|
|
10
|
+
from multiprocessing import Pool
|
|
11
|
+
from packaging import version
|
|
12
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
13
|
+
from tqdm.auto import tqdm
|
|
14
|
+
from transformers import GPT2Tokenizer, RobertaTokenizer
|
|
15
|
+
from transformers import __version__ as trans_version
|
|
16
|
+
|
|
17
|
+
from evalscope import __version__
|
|
18
|
+
|
|
19
|
+
__all__ = []
|
|
20
|
+
|
|
21
|
+
SCIBERT_URL_DICT = {
|
|
22
|
+
'scibert-scivocab-uncased':
|
|
23
|
+
'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_uncased.tar', # recommend by the SciBERT authors
|
|
24
|
+
'scibert-scivocab-cased':
|
|
25
|
+
'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_cased.tar',
|
|
26
|
+
'scibert-basevocab-uncased':
|
|
27
|
+
'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_uncased.tar',
|
|
28
|
+
'scibert-basevocab-cased':
|
|
29
|
+
'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_cased.tar',
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
lang2model = defaultdict(lambda: 'bert-base-multilingual-cased')
|
|
33
|
+
lang2model.update({
|
|
34
|
+
'en': 'roberta-large',
|
|
35
|
+
'zh': 'bert-base-chinese',
|
|
36
|
+
'tr': 'dbmdz/bert-base-turkish-cased',
|
|
37
|
+
'en-sci': 'allenai/scibert_scivocab_uncased',
|
|
38
|
+
})
|
|
39
|
+
|
|
40
|
+
model2layers = {
|
|
41
|
+
'bert-base-uncased': 9, # 0.6925188074454226
|
|
42
|
+
'bert-large-uncased': 18, # 0.7210358126642836
|
|
43
|
+
'bert-base-cased-finetuned-mrpc': 9, # 0.6721947475618048
|
|
44
|
+
'bert-base-multilingual-cased': 9, # 0.6680687802637132
|
|
45
|
+
'bert-base-chinese': 8,
|
|
46
|
+
'roberta-base': 10, # 0.706288719158983
|
|
47
|
+
'roberta-large': 17, # 0.7385974720781534
|
|
48
|
+
'roberta-large-mnli': 19, # 0.7535618640417984
|
|
49
|
+
'roberta-base-openai-detector': 7, # 0.7048158349432633
|
|
50
|
+
'roberta-large-openai-detector': 15, # 0.7462770207355116
|
|
51
|
+
'xlnet-base-cased': 5, # 0.6630103662114238
|
|
52
|
+
'xlnet-large-cased': 7, # 0.6598800720297179
|
|
53
|
+
'xlm-mlm-en-2048': 6, # 0.651262570131464
|
|
54
|
+
'xlm-mlm-100-1280': 10, # 0.6475166424401905
|
|
55
|
+
# "scibert-scivocab-uncased": 8, # 0.6590354319927313
|
|
56
|
+
# "scibert-scivocab-cased": 9, # 0.6536375053937445
|
|
57
|
+
# "scibert-basevocab-uncased": 9, # 0.6748944832703548
|
|
58
|
+
# "scibert-basevocab-cased": 9, # 0.6524624150542374
|
|
59
|
+
'allenai/scibert_scivocab_uncased': 8, # 0.6590354393124127
|
|
60
|
+
'allenai/scibert_scivocab_cased': 9, # 0.6536374902465466
|
|
61
|
+
'nfliu/scibert_basevocab_uncased': 9, # 0.6748945076082333
|
|
62
|
+
'distilroberta-base': 5, # 0.6797558139322964
|
|
63
|
+
'distilbert-base-uncased': 5, # 0.6756659152782033
|
|
64
|
+
'distilbert-base-uncased-distilled-squad': 4, # 0.6718318036382493
|
|
65
|
+
'distilbert-base-multilingual-cased': 5, # 0.6178131050889238
|
|
66
|
+
'albert-base-v1': 10, # 0.654237567249745
|
|
67
|
+
'albert-large-v1': 17, # 0.6755890754323239
|
|
68
|
+
'albert-xlarge-v1': 16, # 0.7031844211905911
|
|
69
|
+
'albert-xxlarge-v1': 8, # 0.7508642218461096
|
|
70
|
+
'albert-base-v2': 9, # 0.6682455591837927
|
|
71
|
+
'albert-large-v2': 14, # 0.7008537594374035
|
|
72
|
+
'albert-xlarge-v2': 13, # 0.7317228357869254
|
|
73
|
+
'albert-xxlarge-v2': 8, # 0.7505160257184014
|
|
74
|
+
'xlm-roberta-base': 9, # 0.6506799445871697
|
|
75
|
+
'xlm-roberta-large': 17, # 0.6941551437476826
|
|
76
|
+
'google/electra-small-generator': 9, # 0.6659421842117754
|
|
77
|
+
'google/electra-small-discriminator': 11, # 0.6534639151385759
|
|
78
|
+
'google/electra-base-generator': 10, # 0.6730033453857188
|
|
79
|
+
'google/electra-base-discriminator': 9, # 0.7032089590812965
|
|
80
|
+
'google/electra-large-generator': 18, # 0.6813370013104459
|
|
81
|
+
'google/electra-large-discriminator': 14, # 0.6896675824733477
|
|
82
|
+
'google/bert_uncased_L-2_H-128_A-2': 1, # 0.5887998733228855
|
|
83
|
+
'google/bert_uncased_L-2_H-256_A-4': 1, # 0.6114863547661203
|
|
84
|
+
'google/bert_uncased_L-2_H-512_A-8': 1, # 0.6177345529192847
|
|
85
|
+
'google/bert_uncased_L-2_H-768_A-12': 2, # 0.6191261237956839
|
|
86
|
+
'google/bert_uncased_L-4_H-128_A-2': 3, # 0.6076202863798991
|
|
87
|
+
'google/bert_uncased_L-4_H-256_A-4': 3, # 0.6205239036810148
|
|
88
|
+
'google/bert_uncased_L-4_H-512_A-8': 3, # 0.6375351621856903
|
|
89
|
+
'google/bert_uncased_L-4_H-768_A-12': 3, # 0.6561849979644787
|
|
90
|
+
'google/bert_uncased_L-6_H-128_A-2': 5, # 0.6200458425360283
|
|
91
|
+
'google/bert_uncased_L-6_H-256_A-4': 5, # 0.6277501629539081
|
|
92
|
+
'google/bert_uncased_L-6_H-512_A-8': 5, # 0.641952305130849
|
|
93
|
+
'google/bert_uncased_L-6_H-768_A-12': 5, # 0.6762186226247106
|
|
94
|
+
'google/bert_uncased_L-8_H-128_A-2': 7, # 0.6186876506711779
|
|
95
|
+
'google/bert_uncased_L-8_H-256_A-4': 7, # 0.6447993208267708
|
|
96
|
+
'google/bert_uncased_L-8_H-512_A-8': 6, # 0.6489729408169956
|
|
97
|
+
'google/bert_uncased_L-8_H-768_A-12': 7, # 0.6705203359541737
|
|
98
|
+
'google/bert_uncased_L-10_H-128_A-2': 8, # 0.6126762064125278
|
|
99
|
+
'google/bert_uncased_L-10_H-256_A-4': 8, # 0.6376350032576573
|
|
100
|
+
'google/bert_uncased_L-10_H-512_A-8': 9, # 0.6579006292799915
|
|
101
|
+
'google/bert_uncased_L-10_H-768_A-12': 8, # 0.6861146692220176
|
|
102
|
+
'google/bert_uncased_L-12_H-128_A-2': 10, # 0.6184105693383591
|
|
103
|
+
'google/bert_uncased_L-12_H-256_A-4': 11, # 0.6374004994430261
|
|
104
|
+
'google/bert_uncased_L-12_H-512_A-8': 10, # 0.65880012149526
|
|
105
|
+
'google/bert_uncased_L-12_H-768_A-12': 9, # 0.675911357700092
|
|
106
|
+
'amazon/bort': 0, # 0.41927911053036643
|
|
107
|
+
'facebook/bart-base': 6, # 0.7122259132414092
|
|
108
|
+
'facebook/bart-large': 10, # 0.7448671872459683
|
|
109
|
+
'facebook/bart-large-cnn': 10, # 0.7393148105835096
|
|
110
|
+
'facebook/bart-large-mnli': 11, # 0.7531665445691358
|
|
111
|
+
'facebook/bart-large-xsum': 9, # 0.7496408866539556
|
|
112
|
+
't5-small': 6, # 0.6813843919496912
|
|
113
|
+
't5-base': 11, # 0.7096044814981418
|
|
114
|
+
't5-large': 23, # 0.7244153820191929
|
|
115
|
+
'vinai/bertweet-base': 9, # 0.6529471006118857
|
|
116
|
+
'microsoft/deberta-base': 9, # 0.7088459455930344
|
|
117
|
+
'microsoft/deberta-base-mnli': 9, # 0.7395257063907247
|
|
118
|
+
'microsoft/deberta-large': 16, # 0.7511806792052013
|
|
119
|
+
'microsoft/deberta-large-mnli': 18, # 0.7736263649679905
|
|
120
|
+
'microsoft/deberta-xlarge': 18, # 0.7568670944373346
|
|
121
|
+
'microsoft/deberta-xlarge-mnli': 40, # 0.7780600929333213
|
|
122
|
+
'YituTech/conv-bert-base': 10, # 0.7058253551080789
|
|
123
|
+
'YituTech/conv-bert-small': 10, # 0.6544473011107349
|
|
124
|
+
'YituTech/conv-bert-medium-small': 9, # 0.6590097075123257
|
|
125
|
+
'microsoft/mpnet-base': 8, # 0.724976539498804
|
|
126
|
+
'squeezebert/squeezebert-uncased': 9, # 0.6543868703018726
|
|
127
|
+
'squeezebert/squeezebert-mnli': 9, # 0.6654799051284791
|
|
128
|
+
'squeezebert/squeezebert-mnli-headless': 9, # 0.6654799051284791
|
|
129
|
+
'tuner007/pegasus_paraphrase': 15, # 0.7188349436772694
|
|
130
|
+
'google/pegasus-large': 8, # 0.63960462272448
|
|
131
|
+
'google/pegasus-xsum': 11, # 0.6836878575233349
|
|
132
|
+
'sshleifer/tiny-mbart': 2, # 0.028246072231946733
|
|
133
|
+
'facebook/mbart-large-cc25': 12, # 0.6582922975802958
|
|
134
|
+
'facebook/mbart-large-50': 12, # 0.6464972230103133
|
|
135
|
+
'facebook/mbart-large-en-ro': 12, # 0.6791285137459857
|
|
136
|
+
'facebook/mbart-large-50-many-to-many-mmt': 12, # 0.6904136529270892
|
|
137
|
+
'facebook/mbart-large-50-one-to-many-mmt': 12, # 0.6847906439540236
|
|
138
|
+
'allenai/led-base-16384': 6, # 0.7122259170564179
|
|
139
|
+
'facebook/blenderbot_small-90M': 7, # 0.6489176335400088
|
|
140
|
+
'facebook/blenderbot-400M-distill': 2, # 0.5874774070540008
|
|
141
|
+
'microsoft/prophetnet-large-uncased': 4, # 0.586496184234925
|
|
142
|
+
'microsoft/prophetnet-large-uncased-cnndm': 7, # 0.6478379437729287
|
|
143
|
+
'SpanBERT/spanbert-base-cased': 8, # 0.6824006863686848
|
|
144
|
+
'SpanBERT/spanbert-large-cased': 17, # 0.705352690855603
|
|
145
|
+
'microsoft/xprophetnet-large-wiki100-cased': 7, # 0.5852499775879524
|
|
146
|
+
'ProsusAI/finbert': 10, # 0.6923213940752796
|
|
147
|
+
'Vamsi/T5_Paraphrase_Paws': 12, # 0.6941611753807352
|
|
148
|
+
'ramsrigouthamg/t5_paraphraser': 11, # 0.7200917597031539
|
|
149
|
+
'microsoft/deberta-v2-xlarge': 10, # 0.7393675784473045
|
|
150
|
+
'microsoft/deberta-v2-xlarge-mnli': 17, # 0.7620620803716714
|
|
151
|
+
'microsoft/deberta-v2-xxlarge': 21, # 0.7520547670281869
|
|
152
|
+
'microsoft/deberta-v2-xxlarge-mnli': 22, # 0.7742603457742682
|
|
153
|
+
'allenai/longformer-base-4096': 7, # 0.7089559593129316
|
|
154
|
+
'allenai/longformer-large-4096': 14, # 0.732408493548181
|
|
155
|
+
'allenai/longformer-large-4096-finetuned-triviaqa': 14, # 0.7365882744744722
|
|
156
|
+
'zhiheng-huang/bert-base-uncased-embedding-relative-key': 4, # 0.5995636595368777
|
|
157
|
+
'zhiheng-huang/bert-base-uncased-embedding-relative-key-query': 7, # 0.6303599452145718
|
|
158
|
+
'zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query': 19, # 0.6896878492850327
|
|
159
|
+
'google/mt5-small': 8, # 0.6401166527273479
|
|
160
|
+
'google/mt5-base': 11, # 0.5663956536597241
|
|
161
|
+
'google/mt5-large': 19, # 0.6430931371732798
|
|
162
|
+
'google/mt5-xl': 24, # 0.6707200963021145
|
|
163
|
+
'google/bigbird-roberta-base': 10, # 0.6695606423502717
|
|
164
|
+
'google/bigbird-roberta-large': 14, # 0.6755874042374509
|
|
165
|
+
'google/bigbird-base-trivia-itc': 8, # 0.6930725491629892
|
|
166
|
+
'princeton-nlp/unsup-simcse-bert-base-uncased': 10, # 0.6703066531921142
|
|
167
|
+
'princeton-nlp/unsup-simcse-bert-large-uncased': 18, # 0.6958302800755326
|
|
168
|
+
'princeton-nlp/unsup-simcse-roberta-base': 8, # 0.6436615893535319
|
|
169
|
+
'princeton-nlp/unsup-simcse-roberta-large': 13, # 0.6812864385585965
|
|
170
|
+
'princeton-nlp/sup-simcse-bert-base-uncased': 10, # 0.7068074935240984
|
|
171
|
+
'princeton-nlp/sup-simcse-bert-large-uncased': 18, # 0.7111049471332378
|
|
172
|
+
'princeton-nlp/sup-simcse-roberta-base': 10, # 0.7253123806661946
|
|
173
|
+
'princeton-nlp/sup-simcse-roberta-large': 16, # 0.7497820277237173
|
|
174
|
+
'dbmdz/bert-base-turkish-cased': 10, # WMT18 seg en-tr 0.5522827687776142
|
|
175
|
+
'dbmdz/distilbert-base-turkish-cased': 4, # WMT18 seg en-tr 0.4742268041237113
|
|
176
|
+
'google/byt5-small': 1, # 0.5100025975052146
|
|
177
|
+
'google/byt5-base': 17, # 0.5810347173565313
|
|
178
|
+
'google/byt5-large': 30, # 0.6151895697554877
|
|
179
|
+
'microsoft/deberta-v3-xsmall': 10, # 0.6941803815412021
|
|
180
|
+
'microsoft/deberta-v3-small': 4, # 0.6651551203179679
|
|
181
|
+
'microsoft/deberta-v3-base': 9, # 0.7261586651018335
|
|
182
|
+
'microsoft/mdeberta-v3-base': 10, # 0.6778713684091584
|
|
183
|
+
'microsoft/deberta-v3-large': 12, # 0.6927693082293821
|
|
184
|
+
'khalidalt/DeBERTa-v3-large-mnli': 18, # 0.7428756686018376
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def sent_encode(tokenizer, sent):
|
|
189
|
+
'Encoding as sentence based on the tokenizer'
|
|
190
|
+
sent = sent.strip()
|
|
191
|
+
if sent == '':
|
|
192
|
+
return tokenizer.build_inputs_with_special_tokens([])
|
|
193
|
+
elif isinstance(tokenizer, GPT2Tokenizer) or isinstance(tokenizer, RobertaTokenizer):
|
|
194
|
+
# for RoBERTa and GPT-2
|
|
195
|
+
if version.parse(trans_version) >= version.parse('4.0.0'):
|
|
196
|
+
if tokenizer.model_max_length > 10000000:
|
|
197
|
+
tokenizer.model_max_length = 512
|
|
198
|
+
|
|
199
|
+
return tokenizer.encode(
|
|
200
|
+
sent,
|
|
201
|
+
add_special_tokens=True,
|
|
202
|
+
add_prefix_space=True,
|
|
203
|
+
max_length=tokenizer.model_max_length,
|
|
204
|
+
truncation=True,
|
|
205
|
+
)
|
|
206
|
+
elif version.parse(trans_version) >= version.parse('3.0.0'):
|
|
207
|
+
return tokenizer.encode(
|
|
208
|
+
sent,
|
|
209
|
+
add_special_tokens=True,
|
|
210
|
+
add_prefix_space=True,
|
|
211
|
+
max_length=tokenizer.max_len,
|
|
212
|
+
truncation=True,
|
|
213
|
+
)
|
|
214
|
+
elif version.parse(trans_version) >= version.parse('2.0.0'):
|
|
215
|
+
return tokenizer.encode(
|
|
216
|
+
sent,
|
|
217
|
+
add_special_tokens=True,
|
|
218
|
+
add_prefix_space=True,
|
|
219
|
+
max_length=tokenizer.max_len,
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
raise NotImplementedError(f'transformers version {trans_version} is not supported')
|
|
223
|
+
else:
|
|
224
|
+
if version.parse(trans_version) >= version.parse('4.0.0'):
|
|
225
|
+
if tokenizer.model_max_length > 10000000:
|
|
226
|
+
tokenizer.model_max_length = 512
|
|
227
|
+
|
|
228
|
+
return tokenizer.encode(
|
|
229
|
+
sent,
|
|
230
|
+
add_special_tokens=True,
|
|
231
|
+
max_length=tokenizer.model_max_length,
|
|
232
|
+
truncation=True,
|
|
233
|
+
)
|
|
234
|
+
elif version.parse(trans_version) >= version.parse('3.0.0'):
|
|
235
|
+
return tokenizer.encode(
|
|
236
|
+
sent,
|
|
237
|
+
add_special_tokens=True,
|
|
238
|
+
max_length=tokenizer.max_len,
|
|
239
|
+
truncation=True,
|
|
240
|
+
)
|
|
241
|
+
elif version.parse(trans_version) >= version.parse('2.0.0'):
|
|
242
|
+
return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len)
|
|
243
|
+
else:
|
|
244
|
+
raise NotImplementedError(f'transformers version {trans_version} is not supported')
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_model(model_type, num_layers, all_layers=None):
|
|
248
|
+
if model_type.startswith('scibert'):
|
|
249
|
+
model = AutoModel.from_pretrained(cache_scibert(model_type))
|
|
250
|
+
elif 't5' in model_type:
|
|
251
|
+
from transformers import T5EncoderModel
|
|
252
|
+
|
|
253
|
+
model = T5EncoderModel.from_pretrained(model_type)
|
|
254
|
+
else:
|
|
255
|
+
model = AutoModel.from_pretrained(model_type)
|
|
256
|
+
model.eval()
|
|
257
|
+
|
|
258
|
+
if hasattr(model, 'decoder') and hasattr(model, 'encoder'):
|
|
259
|
+
model = model.encoder
|
|
260
|
+
|
|
261
|
+
# drop unused layers
|
|
262
|
+
if not all_layers:
|
|
263
|
+
if hasattr(model, 'n_layers'): # xlm
|
|
264
|
+
assert (
|
|
265
|
+
0 <= num_layers <= model.n_layers
|
|
266
|
+
), f'Invalid num_layers: num_layers should be between 0 and {model.n_layers} for {model_type}'
|
|
267
|
+
model.n_layers = num_layers
|
|
268
|
+
elif hasattr(model, 'layer'): # xlnet
|
|
269
|
+
assert (
|
|
270
|
+
0 <= num_layers <= len(model.layer)
|
|
271
|
+
), f'Invalid num_layers: num_layers should be between 0 and {len(model.layer)} for {model_type}'
|
|
272
|
+
model.layer = torch.nn.ModuleList([layer for layer in model.layer[:num_layers]])
|
|
273
|
+
elif hasattr(model, 'encoder'): # albert
|
|
274
|
+
if hasattr(model.encoder, 'albert_layer_groups'):
|
|
275
|
+
assert (
|
|
276
|
+
0 <= num_layers <= model.encoder.config.num_hidden_layers
|
|
277
|
+
), f'Invalid num_layers: num_layers should be between 0 and {model.encoder.config.num_hidden_layers} for {model_type}'
|
|
278
|
+
model.encoder.config.num_hidden_layers = num_layers
|
|
279
|
+
elif hasattr(model.encoder, 'block'): # t5
|
|
280
|
+
assert (
|
|
281
|
+
0 <= num_layers <= len(model.encoder.block)
|
|
282
|
+
), f'Invalid num_layers: num_layers should be between 0 and {len(model.encoder.block)} for {model_type}'
|
|
283
|
+
model.encoder.block = torch.nn.ModuleList([layer for layer in model.encoder.block[:num_layers]])
|
|
284
|
+
else: # bert, roberta
|
|
285
|
+
assert (
|
|
286
|
+
0 <= num_layers <= len(model.encoder.layer)
|
|
287
|
+
), f'Invalid num_layers: num_layers should be between 0 and {len(model.encoder.layer)} for {model_type}'
|
|
288
|
+
model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
|
|
289
|
+
elif hasattr(model, 'transformer'): # bert, roberta
|
|
290
|
+
assert (
|
|
291
|
+
0 <= num_layers <= len(model.transformer.layer)
|
|
292
|
+
), f'Invalid num_layers: num_layers should be between 0 and {len(model.transformer.layer)} for {model_type}'
|
|
293
|
+
model.transformer.layer = torch.nn.ModuleList([layer for layer in model.transformer.layer[:num_layers]])
|
|
294
|
+
elif hasattr(model, 'layers'): # bart
|
|
295
|
+
assert (
|
|
296
|
+
0 <= num_layers <= len(model.layers)
|
|
297
|
+
), f'Invalid num_layers: num_layers should be between 0 and {len(model.layers)} for {model_type}'
|
|
298
|
+
model.layers = torch.nn.ModuleList([layer for layer in model.layers[:num_layers]])
|
|
299
|
+
else:
|
|
300
|
+
raise ValueError('Not supported')
|
|
301
|
+
else:
|
|
302
|
+
if hasattr(model, 'output_hidden_states'):
|
|
303
|
+
model.output_hidden_states = True
|
|
304
|
+
elif hasattr(model, 'encoder'):
|
|
305
|
+
model.encoder.output_hidden_states = True
|
|
306
|
+
elif hasattr(model, 'transformer'):
|
|
307
|
+
model.transformer.output_hidden_states = True
|
|
308
|
+
# else:
|
|
309
|
+
# raise ValueError(f"Not supported model architecture: {model_type}")
|
|
310
|
+
|
|
311
|
+
return model
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def get_tokenizer(model_type, use_fast=False):
|
|
315
|
+
if model_type.startswith('scibert'):
|
|
316
|
+
model_type = cache_scibert(model_type)
|
|
317
|
+
|
|
318
|
+
if version.parse(trans_version) >= version.parse('4.0.0'):
|
|
319
|
+
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=use_fast)
|
|
320
|
+
else:
|
|
321
|
+
assert not use_fast, 'Fast tokenizer is not available for version < 4.0.0'
|
|
322
|
+
tokenizer = AutoTokenizer.from_pretrained(model_type)
|
|
323
|
+
|
|
324
|
+
return tokenizer
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def padding(arr, pad_token, dtype=torch.long):
|
|
328
|
+
lens = torch.LongTensor([len(a) for a in arr])
|
|
329
|
+
max_len = lens.max().item()
|
|
330
|
+
padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token
|
|
331
|
+
mask = torch.zeros(len(arr), max_len, dtype=torch.long)
|
|
332
|
+
for i, a in enumerate(arr):
|
|
333
|
+
padded[i, :lens[i]] = torch.tensor(a, dtype=dtype)
|
|
334
|
+
mask[i, :lens[i]] = 1
|
|
335
|
+
return padded, lens, mask
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def bert_encode(model, x, attention_mask, all_layers=False):
|
|
339
|
+
model.eval()
|
|
340
|
+
with torch.no_grad():
|
|
341
|
+
out = model(x, attention_mask=attention_mask, output_hidden_states=all_layers)
|
|
342
|
+
if all_layers:
|
|
343
|
+
emb = torch.stack(out[-1], dim=2)
|
|
344
|
+
else:
|
|
345
|
+
emb = out[0]
|
|
346
|
+
return emb
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def process(a, tokenizer=None):
|
|
350
|
+
if tokenizer is not None:
|
|
351
|
+
a = sent_encode(tokenizer, a)
|
|
352
|
+
return set(a)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def get_idf_dict(arr, tokenizer, nthreads=4):
|
|
356
|
+
"""
|
|
357
|
+
Returns mapping from word piece index to its inverse document frequency.
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
- :param: `arr` (list of str) : sentences to process.
|
|
362
|
+
- :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
|
|
363
|
+
- :param: `nthreads` (int) : number of CPU threads to use
|
|
364
|
+
"""
|
|
365
|
+
idf_count = Counter()
|
|
366
|
+
num_docs = len(arr)
|
|
367
|
+
|
|
368
|
+
process_partial = partial(process, tokenizer=tokenizer)
|
|
369
|
+
|
|
370
|
+
if nthreads > 0:
|
|
371
|
+
with Pool(nthreads) as p:
|
|
372
|
+
idf_count.update(chain.from_iterable(p.map(process_partial, arr)))
|
|
373
|
+
else:
|
|
374
|
+
idf_count.update(chain.from_iterable(map(process_partial, arr)))
|
|
375
|
+
|
|
376
|
+
idf_dict = defaultdict(lambda: log((num_docs + 1) / (1)))
|
|
377
|
+
idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()})
|
|
378
|
+
return idf_dict
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def collate_idf(arr, tokenizer, idf_dict, device='cuda:0'):
|
|
382
|
+
"""
|
|
383
|
+
Helper function that pads a list of sentences to hvae the same length and
|
|
384
|
+
loads idf score for words in the sentences.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
- :param: `arr` (list of str): sentences to process.
|
|
388
|
+
- :param: `tokenize` : a function that takes a string and return list
|
|
389
|
+
of tokens.
|
|
390
|
+
- :param: `numericalize` : a function that takes a list of tokens and
|
|
391
|
+
return list of token indexes.
|
|
392
|
+
- :param: `idf_dict` (dict): mapping a word piece index to its
|
|
393
|
+
inverse document frequency
|
|
394
|
+
- :param: `pad` (str): the padding token.
|
|
395
|
+
- :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
|
|
396
|
+
"""
|
|
397
|
+
arr = [sent_encode(tokenizer, a) for a in arr]
|
|
398
|
+
|
|
399
|
+
idf_weights = [[idf_dict[i] for i in a] for a in arr]
|
|
400
|
+
|
|
401
|
+
pad_token = tokenizer.pad_token_id
|
|
402
|
+
|
|
403
|
+
padded, lens, mask = padding(arr, pad_token, dtype=torch.long)
|
|
404
|
+
padded_idf, _, _ = padding(idf_weights, 0, dtype=torch.float)
|
|
405
|
+
|
|
406
|
+
padded = padded.to(device=device)
|
|
407
|
+
mask = mask.to(device=device)
|
|
408
|
+
lens = lens.to(device=device)
|
|
409
|
+
return padded, padded_idf, lens, mask
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def get_bert_embedding(
|
|
413
|
+
all_sens,
|
|
414
|
+
model,
|
|
415
|
+
tokenizer,
|
|
416
|
+
idf_dict,
|
|
417
|
+
batch_size=-1,
|
|
418
|
+
device='cuda:0',
|
|
419
|
+
all_layers=False,
|
|
420
|
+
):
|
|
421
|
+
"""
|
|
422
|
+
Compute BERT embedding in batches.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
- :param: `all_sens` (list of str) : sentences to encode.
|
|
426
|
+
- :param: `model` : a BERT model from `pytorch_pretrained_bert`.
|
|
427
|
+
- :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
|
|
428
|
+
- :param: `idf_dict` (dict) : mapping a word piece index to its
|
|
429
|
+
inverse document frequency
|
|
430
|
+
- :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
padded_sens, padded_idf, lens, mask = collate_idf(all_sens, tokenizer, idf_dict, device=device)
|
|
434
|
+
|
|
435
|
+
if batch_size == -1:
|
|
436
|
+
batch_size = len(all_sens)
|
|
437
|
+
|
|
438
|
+
embeddings = []
|
|
439
|
+
with torch.no_grad():
|
|
440
|
+
for i in range(0, len(all_sens), batch_size):
|
|
441
|
+
batch_embedding = bert_encode(
|
|
442
|
+
model,
|
|
443
|
+
padded_sens[i:i + batch_size],
|
|
444
|
+
attention_mask=mask[i:i + batch_size],
|
|
445
|
+
all_layers=all_layers,
|
|
446
|
+
)
|
|
447
|
+
embeddings.append(batch_embedding)
|
|
448
|
+
del batch_embedding
|
|
449
|
+
|
|
450
|
+
total_embedding = torch.cat(embeddings, dim=0)
|
|
451
|
+
|
|
452
|
+
return total_embedding, mask, padded_idf
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def greedy_cos_idf(
|
|
456
|
+
ref_embedding,
|
|
457
|
+
ref_masks,
|
|
458
|
+
ref_idf,
|
|
459
|
+
hyp_embedding,
|
|
460
|
+
hyp_masks,
|
|
461
|
+
hyp_idf,
|
|
462
|
+
all_layers=False,
|
|
463
|
+
):
|
|
464
|
+
"""
|
|
465
|
+
Compute greedy matching based on cosine similarity.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
- :param: `ref_embedding` (torch.Tensor):
|
|
469
|
+
embeddings of reference sentences, BxKxd,
|
|
470
|
+
B: batch size, K: longest length, d: bert dimenison
|
|
471
|
+
- :param: `ref_lens` (list of int): list of reference sentence length.
|
|
472
|
+
- :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for
|
|
473
|
+
reference sentences.
|
|
474
|
+
- :param: `ref_idf` (torch.Tensor): BxK, idf score of each word
|
|
475
|
+
piece in the reference setence
|
|
476
|
+
- :param: `hyp_embedding` (torch.Tensor):
|
|
477
|
+
embeddings of candidate sentences, BxKxd,
|
|
478
|
+
B: batch size, K: longest length, d: bert dimenison
|
|
479
|
+
- :param: `hyp_lens` (list of int): list of candidate sentence length.
|
|
480
|
+
- :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for
|
|
481
|
+
candidate sentences.
|
|
482
|
+
- :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word
|
|
483
|
+
piece in the candidate setence
|
|
484
|
+
"""
|
|
485
|
+
ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
|
|
486
|
+
hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
|
|
487
|
+
|
|
488
|
+
if all_layers:
|
|
489
|
+
B, _, L, D = hyp_embedding.size()
|
|
490
|
+
hyp_embedding = (
|
|
491
|
+
hyp_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, hyp_embedding.size(1), D)
|
|
492
|
+
)
|
|
493
|
+
ref_embedding = (
|
|
494
|
+
ref_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, ref_embedding.size(1), D)
|
|
495
|
+
)
|
|
496
|
+
batch_size = ref_embedding.size(0)
|
|
497
|
+
sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
|
|
498
|
+
masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float())
|
|
499
|
+
if all_layers:
|
|
500
|
+
masks = masks.unsqueeze(0).expand(L, -1, -1, -1).contiguous().view_as(sim)
|
|
501
|
+
else:
|
|
502
|
+
masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim)
|
|
503
|
+
|
|
504
|
+
masks = masks.float().to(sim.device)
|
|
505
|
+
sim = sim * masks
|
|
506
|
+
|
|
507
|
+
word_precision = sim.max(dim=2)[0]
|
|
508
|
+
word_recall = sim.max(dim=1)[0]
|
|
509
|
+
|
|
510
|
+
hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True))
|
|
511
|
+
ref_idf.div_(ref_idf.sum(dim=1, keepdim=True))
|
|
512
|
+
precision_scale = hyp_idf.to(word_precision.device)
|
|
513
|
+
recall_scale = ref_idf.to(word_recall.device)
|
|
514
|
+
if all_layers:
|
|
515
|
+
precision_scale = (precision_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_precision))
|
|
516
|
+
recall_scale = (recall_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_recall))
|
|
517
|
+
P = (word_precision * precision_scale).sum(dim=1)
|
|
518
|
+
R = (word_recall * recall_scale).sum(dim=1)
|
|
519
|
+
F = 2 * P * R / (P + R)
|
|
520
|
+
|
|
521
|
+
hyp_zero_mask = hyp_masks.sum(dim=1).eq(2)
|
|
522
|
+
ref_zero_mask = ref_masks.sum(dim=1).eq(2)
|
|
523
|
+
|
|
524
|
+
if all_layers:
|
|
525
|
+
P = P.view(L, B)
|
|
526
|
+
R = R.view(L, B)
|
|
527
|
+
F = F.view(L, B)
|
|
528
|
+
|
|
529
|
+
if torch.any(hyp_zero_mask):
|
|
530
|
+
print(
|
|
531
|
+
'Warning: Empty candidate sentence detected; setting raw BERTscores to 0.',
|
|
532
|
+
file=sys.stderr,
|
|
533
|
+
)
|
|
534
|
+
P = P.masked_fill(hyp_zero_mask, 0.0)
|
|
535
|
+
R = R.masked_fill(hyp_zero_mask, 0.0)
|
|
536
|
+
|
|
537
|
+
if torch.any(ref_zero_mask):
|
|
538
|
+
print(
|
|
539
|
+
'Warning: Empty reference sentence detected; setting raw BERTScores to 0.',
|
|
540
|
+
file=sys.stderr,
|
|
541
|
+
)
|
|
542
|
+
P = P.masked_fill(ref_zero_mask, 0.0)
|
|
543
|
+
R = R.masked_fill(ref_zero_mask, 0.0)
|
|
544
|
+
|
|
545
|
+
F = F.masked_fill(torch.isnan(F), 0.0)
|
|
546
|
+
|
|
547
|
+
return P, R, F
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def bert_cos_score_idf(
|
|
551
|
+
model,
|
|
552
|
+
refs,
|
|
553
|
+
hyps,
|
|
554
|
+
tokenizer,
|
|
555
|
+
idf_dict,
|
|
556
|
+
verbose=False,
|
|
557
|
+
batch_size=64,
|
|
558
|
+
device='cuda:0',
|
|
559
|
+
all_layers=False,
|
|
560
|
+
):
|
|
561
|
+
"""
|
|
562
|
+
Compute BERTScore.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
- :param: `model` : a BERT model in `pytorch_pretrained_bert`
|
|
566
|
+
- :param: `refs` (list of str): reference sentences
|
|
567
|
+
- :param: `hyps` (list of str): candidate sentences
|
|
568
|
+
- :param: `tokenzier` : a BERT tokenizer corresponds to `model`
|
|
569
|
+
- :param: `idf_dict` : a dictionary mapping a word piece index to its
|
|
570
|
+
inverse document frequency
|
|
571
|
+
- :param: `verbose` (bool): turn on intermediate status update
|
|
572
|
+
- :param: `batch_size` (int): bert score processing batch size
|
|
573
|
+
- :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
|
|
574
|
+
"""
|
|
575
|
+
preds = []
|
|
576
|
+
|
|
577
|
+
def dedup_and_sort(l):
|
|
578
|
+
return sorted(list(set(l)), key=lambda x: len(x.split(' ')), reverse=True)
|
|
579
|
+
|
|
580
|
+
sentences = dedup_and_sort(refs + hyps)
|
|
581
|
+
embs = []
|
|
582
|
+
iter_range = range(0, len(sentences), batch_size)
|
|
583
|
+
if verbose:
|
|
584
|
+
print('computing bert embedding.')
|
|
585
|
+
iter_range = tqdm(iter_range)
|
|
586
|
+
stats_dict = dict()
|
|
587
|
+
for batch_start in iter_range:
|
|
588
|
+
sen_batch = sentences[batch_start:batch_start + batch_size]
|
|
589
|
+
embs, masks, padded_idf = get_bert_embedding(
|
|
590
|
+
sen_batch, model, tokenizer, idf_dict, device=device, all_layers=all_layers
|
|
591
|
+
)
|
|
592
|
+
embs = embs.cpu()
|
|
593
|
+
masks = masks.cpu()
|
|
594
|
+
padded_idf = padded_idf.cpu()
|
|
595
|
+
for i, sen in enumerate(sen_batch):
|
|
596
|
+
sequence_len = masks[i].sum().item()
|
|
597
|
+
emb = embs[i, :sequence_len]
|
|
598
|
+
idf = padded_idf[i, :sequence_len]
|
|
599
|
+
stats_dict[sen] = (emb, idf)
|
|
600
|
+
|
|
601
|
+
def pad_batch_stats(sen_batch, stats_dict, device):
|
|
602
|
+
stats = [stats_dict[s] for s in sen_batch]
|
|
603
|
+
emb, idf = zip(*stats)
|
|
604
|
+
emb = [e.to(device) for e in emb]
|
|
605
|
+
idf = [i.to(device) for i in idf]
|
|
606
|
+
lens = [e.size(0) for e in emb]
|
|
607
|
+
emb_pad = pad_sequence(emb, batch_first=True, padding_value=2.0)
|
|
608
|
+
idf_pad = pad_sequence(idf, batch_first=True)
|
|
609
|
+
|
|
610
|
+
def length_to_mask(lens):
|
|
611
|
+
lens = torch.tensor(lens, dtype=torch.long)
|
|
612
|
+
max_len = max(lens)
|
|
613
|
+
base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len)
|
|
614
|
+
return base < lens.unsqueeze(1)
|
|
615
|
+
|
|
616
|
+
pad_mask = length_to_mask(lens).to(device)
|
|
617
|
+
return emb_pad, pad_mask, idf_pad
|
|
618
|
+
|
|
619
|
+
device = next(model.parameters()).device
|
|
620
|
+
iter_range = range(0, len(refs), batch_size)
|
|
621
|
+
if verbose:
|
|
622
|
+
print('computing greedy matching.')
|
|
623
|
+
iter_range = tqdm(iter_range)
|
|
624
|
+
|
|
625
|
+
with torch.no_grad():
|
|
626
|
+
for batch_start in iter_range:
|
|
627
|
+
batch_refs = refs[batch_start:batch_start + batch_size]
|
|
628
|
+
batch_hyps = hyps[batch_start:batch_start + batch_size]
|
|
629
|
+
ref_stats = pad_batch_stats(batch_refs, stats_dict, device)
|
|
630
|
+
hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device)
|
|
631
|
+
|
|
632
|
+
P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers)
|
|
633
|
+
preds.append(torch.stack((P, R, F1), dim=-1).cpu())
|
|
634
|
+
preds = torch.cat(preds, dim=1 if all_layers else 0)
|
|
635
|
+
return preds
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def get_hash(
|
|
639
|
+
model,
|
|
640
|
+
num_layers,
|
|
641
|
+
idf,
|
|
642
|
+
rescale_with_baseline,
|
|
643
|
+
use_custom_baseline,
|
|
644
|
+
use_fast_tokenizer,
|
|
645
|
+
):
|
|
646
|
+
msg = '{}_L{}{}_version={}(hug_trans={})'.format(
|
|
647
|
+
model, num_layers, '_idf' if idf else '_no-idf', __version__, trans_version
|
|
648
|
+
)
|
|
649
|
+
if rescale_with_baseline:
|
|
650
|
+
if use_custom_baseline:
|
|
651
|
+
msg += '-custom-rescaled'
|
|
652
|
+
else:
|
|
653
|
+
msg += '-rescaled'
|
|
654
|
+
if use_fast_tokenizer:
|
|
655
|
+
msg += '_fast-tokenizer'
|
|
656
|
+
return msg
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def cache_scibert(model_type, cache_folder='~/.cache/torch/transformers'):
|
|
660
|
+
if not model_type.startswith('scibert'):
|
|
661
|
+
return model_type
|
|
662
|
+
|
|
663
|
+
underscore_model_type = model_type.replace('-', '_')
|
|
664
|
+
cache_folder = os.path.abspath(os.path.expanduser(cache_folder))
|
|
665
|
+
filename = os.path.join(cache_folder, underscore_model_type)
|
|
666
|
+
|
|
667
|
+
# download SciBERT models
|
|
668
|
+
if not os.path.exists(filename):
|
|
669
|
+
cmd = f'mkdir -p {cache_folder}; cd {cache_folder};'
|
|
670
|
+
cmd += f'wget {SCIBERT_URL_DICT[model_type]}; tar -xvf {underscore_model_type}.tar;'
|
|
671
|
+
cmd += f'rm -f {underscore_model_type}.tar ; cd {underscore_model_type}; tar -zxvf weights.tar.gz; mv weights/* .;'
|
|
672
|
+
cmd += f'rm -f weights.tar.gz; rmdir weights; mv bert_config.json config.json;'
|
|
673
|
+
print(cmd)
|
|
674
|
+
print(f'downloading {model_type} model')
|
|
675
|
+
os.system(cmd)
|
|
676
|
+
|
|
677
|
+
# fix the missing files in scibert
|
|
678
|
+
json_file = os.path.join(filename, 'special_tokens_map.json')
|
|
679
|
+
if not os.path.exists(json_file):
|
|
680
|
+
with open(json_file, 'w') as f:
|
|
681
|
+
print(
|
|
682
|
+
'{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}',
|
|
683
|
+
file=f,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
json_file = os.path.join(filename, 'added_tokens.json')
|
|
687
|
+
if not os.path.exists(json_file):
|
|
688
|
+
with open(json_file, 'w') as f:
|
|
689
|
+
print('{}', file=f)
|
|
690
|
+
|
|
691
|
+
if 'uncased' in model_type:
|
|
692
|
+
json_file = os.path.join(filename, 'tokenizer_config.json')
|
|
693
|
+
if not os.path.exists(json_file):
|
|
694
|
+
with open(json_file, 'w') as f:
|
|
695
|
+
print('{"do_lower_case": true, "max_len": 512, "init_inputs": []}', file=f)
|
|
696
|
+
|
|
697
|
+
return filename
|