evalscope 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/api/benchmark/__init__.py +9 -1
- evalscope/api/benchmark/adapters/__init__.py +4 -0
- evalscope/api/benchmark/adapters/agent_adapter.py +8 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +75 -4
- evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +5 -2
- evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +12 -10
- evalscope/api/benchmark/adapters/vision_language_adapter.py +8 -0
- evalscope/api/benchmark/benchmark.py +85 -2
- evalscope/api/benchmark/meta.py +10 -1
- evalscope/api/dataset/dataset.py +27 -6
- evalscope/api/dataset/loader.py +8 -3
- evalscope/api/evaluator/cache.py +31 -4
- evalscope/api/evaluator/evaluator.py +5 -0
- evalscope/api/evaluator/state.py +17 -1
- evalscope/api/messages/__init__.py +1 -0
- evalscope/api/messages/chat_message.py +52 -2
- evalscope/api/metric/__init__.py +1 -1
- evalscope/api/metric/metric.py +6 -1
- evalscope/api/metric/scorer.py +15 -7
- evalscope/api/mixin/__init__.py +1 -1
- evalscope/api/mixin/llm_judge_mixin.py +2 -0
- evalscope/api/mixin/sandbox_mixin.py +182 -0
- evalscope/api/model/generate_config.py +10 -6
- evalscope/api/model/model.py +5 -2
- evalscope/api/tool/tool_info.py +1 -1
- evalscope/app/app.py +3 -0
- evalscope/app/ui/multi_model.py +6 -1
- evalscope/app/ui/single_model.py +11 -5
- evalscope/app/utils/data_utils.py +8 -7
- evalscope/app/utils/env_utils.py +12 -0
- evalscope/app/utils/text_utils.py +14 -12
- evalscope/app/utils/visualization.py +2 -2
- evalscope/arguments.py +8 -4
- evalscope/backend/opencompass/backend_manager.py +0 -2
- evalscope/backend/rag_eval/utils/embedding.py +9 -1
- evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
- evalscope/benchmarks/ai2d/ai2d_adapter.py +54 -0
- evalscope/benchmarks/aime/aime24_adapter.py +5 -0
- evalscope/benchmarks/aime/aime25_adapter.py +136 -1
- evalscope/benchmarks/aime/grader.py +307 -0
- evalscope/benchmarks/aime/math_normalize.py +189 -0
- evalscope/benchmarks/amc/amc_adapter.py +51 -0
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -0
- evalscope/benchmarks/bbh/bbh_adapter.py +43 -17
- evalscope/benchmarks/bfcl/{bfcl_adapter.py → v3/bfcl_v3_adapter.py} +131 -19
- evalscope/benchmarks/bfcl/{generation.py → v3/generation.py} +9 -9
- evalscope/benchmarks/bfcl/v3/utils.py +23 -0
- evalscope/benchmarks/bfcl/v4/__init__.py +0 -0
- evalscope/benchmarks/bfcl/v4/bfcl_v4_adapter.py +229 -0
- evalscope/benchmarks/bfcl/v4/utils.py +410 -0
- evalscope/benchmarks/biomix_qa/__init__.py +0 -0
- evalscope/benchmarks/biomix_qa/biomix_qa_adapter.py +36 -0
- evalscope/benchmarks/blink/__init__.py +0 -0
- evalscope/benchmarks/blink/blink_adapter.py +61 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +1 -2
- evalscope/benchmarks/chartqa/__init__.py +0 -0
- evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
- evalscope/benchmarks/chartqa/utils.py +38 -0
- evalscope/benchmarks/coin_flip/__init__.py +0 -0
- evalscope/benchmarks/coin_flip/coin_flip_adapter.py +128 -0
- evalscope/benchmarks/commonsense_qa/__init__.py +0 -0
- evalscope/benchmarks/commonsense_qa/commonsense_qa_adapter.py +32 -0
- evalscope/benchmarks/competition_math/competition_math_adapter.py +5 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +24 -19
- evalscope/benchmarks/docvqa/__init__.py +0 -0
- evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
- evalscope/benchmarks/drivelology/__init__.py +0 -0
- evalscope/benchmarks/drivelology/drivelology_binary_adapter.py +170 -0
- evalscope/benchmarks/drivelology/drivelology_multilabel_adapter.py +254 -0
- evalscope/benchmarks/drivelology/drivelology_selection_adapter.py +49 -0
- evalscope/benchmarks/drivelology/drivelology_writing_adapter.py +218 -0
- evalscope/benchmarks/drop/drop_adapter.py +15 -44
- evalscope/benchmarks/drop/utils.py +97 -0
- evalscope/benchmarks/frames/frames_adapter.py +2 -1
- evalscope/benchmarks/general_arena/general_arena_adapter.py +7 -2
- evalscope/benchmarks/general_arena/utils.py +2 -1
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
- evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +25 -9
- evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
- evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +159 -0
- evalscope/benchmarks/halu_eval/__init__.py +0 -0
- evalscope/benchmarks/halu_eval/halu_eval_adapter.py +128 -0
- evalscope/benchmarks/halu_eval/halu_eval_instructions.py +84 -0
- evalscope/benchmarks/healthbench/__init__.py +0 -0
- evalscope/benchmarks/healthbench/healthbench_adapter.py +282 -0
- evalscope/benchmarks/healthbench/utils.py +102 -0
- evalscope/benchmarks/hle/hle_adapter.py +3 -2
- evalscope/benchmarks/humaneval/humaneval_adapter.py +24 -52
- evalscope/benchmarks/humaneval/utils.py +235 -0
- evalscope/benchmarks/ifeval/instructions_util.py +2 -3
- evalscope/benchmarks/image_edit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
- evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
- evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
- evalscope/benchmarks/infovqa/__init__.py +0 -0
- evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +13 -6
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +66 -54
- evalscope/benchmarks/live_code_bench/sandbox_evaluate_utils.py +220 -0
- evalscope/benchmarks/logi_qa/__int__.py +0 -0
- evalscope/benchmarks/logi_qa/logi_qa_adapter.py +41 -0
- evalscope/benchmarks/math_500/math_500_adapter.py +5 -1
- evalscope/benchmarks/math_qa/__init__.py +0 -0
- evalscope/benchmarks/math_qa/math_qa_adapter.py +35 -0
- evalscope/benchmarks/math_verse/__init__.py +0 -0
- evalscope/benchmarks/math_verse/math_verse_adapter.py +105 -0
- evalscope/benchmarks/math_vision/__init__.py +0 -0
- evalscope/benchmarks/math_vision/math_vision_adapter.py +116 -0
- evalscope/benchmarks/math_vista/__init__.py +0 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +114 -0
- evalscope/benchmarks/med_mcqa/__init__.py +0 -0
- evalscope/benchmarks/med_mcqa/med_mcqa_adapter.py +32 -0
- evalscope/benchmarks/minerva_math/__init__.py +0 -0
- evalscope/benchmarks/minerva_math/minerva_math_adapter.py +53 -0
- evalscope/benchmarks/mm_bench/__init__.py +0 -0
- evalscope/benchmarks/mm_bench/mm_bench_adapter.py +99 -0
- evalscope/benchmarks/mm_star/__init__.py +0 -0
- evalscope/benchmarks/mm_star/mm_star_adapter.py +73 -0
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
- evalscope/benchmarks/mmmu/__init__.py +0 -0
- evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
- evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +124 -0
- evalscope/benchmarks/mri_mcqa/__init__.py +0 -0
- evalscope/benchmarks/mri_mcqa/mri_mcqa_adapter.py +34 -0
- evalscope/benchmarks/multi_if/__init__.py +0 -0
- evalscope/benchmarks/multi_if/ifeval.py +3354 -0
- evalscope/benchmarks/multi_if/metrics.py +120 -0
- evalscope/benchmarks/multi_if/multi_if_adapter.py +161 -0
- evalscope/benchmarks/music_trivia/__init__.py +0 -0
- evalscope/benchmarks/music_trivia/music_trivia_adapter.py +36 -0
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +7 -6
- evalscope/benchmarks/ner/__init__.py +0 -0
- evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
- evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
- evalscope/benchmarks/ner/copious_adapter.py +85 -0
- evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
- evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
- evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
- evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
- evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
- evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
- evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
- evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
- evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
- evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
- evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
- evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
- evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
- evalscope/benchmarks/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench/ocr_bench_adapter.py +101 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/IoUscore_metric.py +87 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/TEDS_metric.py +963 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/page_ocr_metric.py +50 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/parallel.py +46 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/readme.txt +26 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/script.py +481 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_metric.py +179 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/utils.py +433 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_v2/vqa_metric.py +254 -0
- evalscope/benchmarks/olympiad_bench/__init__.py +0 -0
- evalscope/benchmarks/olympiad_bench/olympiad_bench_adapter.py +163 -0
- evalscope/benchmarks/olympiad_bench/utils.py +565 -0
- evalscope/benchmarks/omni_bench/__init__.py +0 -0
- evalscope/benchmarks/omni_bench/omni_bench_adapter.py +86 -0
- evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
- evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
- evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
- evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
- evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
- evalscope/benchmarks/piqa/__init__.py +0 -0
- evalscope/benchmarks/piqa/piqa_adapter.py +32 -0
- evalscope/benchmarks/poly_math/__init__.py +0 -0
- evalscope/benchmarks/poly_math/poly_math_adapter.py +132 -0
- evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
- evalscope/benchmarks/pope/__init__.py +0 -0
- evalscope/benchmarks/pope/pope_adapter.py +112 -0
- evalscope/benchmarks/process_bench/process_bench_adapter.py +1 -0
- evalscope/benchmarks/pumed_qa/__init__.py +0 -0
- evalscope/benchmarks/pumed_qa/pubmed_qa_adapter.py +175 -0
- evalscope/benchmarks/qasc/__init__.py +0 -0
- evalscope/benchmarks/qasc/qasc_adapter.py +35 -0
- evalscope/benchmarks/real_world_qa/__init__.py +0 -0
- evalscope/benchmarks/real_world_qa/real_world_qa_adapter.py +64 -0
- evalscope/benchmarks/sciq/__init__.py +0 -0
- evalscope/benchmarks/sciq/sciq_adapter.py +36 -0
- evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
- evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -1
- evalscope/benchmarks/simple_vqa/__init__.py +0 -0
- evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
- evalscope/benchmarks/siqa/__init__.py +0 -0
- evalscope/benchmarks/siqa/siqa_adapter.py +39 -0
- evalscope/benchmarks/tau_bench/tau2_bench/__init__.py +0 -0
- evalscope/benchmarks/tau_bench/tau2_bench/generation.py +158 -0
- evalscope/benchmarks/tau_bench/tau2_bench/tau2_bench_adapter.py +146 -0
- evalscope/benchmarks/tau_bench/tau_bench/__init__.py +0 -0
- evalscope/benchmarks/tau_bench/{generation.py → tau_bench/generation.py} +1 -1
- evalscope/benchmarks/tau_bench/{tau_bench_adapter.py → tau_bench/tau_bench_adapter.py} +29 -29
- evalscope/benchmarks/text2image/__init__.py +0 -0
- evalscope/benchmarks/{aigc/t2i → text2image}/evalmuse_adapter.py +3 -1
- evalscope/benchmarks/{aigc/t2i → text2image}/genai_bench_adapter.py +2 -2
- evalscope/benchmarks/{aigc/t2i → text2image}/general_t2i_adapter.py +1 -1
- evalscope/benchmarks/{aigc/t2i → text2image}/hpdv2_adapter.py +7 -2
- evalscope/benchmarks/{aigc/t2i → text2image}/tifa_adapter.py +1 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +3 -3
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +1 -2
- evalscope/benchmarks/visu_logic/__init__.py +0 -0
- evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
- evalscope/benchmarks/wmt/__init__.py +0 -0
- evalscope/benchmarks/wmt/wmt24_adapter.py +294 -0
- evalscope/benchmarks/zerobench/__init__.py +0 -0
- evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
- evalscope/cli/start_app.py +7 -1
- evalscope/cli/start_perf.py +7 -1
- evalscope/config.py +103 -18
- evalscope/constants.py +18 -0
- evalscope/evaluator/evaluator.py +138 -82
- evalscope/metrics/bert_score/__init__.py +0 -0
- evalscope/metrics/bert_score/scorer.py +338 -0
- evalscope/metrics/bert_score/utils.py +697 -0
- evalscope/metrics/llm_judge.py +19 -7
- evalscope/metrics/math_parser.py +14 -0
- evalscope/metrics/metric.py +317 -13
- evalscope/metrics/metrics.py +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
- evalscope/models/image_edit_model.py +125 -0
- evalscope/models/model_apis.py +22 -0
- evalscope/models/openai_compatible.py +21 -0
- evalscope/models/text2image_model.py +2 -2
- evalscope/models/utils/openai.py +16 -6
- evalscope/perf/arguments.py +26 -4
- evalscope/perf/benchmark.py +76 -89
- evalscope/perf/http_client.py +31 -16
- evalscope/perf/main.py +15 -2
- evalscope/perf/plugin/api/base.py +9 -7
- evalscope/perf/plugin/api/custom_api.py +13 -58
- evalscope/perf/plugin/api/default_api.py +188 -79
- evalscope/perf/plugin/api/openai_api.py +85 -20
- evalscope/perf/plugin/datasets/base.py +21 -0
- evalscope/perf/plugin/datasets/custom.py +2 -3
- evalscope/perf/plugin/datasets/flickr8k.py +2 -2
- evalscope/perf/plugin/datasets/kontext_bench.py +2 -2
- evalscope/perf/plugin/datasets/line_by_line.py +2 -3
- evalscope/perf/plugin/datasets/longalpaca.py +2 -3
- evalscope/perf/plugin/datasets/openqa.py +2 -4
- evalscope/perf/plugin/datasets/random_dataset.py +1 -3
- evalscope/perf/plugin/datasets/random_vl_dataset.py +2 -2
- evalscope/perf/utils/benchmark_util.py +43 -27
- evalscope/perf/utils/db_util.py +14 -19
- evalscope/perf/utils/local_server.py +3 -44
- evalscope/perf/utils/log_utils.py +21 -6
- evalscope/report/__init__.py +13 -3
- evalscope/report/combinator.py +91 -20
- evalscope/report/generator.py +8 -87
- evalscope/report/report.py +8 -4
- evalscope/run.py +13 -5
- evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
- evalscope/utils/argument_utils.py +1 -1
- evalscope/utils/chat_service.py +1 -1
- evalscope/utils/function_utils.py +249 -12
- evalscope/utils/import_utils.py +73 -1
- evalscope/utils/io_utils.py +132 -7
- evalscope/utils/json_schema.py +25 -2
- evalscope/utils/logger.py +69 -18
- evalscope/utils/model_utils.py +4 -3
- evalscope/utils/multi_choices.py +39 -7
- evalscope/utils/ner.py +377 -0
- evalscope/version.py +2 -2
- {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/METADATA +252 -408
- {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/RECORD +290 -154
- {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/WHEEL +1 -1
- {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/top_level.txt +0 -1
- evalscope/api/mixin/dataset_mixin.py +0 -105
- evalscope/benchmarks/aigc/i2i/general_i2i_adapter.py +0 -44
- tests/__init__.py +0 -1
- tests/aigc/__init__.py +0 -1
- tests/aigc/test_t2i.py +0 -142
- tests/benchmark/__init__.py +0 -1
- tests/benchmark/test_eval.py +0 -386
- tests/cli/__init__.py +0 -1
- tests/cli/test_all.py +0 -229
- tests/cli/test_collection.py +0 -96
- tests/cli/test_custom.py +0 -268
- tests/perf/__init__.py +0 -1
- tests/perf/test_perf.py +0 -176
- tests/rag/test_clip_benchmark.py +0 -90
- tests/rag/test_mteb.py +0 -213
- tests/rag/test_ragas.py +0 -128
- tests/swift/__init__.py +0 -1
- tests/swift/test_run_swift_eval.py +0 -146
- tests/swift/test_run_swift_vlm_eval.py +0 -128
- tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
- tests/test_run_all.py +0 -12
- tests/utils.py +0 -13
- tests/vlm/__init__.py +0 -1
- tests/vlm/test_vlmeval.py +0 -102
- /evalscope/benchmarks/{aigc → aa_lcr}/__init__.py +0 -0
- /evalscope/benchmarks/{aigc/i2i → ai2d}/__init__.py +0 -0
- /evalscope/benchmarks/{aigc/t2i → amc}/__init__.py +0 -0
- {tests/rag → evalscope/benchmarks/bfcl/v3}/__init__.py +0 -0
- {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/entry_points.txt +0 -0
- {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info/licenses}/LICENSE +0 -0
evalscope/evaluator/evaluator.py
CHANGED
|
@@ -8,15 +8,18 @@ and report generation.
|
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
10
|
import os
|
|
11
|
+
import traceback
|
|
11
12
|
from collections import defaultdict
|
|
12
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
13
13
|
from tqdm import tqdm
|
|
14
|
-
from typing import TYPE_CHECKING, Dict, List
|
|
14
|
+
from typing import TYPE_CHECKING, Callable, Dict, List
|
|
15
15
|
|
|
16
16
|
from evalscope.api.dataset import Dataset, DatasetDict, Sample
|
|
17
17
|
from evalscope.api.evaluator import CacheManager, Evaluator, TaskState
|
|
18
18
|
from evalscope.api.metric import AggScore, SampleScore
|
|
19
|
+
from evalscope.constants import HEARTBEAT_INTERVAL_SEC
|
|
19
20
|
from evalscope.report import Report, gen_table
|
|
21
|
+
from evalscope.utils.function_utils import run_in_threads_with_progress
|
|
22
|
+
from evalscope.utils.logger import get_logger
|
|
20
23
|
|
|
21
24
|
if TYPE_CHECKING:
|
|
22
25
|
from evalscope.api.benchmark import DataAdapter
|
|
@@ -24,8 +27,6 @@ if TYPE_CHECKING:
|
|
|
24
27
|
from evalscope.config import TaskConfig
|
|
25
28
|
from evalscope.utils.io_utils import OutputsStructure
|
|
26
29
|
|
|
27
|
-
from evalscope.utils.logger import get_logger
|
|
28
|
-
|
|
29
30
|
logger = get_logger()
|
|
30
31
|
|
|
31
32
|
|
|
@@ -91,17 +92,27 @@ class DefaultEvaluator(Evaluator):
|
|
|
91
92
|
Report: The complete evaluation report containing all metrics and results.
|
|
92
93
|
"""
|
|
93
94
|
# Load the dataset and evaluate each subset
|
|
95
|
+
logger.info(f'Start evaluating benchmark: {self.benchmark_name}')
|
|
94
96
|
dataset_dict = self.benchmark.load_dataset()
|
|
95
97
|
agg_score_dict = defaultdict(list)
|
|
96
98
|
|
|
97
99
|
# Process each subset (e.g., test, validation) independently
|
|
100
|
+
logger.info('Evaluating all subsets of the dataset...')
|
|
98
101
|
for subset, dataset in dataset_dict.items():
|
|
99
|
-
|
|
102
|
+
if len(dataset) == 0:
|
|
103
|
+
logger.info(f'No samples found in subset: {subset}, skipping.')
|
|
104
|
+
continue
|
|
105
|
+
logger.info(f'Evaluating subset: {subset}')
|
|
100
106
|
subset_score = self.evaluate_subset(subset, dataset)
|
|
101
107
|
agg_score_dict[subset] = subset_score
|
|
102
108
|
|
|
103
109
|
# Generate the report based on aggregated scores
|
|
110
|
+
logger.info('Generating report...')
|
|
104
111
|
report = self.get_report(agg_score_dict)
|
|
112
|
+
|
|
113
|
+
# Finalize the evaluation process
|
|
114
|
+
self.finalize()
|
|
115
|
+
logger.info(f'Benchmark {self.benchmark_name} evaluation finished.')
|
|
105
116
|
return report
|
|
106
117
|
|
|
107
118
|
def evaluate_subset(self, subset: str, dataset: Dataset) -> List[AggScore]:
|
|
@@ -121,12 +132,15 @@ class DefaultEvaluator(Evaluator):
|
|
|
121
132
|
List[AggScore]: Aggregated scores for this subset.
|
|
122
133
|
"""
|
|
123
134
|
# Get model predictions for all samples in the subset
|
|
135
|
+
logger.info(f'Getting predictions for subset: {subset}')
|
|
124
136
|
task_states = self.get_answers(subset, dataset)
|
|
125
137
|
|
|
126
138
|
# Calculate evaluation metrics for each prediction
|
|
139
|
+
logger.info(f'Getting reviews for subset: {subset}')
|
|
127
140
|
sample_scores = self.get_reviews(subset, task_states)
|
|
128
141
|
|
|
129
142
|
# Aggregate individual sample scores into subset-level metrics
|
|
143
|
+
logger.info(f'Aggregating scores for subset: {subset}')
|
|
130
144
|
agg_scores = self.benchmark.aggregate_scores(sample_scores=sample_scores)
|
|
131
145
|
return agg_scores
|
|
132
146
|
|
|
@@ -148,51 +162,48 @@ class DefaultEvaluator(Evaluator):
|
|
|
148
162
|
"""
|
|
149
163
|
# Initialize task state list and filter cached predictions if caching is enabled
|
|
150
164
|
if self.use_cache:
|
|
151
|
-
|
|
165
|
+
cached_task_state_list, dataset = self.cache_manager.filter_prediction_cache(subset, dataset)
|
|
152
166
|
else:
|
|
153
|
-
|
|
167
|
+
cached_task_state_list = []
|
|
154
168
|
|
|
155
169
|
# Get output directory for storing model predictions
|
|
156
170
|
model_prediction_dir = os.path.dirname(self.cache_manager.get_prediction_cache_path(subset))
|
|
157
171
|
|
|
158
172
|
# Convert dataset to list for parallel processing
|
|
159
173
|
dataset_list = list(dataset)
|
|
160
|
-
|
|
161
174
|
if not dataset_list:
|
|
162
|
-
return
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
return task_state_list
|
|
175
|
+
return cached_task_state_list
|
|
176
|
+
|
|
177
|
+
logger.info(f'Processing {len(dataset_list)} samples, if data is large, it may take a while.')
|
|
178
|
+
|
|
179
|
+
def worker(sample: Sample) -> TaskState:
|
|
180
|
+
return self._predict_sample(sample, model_prediction_dir)
|
|
181
|
+
|
|
182
|
+
def on_result(sample: Sample, task_state: TaskState) -> None:
|
|
183
|
+
model_result = self.cache_manager.save_prediction_cache(subset, task_state, self.benchmark.save_metadata)
|
|
184
|
+
logger.debug(f'Model result: \n{model_result.pretty_print()}')
|
|
185
|
+
|
|
186
|
+
def on_error(sample: Sample, exc: Exception) -> None:
|
|
187
|
+
tb_str = traceback.format_exc()
|
|
188
|
+
logger.error(f'{sample.model_dump_json(indent=2)} prediction failed: due to {exc}\nTraceback:\n{tb_str}')
|
|
189
|
+
if self.task_config.ignore_errors:
|
|
190
|
+
logger.warning('Error ignored, continuing with next sample.')
|
|
191
|
+
return
|
|
192
|
+
raise exc
|
|
193
|
+
|
|
194
|
+
finished_task_states = run_in_threads_with_progress(
|
|
195
|
+
dataset_list,
|
|
196
|
+
worker,
|
|
197
|
+
desc=f'Predicting[{self.benchmark_name}@{subset}]: ',
|
|
198
|
+
max_workers=self.task_config.eval_batch_size,
|
|
199
|
+
heartbeat_sec=HEARTBEAT_INTERVAL_SEC,
|
|
200
|
+
on_result=on_result,
|
|
201
|
+
on_error=on_error,
|
|
202
|
+
filter_none_results=True,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
logger.info(f'Finished getting predictions for subset: {subset}.')
|
|
206
|
+
return cached_task_state_list + finished_task_states
|
|
196
207
|
|
|
197
208
|
def _predict_sample(self, sample: Sample, model_prediction_dir: str) -> TaskState:
|
|
198
209
|
"""
|
|
@@ -229,50 +240,58 @@ class DefaultEvaluator(Evaluator):
|
|
|
229
240
|
"""
|
|
230
241
|
# Initialize sample score list and filter cached reviews if caching is enabled
|
|
231
242
|
if self.use_cache and not self.task_config.rerun_review:
|
|
232
|
-
|
|
243
|
+
cached_score_list, task_states = self.cache_manager.filter_review_cache(subset, task_states)
|
|
233
244
|
else:
|
|
234
245
|
# Init a clean sample score list
|
|
235
|
-
|
|
246
|
+
cached_score_list = []
|
|
236
247
|
self.cache_manager.delete_review_cache(subset)
|
|
237
248
|
|
|
238
249
|
if not task_states:
|
|
239
|
-
return
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
250
|
+
return cached_score_list
|
|
251
|
+
|
|
252
|
+
logger.info(f'Reviewing {len(task_states)} samples, if data is large, it may take a while.')
|
|
253
|
+
|
|
254
|
+
def worker(task_state: TaskState) -> SampleScore:
|
|
255
|
+
return self._review_task_state(task_state)
|
|
256
|
+
|
|
257
|
+
def on_result(task_state: TaskState, sample_score: SampleScore) -> None:
|
|
258
|
+
review_result = self.cache_manager.save_review_cache(
|
|
259
|
+
subset=subset,
|
|
260
|
+
task_state=task_state,
|
|
261
|
+
sample_score=sample_score,
|
|
262
|
+
save_metadata=self.benchmark.save_metadata
|
|
263
|
+
)
|
|
264
|
+
logger.debug(f'Review result: \n{review_result.pretty_print()}')
|
|
265
|
+
|
|
266
|
+
def on_error(task_state: TaskState, exc: Exception) -> None:
|
|
267
|
+
tb_str = traceback.format_exc()
|
|
268
|
+
logger.error(f'Error when review sample {task_state.sample_id}: due to {exc}\nTraceback:\n{tb_str}')
|
|
269
|
+
if self.task_config.ignore_errors:
|
|
270
|
+
logger.warning('Error ignored, continuing with next sample.')
|
|
271
|
+
return
|
|
272
|
+
raise exc
|
|
273
|
+
|
|
274
|
+
# Run reviews in parallel
|
|
275
|
+
reviewed_scores = run_in_threads_with_progress(
|
|
276
|
+
task_states,
|
|
277
|
+
worker,
|
|
278
|
+
desc=f'Reviewing[{self.benchmark_name}@{subset}]: ',
|
|
279
|
+
max_workers=self.task_config.judge_worker_num,
|
|
280
|
+
heartbeat_sec=HEARTBEAT_INTERVAL_SEC,
|
|
281
|
+
on_error=on_error,
|
|
282
|
+
# Do not persist interim results when batch scoring is enabled
|
|
283
|
+
on_result=None if self.benchmark.use_batch_scoring else on_result,
|
|
284
|
+
filter_none_results=False,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Batch calculate metrics if supported by the benchmark
|
|
288
|
+
if self.benchmark.use_batch_scoring:
|
|
289
|
+
reviewed_scores = self._batch_review_task_states(
|
|
290
|
+
task_states=task_states, reviewed_scores=reviewed_scores, on_result=on_result
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
logger.info(f'Finished reviewing subset: {subset}. Total reviewed: {len(reviewed_scores)}')
|
|
294
|
+
return cached_score_list + reviewed_scores
|
|
276
295
|
|
|
277
296
|
def _review_task_state(self, task_state: TaskState) -> SampleScore:
|
|
278
297
|
"""
|
|
@@ -288,6 +307,40 @@ class DefaultEvaluator(Evaluator):
|
|
|
288
307
|
sample_score = self.benchmark.calculate_metrics(task_state=task_state)
|
|
289
308
|
return sample_score
|
|
290
309
|
|
|
310
|
+
def _batch_review_task_states(
|
|
311
|
+
self, task_states: List[TaskState], reviewed_scores: List[SampleScore],
|
|
312
|
+
on_result: Callable[[TaskState, SampleScore], None]
|
|
313
|
+
) -> List[SampleScore]:
|
|
314
|
+
valid_indices = [i for i, score in enumerate(reviewed_scores) if score is not None]
|
|
315
|
+
if not valid_indices:
|
|
316
|
+
return reviewed_scores
|
|
317
|
+
|
|
318
|
+
task_states = [task_states[i] for i in valid_indices]
|
|
319
|
+
reviewed_scores = [reviewed_scores[i] for i in valid_indices]
|
|
320
|
+
|
|
321
|
+
# Iterate in batches with progress bar
|
|
322
|
+
all_reviewed_scores = []
|
|
323
|
+
total = len(task_states)
|
|
324
|
+
batch_size = self.task_config.judge_worker_num
|
|
325
|
+
with tqdm(total=total, desc='Scoring (batch)', unit='sample') as pbar:
|
|
326
|
+
for start in range(0, total, batch_size):
|
|
327
|
+
# Process batch
|
|
328
|
+
end = min(start + batch_size, total)
|
|
329
|
+
batch_task_states = task_states[start:end]
|
|
330
|
+
batch_scores = reviewed_scores[start:end]
|
|
331
|
+
# Batch calculate metrics
|
|
332
|
+
updated_reviewed_scores = self.benchmark.batch_calculate_metrics(
|
|
333
|
+
task_states=batch_task_states, sample_scores=batch_scores
|
|
334
|
+
)
|
|
335
|
+
# Append results
|
|
336
|
+
all_reviewed_scores.extend(updated_reviewed_scores)
|
|
337
|
+
# Save each result to cache
|
|
338
|
+
for task_state, sample_score in zip(batch_task_states, updated_reviewed_scores):
|
|
339
|
+
on_result(task_state, sample_score)
|
|
340
|
+
|
|
341
|
+
pbar.update(len(batch_task_states))
|
|
342
|
+
return all_reviewed_scores
|
|
343
|
+
|
|
291
344
|
def get_report(self, agg_score_dict: Dict[str, List[AggScore]]) -> Report:
|
|
292
345
|
"""
|
|
293
346
|
Generate a comprehensive evaluation report from aggregated scores.
|
|
@@ -317,7 +370,7 @@ class DefaultEvaluator(Evaluator):
|
|
|
317
370
|
|
|
318
371
|
# Generate and display a summary table of results
|
|
319
372
|
try:
|
|
320
|
-
report_table = gen_table(report_list=[report], add_overall_metric=
|
|
373
|
+
report_table = gen_table(report_list=[report], add_overall_metric=self.benchmark.add_overall_metric)
|
|
321
374
|
logger.info(f'\n{self.benchmark_name} report table:'
|
|
322
375
|
f'\n{report_table} \n')
|
|
323
376
|
except Exception:
|
|
@@ -335,3 +388,6 @@ class DefaultEvaluator(Evaluator):
|
|
|
335
388
|
report.to_json(report_file)
|
|
336
389
|
logger.info(f'Dump report to: {report_file} \n')
|
|
337
390
|
return report
|
|
391
|
+
|
|
392
|
+
def finalize(self, *args, **kwargs):
|
|
393
|
+
self.benchmark.finalize(*args, **kwargs)
|
|
File without changes
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import numpy as np
|
|
3
|
+
import os
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import time
|
|
6
|
+
import torch
|
|
7
|
+
import warnings
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
|
|
10
|
+
from .utils import (
|
|
11
|
+
bert_cos_score_idf,
|
|
12
|
+
get_bert_embedding,
|
|
13
|
+
get_hash,
|
|
14
|
+
get_idf_dict,
|
|
15
|
+
get_model,
|
|
16
|
+
get_tokenizer,
|
|
17
|
+
lang2model,
|
|
18
|
+
model2layers,
|
|
19
|
+
sent_encode,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BERTScorer:
|
|
24
|
+
"""
|
|
25
|
+
BERTScore Scorer Object.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_id_or_path=None,
|
|
31
|
+
model_type=None,
|
|
32
|
+
num_layers=None,
|
|
33
|
+
batch_size=64,
|
|
34
|
+
nthreads=4,
|
|
35
|
+
all_layers=False,
|
|
36
|
+
idf=False,
|
|
37
|
+
idf_sents=None,
|
|
38
|
+
device=None,
|
|
39
|
+
lang=None,
|
|
40
|
+
rescale_with_baseline=False,
|
|
41
|
+
baseline_path=None,
|
|
42
|
+
use_fast_tokenizer=False,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Args:
|
|
46
|
+
- :param: `model_type` (str): contexual embedding model specification, default using the suggested
|
|
47
|
+
model for the target langauge; has to specify at least one of
|
|
48
|
+
`model_type` or `lang`
|
|
49
|
+
- :param: `num_layers` (int): the layer of representation to use.
|
|
50
|
+
default using the number of layer tuned on WMT16 correlation data
|
|
51
|
+
- :param: `verbose` (bool): turn on intermediate status update
|
|
52
|
+
- :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given)
|
|
53
|
+
- :param: `idf_sents` (List of str): list of sentences used to compute the idf weights
|
|
54
|
+
- :param: `device` (str): on which the contextual embedding model will be allocated on.
|
|
55
|
+
If this argument is None, the model lives on cuda:0 if cuda is available.
|
|
56
|
+
- :param: `batch_size` (int): bert score processing batch size
|
|
57
|
+
- :param: `nthreads` (int): number of threads
|
|
58
|
+
- :param: `lang` (str): language of the sentences; has to specify
|
|
59
|
+
at least one of `model_type` or `lang`. `lang` needs to be
|
|
60
|
+
specified when `rescale_with_baseline` is True.
|
|
61
|
+
- :param: `return_hash` (bool): return hash code of the setting
|
|
62
|
+
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
|
|
63
|
+
- :param: `baseline_path` (str): customized baseline file
|
|
64
|
+
- :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
assert (lang is not None or model_type is not None), 'Either lang or model_type should be specified'
|
|
68
|
+
|
|
69
|
+
if rescale_with_baseline:
|
|
70
|
+
assert (lang is not None), 'Need to specify Language when rescaling with baseline'
|
|
71
|
+
|
|
72
|
+
if device is None:
|
|
73
|
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
74
|
+
else:
|
|
75
|
+
self.device = device
|
|
76
|
+
|
|
77
|
+
self._lang = lang
|
|
78
|
+
self._rescale_with_baseline = rescale_with_baseline
|
|
79
|
+
self._idf = idf
|
|
80
|
+
self.batch_size = batch_size
|
|
81
|
+
self.nthreads = nthreads
|
|
82
|
+
self.all_layers = all_layers
|
|
83
|
+
self.model_id_or_path = model_id_or_path
|
|
84
|
+
|
|
85
|
+
if model_type is None:
|
|
86
|
+
lang = lang.lower()
|
|
87
|
+
self._model_type = lang2model[lang]
|
|
88
|
+
else:
|
|
89
|
+
self._model_type = model_type
|
|
90
|
+
|
|
91
|
+
if num_layers is None:
|
|
92
|
+
self._num_layers = model2layers[self.model_type]
|
|
93
|
+
else:
|
|
94
|
+
self._num_layers = num_layers
|
|
95
|
+
|
|
96
|
+
# Building model and tokenizer
|
|
97
|
+
self._use_fast_tokenizer = use_fast_tokenizer
|
|
98
|
+
self._tokenizer = get_tokenizer(self.model_id_or_path, self._use_fast_tokenizer)
|
|
99
|
+
self._model = get_model(self.model_id_or_path, self.num_layers, self.all_layers)
|
|
100
|
+
self._model.to(self.device)
|
|
101
|
+
|
|
102
|
+
self._idf_dict = None
|
|
103
|
+
if idf_sents is not None:
|
|
104
|
+
self.compute_idf(idf_sents)
|
|
105
|
+
|
|
106
|
+
self._baseline_vals = None
|
|
107
|
+
self.baseline_path = baseline_path
|
|
108
|
+
self.use_custom_baseline = self.baseline_path is not None
|
|
109
|
+
if self.baseline_path is None:
|
|
110
|
+
self.baseline_path = os.path.join(
|
|
111
|
+
os.path.dirname(__file__),
|
|
112
|
+
f'rescale_baseline/{self.lang}/{self.model_type}.tsv',
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def lang(self):
|
|
117
|
+
return self._lang
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def idf(self):
|
|
121
|
+
return self._idf
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def model_type(self):
|
|
125
|
+
return self._model_type
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def num_layers(self):
|
|
129
|
+
return self._num_layers
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def rescale_with_baseline(self):
|
|
133
|
+
return self._rescale_with_baseline
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def baseline_vals(self):
|
|
137
|
+
if self._baseline_vals is None:
|
|
138
|
+
if os.path.isfile(self.baseline_path):
|
|
139
|
+
if not self.all_layers:
|
|
140
|
+
self._baseline_vals = torch.from_numpy(
|
|
141
|
+
pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy()
|
|
142
|
+
)[1:].float()
|
|
143
|
+
else:
|
|
144
|
+
self._baseline_vals = (
|
|
145
|
+
torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError(f'Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}')
|
|
149
|
+
|
|
150
|
+
return self._baseline_vals
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def use_fast_tokenizer(self):
|
|
154
|
+
return self._use_fast_tokenizer
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def hash(self):
|
|
158
|
+
return get_hash(
|
|
159
|
+
self.model_type,
|
|
160
|
+
self.num_layers,
|
|
161
|
+
self.idf,
|
|
162
|
+
self.rescale_with_baseline,
|
|
163
|
+
self.use_custom_baseline,
|
|
164
|
+
self.use_fast_tokenizer,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def compute_idf(self, sents):
|
|
168
|
+
"""
|
|
169
|
+
Args:
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
if self._idf_dict is not None:
|
|
173
|
+
warnings.warn('Overwriting the previous importance weights.')
|
|
174
|
+
|
|
175
|
+
self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads)
|
|
176
|
+
|
|
177
|
+
def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
|
|
178
|
+
"""
|
|
179
|
+
Args:
|
|
180
|
+
- :param: `cands` (list of str): candidate sentences
|
|
181
|
+
- :param: `refs` (list of str or list of list of str): reference sentences
|
|
182
|
+
|
|
183
|
+
Return:
|
|
184
|
+
- :param: `(P, R, F)`: each is of shape (N); N = number of input
|
|
185
|
+
candidate reference pairs. if returning hashcode, the
|
|
186
|
+
output will be ((P, R, F), hashcode). If a candidate have
|
|
187
|
+
multiple references, the returned score of this candidate is
|
|
188
|
+
the *best* score among all references.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
ref_group_boundaries = None
|
|
192
|
+
if not isinstance(refs[0], str):
|
|
193
|
+
ref_group_boundaries = []
|
|
194
|
+
ori_cands, ori_refs = cands, refs
|
|
195
|
+
cands, refs = [], []
|
|
196
|
+
count = 0
|
|
197
|
+
for cand, ref_group in zip(ori_cands, ori_refs):
|
|
198
|
+
cands += [cand] * len(ref_group)
|
|
199
|
+
refs += ref_group
|
|
200
|
+
ref_group_boundaries.append((count, count + len(ref_group)))
|
|
201
|
+
count += len(ref_group)
|
|
202
|
+
|
|
203
|
+
if verbose:
|
|
204
|
+
print('calculating scores...')
|
|
205
|
+
start = time.perf_counter()
|
|
206
|
+
|
|
207
|
+
if self.idf:
|
|
208
|
+
assert self._idf_dict, 'IDF weights are not computed'
|
|
209
|
+
idf_dict = self._idf_dict
|
|
210
|
+
else:
|
|
211
|
+
idf_dict = defaultdict(lambda: 1.0)
|
|
212
|
+
idf_dict[self._tokenizer.sep_token_id] = 0
|
|
213
|
+
idf_dict[self._tokenizer.cls_token_id] = 0
|
|
214
|
+
|
|
215
|
+
all_preds = bert_cos_score_idf(
|
|
216
|
+
self._model,
|
|
217
|
+
refs,
|
|
218
|
+
cands,
|
|
219
|
+
self._tokenizer,
|
|
220
|
+
idf_dict,
|
|
221
|
+
verbose=verbose,
|
|
222
|
+
device=self.device,
|
|
223
|
+
batch_size=batch_size,
|
|
224
|
+
all_layers=self.all_layers,
|
|
225
|
+
).cpu()
|
|
226
|
+
|
|
227
|
+
if ref_group_boundaries is not None:
|
|
228
|
+
max_preds = []
|
|
229
|
+
for start, end in ref_group_boundaries:
|
|
230
|
+
max_preds.append(all_preds[start:end].max(dim=0)[0])
|
|
231
|
+
all_preds = torch.stack(max_preds, dim=0)
|
|
232
|
+
|
|
233
|
+
if self.rescale_with_baseline:
|
|
234
|
+
all_preds = (all_preds - self.baseline_vals) / (1 - self.baseline_vals)
|
|
235
|
+
|
|
236
|
+
out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F
|
|
237
|
+
|
|
238
|
+
if verbose:
|
|
239
|
+
time_diff = time.perf_counter() - start
|
|
240
|
+
print(f'done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec')
|
|
241
|
+
|
|
242
|
+
if return_hash:
|
|
243
|
+
out = tuple([out, self.hash])
|
|
244
|
+
|
|
245
|
+
return out
|
|
246
|
+
|
|
247
|
+
def plot_example(self, candidate, reference, fname=''):
|
|
248
|
+
"""
|
|
249
|
+
Args:
|
|
250
|
+
- :param: `candidate` (str): a candidate sentence
|
|
251
|
+
- :param: `reference` (str): a reference sentence
|
|
252
|
+
- :param: `fname` (str): path to save the output plot
|
|
253
|
+
"""
|
|
254
|
+
import matplotlib.pyplot as plt
|
|
255
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
256
|
+
|
|
257
|
+
assert isinstance(candidate, str)
|
|
258
|
+
assert isinstance(reference, str)
|
|
259
|
+
|
|
260
|
+
idf_dict = defaultdict(lambda: 1.0)
|
|
261
|
+
idf_dict[self._tokenizer.sep_token_id] = 0
|
|
262
|
+
idf_dict[self._tokenizer.cls_token_id] = 0
|
|
263
|
+
|
|
264
|
+
hyp_embedding, masks, padded_idf = get_bert_embedding(
|
|
265
|
+
[candidate],
|
|
266
|
+
self._model,
|
|
267
|
+
self._tokenizer,
|
|
268
|
+
idf_dict,
|
|
269
|
+
device=self.device,
|
|
270
|
+
all_layers=False,
|
|
271
|
+
)
|
|
272
|
+
ref_embedding, masks, padded_idf = get_bert_embedding(
|
|
273
|
+
[reference],
|
|
274
|
+
self._model,
|
|
275
|
+
self._tokenizer,
|
|
276
|
+
idf_dict,
|
|
277
|
+
device=self.device,
|
|
278
|
+
all_layers=False,
|
|
279
|
+
)
|
|
280
|
+
ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
|
|
281
|
+
hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
|
|
282
|
+
sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
|
|
283
|
+
sim = sim.squeeze(0).cpu()
|
|
284
|
+
|
|
285
|
+
r_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, reference)][1:-1]
|
|
286
|
+
h_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, candidate)][1:-1]
|
|
287
|
+
sim = sim[1:-1, 1:-1]
|
|
288
|
+
|
|
289
|
+
if self.rescale_with_baseline:
|
|
290
|
+
sim = (sim - self.baseline_vals[2].item()) / (1 - self.baseline_vals[2].item())
|
|
291
|
+
|
|
292
|
+
fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens)))
|
|
293
|
+
im = ax.imshow(sim, cmap='Blues', vmin=0, vmax=1)
|
|
294
|
+
|
|
295
|
+
# We want to show all ticks...
|
|
296
|
+
ax.set_xticks(np.arange(len(r_tokens)))
|
|
297
|
+
ax.set_yticks(np.arange(len(h_tokens)))
|
|
298
|
+
# ... and label them with the respective list entries
|
|
299
|
+
ax.set_xticklabels(r_tokens, fontsize=10)
|
|
300
|
+
ax.set_yticklabels(h_tokens, fontsize=10)
|
|
301
|
+
ax.grid(False)
|
|
302
|
+
plt.xlabel('Reference (tokenized)', fontsize=14)
|
|
303
|
+
plt.ylabel('Candidate (tokenized)', fontsize=14)
|
|
304
|
+
title = 'Similarity Matrix'
|
|
305
|
+
if self.rescale_with_baseline:
|
|
306
|
+
title += ' (after Rescaling)'
|
|
307
|
+
plt.title(title, fontsize=14)
|
|
308
|
+
|
|
309
|
+
divider = make_axes_locatable(ax)
|
|
310
|
+
cax = divider.append_axes('right', size='2%', pad=0.2)
|
|
311
|
+
fig.colorbar(im, cax=cax)
|
|
312
|
+
|
|
313
|
+
# Rotate the tick labels and set their alignment.
|
|
314
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
|
|
315
|
+
|
|
316
|
+
# Loop over data dimensions and create text annotations.
|
|
317
|
+
for i in range(len(h_tokens)):
|
|
318
|
+
for j in range(len(r_tokens)):
|
|
319
|
+
text = ax.text(
|
|
320
|
+
j,
|
|
321
|
+
i,
|
|
322
|
+
'{:.3f}'.format(sim[i, j].item()),
|
|
323
|
+
ha='center',
|
|
324
|
+
va='center',
|
|
325
|
+
color='k' if sim[i, j].item() < 0.5 else 'w',
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
fig.tight_layout()
|
|
329
|
+
if fname != '':
|
|
330
|
+
plt.savefig(fname, dpi=100)
|
|
331
|
+
print('Saved figure to file: ', fname)
|
|
332
|
+
plt.show()
|
|
333
|
+
|
|
334
|
+
def __repr__(self):
|
|
335
|
+
return f'{self.__class__.__name__}(hash={self.hash}, batch_size={self.batch_size}, nthreads={self.nthreads})'
|
|
336
|
+
|
|
337
|
+
def __str__(self):
|
|
338
|
+
return self.__repr__()
|