evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +4 -1
- evalscope/api/benchmark/__init__.py +3 -0
- evalscope/api/benchmark/adapters/__init__.py +5 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
- evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
- evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
- evalscope/api/benchmark/benchmark.py +356 -0
- evalscope/api/benchmark/meta.py +121 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +349 -0
- evalscope/api/dataset/loader.py +262 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +378 -0
- evalscope/api/evaluator/evaluator.py +56 -0
- evalscope/api/evaluator/state.py +275 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +12 -0
- evalscope/api/messages/chat_message.py +243 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +55 -0
- evalscope/api/metric/scorer.py +113 -0
- evalscope/api/mixin/__init__.py +1 -0
- evalscope/api/mixin/llm_judge_mixin.py +168 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +155 -0
- evalscope/api/model/model.py +386 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/app.py +3 -0
- evalscope/app/ui/app_ui.py +2 -1
- evalscope/app/ui/multi_model.py +50 -25
- evalscope/app/ui/single_model.py +26 -14
- evalscope/app/utils/data_utils.py +43 -27
- evalscope/app/utils/env_utils.py +12 -0
- evalscope/app/utils/text_utils.py +14 -14
- evalscope/app/utils/visualization.py +9 -4
- evalscope/arguments.py +7 -10
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +6 -5
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
- evalscope/backend/rag_eval/utils/embedding.py +10 -1
- evalscope/backend/rag_eval/utils/llm.py +13 -12
- evalscope/benchmarks/__init__.py +0 -2
- evalscope/benchmarks/aime/aime24_adapter.py +38 -40
- evalscope/benchmarks/aime/aime25_adapter.py +34 -40
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
- evalscope/benchmarks/arc/arc_adapter.py +34 -147
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
- evalscope/benchmarks/arena_hard/utils.py +37 -1
- evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
- evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
- evalscope/benchmarks/bfcl/generation.py +222 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
- evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
- evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
- evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
- evalscope/benchmarks/docmath/utils.py +4 -5
- evalscope/benchmarks/drop/drop_adapter.py +88 -40
- evalscope/benchmarks/frames/frames_adapter.py +136 -52
- evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
- evalscope/benchmarks/general_arena/utils.py +23 -27
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
- evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
- evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
- evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
- evalscope/benchmarks/hle/hle_adapter.py +127 -93
- evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
- evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
- evalscope/benchmarks/ifeval/instructions.py +109 -64
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/instructions_util.py +2 -3
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
- evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
- evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
- evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
- evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
- evalscope/benchmarks/math_vista/__init__.py +0 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
- evalscope/benchmarks/mmmu/__init__.py +0 -0
- evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
- evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
- evalscope/benchmarks/musr/musr_adapter.py +33 -64
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
- evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
- evalscope/benchmarks/race/race_adapter.py +33 -119
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
- evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
- evalscope/benchmarks/super_gpqa/utils.py +2 -1
- evalscope/benchmarks/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
- evalscope/benchmarks/text2image/__init__.py +0 -0
- evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
- evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
- evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
- evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
- evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
- evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_app.py +7 -1
- evalscope/cli/start_perf.py +7 -1
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +2 -10
- evalscope/collections/sampler.py +10 -10
- evalscope/collections/schema.py +13 -11
- evalscope/config.py +157 -57
- evalscope/constants.py +37 -61
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +275 -419
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +13 -13
- evalscope/metrics/llm_judge.py +47 -33
- evalscope/metrics/math_parser.py +27 -22
- evalscope/metrics/metric.py +307 -0
- evalscope/metrics/metrics.py +22 -18
- evalscope/metrics/t2v_metrics/__init__.py +0 -52
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
- evalscope/models/__init__.py +6 -29
- evalscope/models/image_edit_model.py +125 -0
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +67 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +126 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +701 -0
- evalscope/perf/benchmark.py +4 -1
- evalscope/perf/http_client.py +4 -2
- evalscope/perf/plugin/api/custom_api.py +5 -4
- evalscope/perf/plugin/api/openai_api.py +11 -9
- evalscope/perf/plugin/datasets/custom.py +2 -1
- evalscope/perf/plugin/datasets/flickr8k.py +1 -1
- evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
- evalscope/perf/plugin/datasets/line_by_line.py +2 -1
- evalscope/perf/plugin/datasets/longalpaca.py +2 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -2
- evalscope/perf/utils/benchmark_util.py +15 -10
- evalscope/perf/utils/db_util.py +9 -6
- evalscope/perf/utils/local_server.py +11 -3
- evalscope/perf/utils/rich_display.py +16 -10
- evalscope/report/__init__.py +2 -3
- evalscope/report/combinator.py +18 -12
- evalscope/report/generator.py +51 -35
- evalscope/report/{utils.py → report.py} +8 -6
- evalscope/run.py +33 -47
- evalscope/summarizer.py +1 -1
- evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
- evalscope/utils/__init__.py +21 -2
- evalscope/utils/chat_service.py +3 -2
- evalscope/utils/deprecation_utils.py +12 -1
- evalscope/utils/function_utils.py +29 -0
- evalscope/utils/import_utils.py +23 -1
- evalscope/utils/io_utils.py +142 -6
- evalscope/utils/json_schema.py +208 -0
- evalscope/utils/logger.py +51 -12
- evalscope/utils/model_utils.py +11 -7
- evalscope/utils/multi_choices.py +288 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
- tests/benchmark/test_eval.py +385 -0
- tests/benchmark/test_image_edit.py +65 -0
- tests/{aigc → benchmark}/test_t2i.py +22 -4
- tests/benchmark/test_vlm.py +80 -0
- tests/cli/test_all.py +85 -47
- tests/cli/test_collection.py +20 -8
- tests/cli/test_custom.py +22 -15
- tests/cli/test_reasoning.py +81 -0
- tests/common.py +73 -0
- tests/perf/test_perf.py +4 -2
- tests/rag/test_clip_benchmark.py +0 -2
- evalscope/benchmarks/aigc/t2i/base.py +0 -56
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -81
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -528
- evalscope/benchmarks/filters.py +0 -59
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/process_bench/critique_template.txt +0 -13
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/benchmarks/utils.py +0 -60
- evalscope/collections/evaluator.py +0 -375
- evalscope/metrics/completion_parsers.py +0 -227
- evalscope/metrics/named_metrics.py +0 -55
- evalscope/models/adapters/__init__.py +0 -14
- evalscope/models/adapters/base_adapter.py +0 -84
- evalscope/models/adapters/bfcl_adapter.py +0 -246
- evalscope/models/adapters/chat_adapter.py +0 -207
- evalscope/models/adapters/choice_adapter.py +0 -222
- evalscope/models/adapters/custom_adapter.py +0 -71
- evalscope/models/adapters/server_adapter.py +0 -236
- evalscope/models/adapters/t2i_adapter.py +0 -79
- evalscope/models/adapters/tau_bench_adapter.py +0 -189
- evalscope/models/custom/__init__.py +0 -4
- evalscope/models/custom/custom_model.py +0 -50
- evalscope/models/custom/dummy_model.py +0 -99
- evalscope/models/local_model.py +0 -128
- evalscope/models/register.py +0 -41
- tests/cli/test_run.py +0 -489
- /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
- /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
- /tests/{aigc → benchmark}/__init__.py +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from evalscope.api.metric import Aggregator, AggScore, Metric, SampleScore, T2IMetric
|
|
5
|
+
from evalscope.api.registry import register_aggregation, register_metric
|
|
6
|
+
from .metrics import mean
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@register_metric(name='exact_match')
|
|
10
|
+
class ExactMatch(Metric):
|
|
11
|
+
|
|
12
|
+
def apply(self, predictions, references):
|
|
13
|
+
return [float(prediction == reference) for prediction, reference in zip(predictions, references)]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_metric(name='acc')
|
|
17
|
+
class Accuracy(ExactMatch):
|
|
18
|
+
|
|
19
|
+
def __init__(self, allow_inclusion: bool = False, numeric: bool = False):
|
|
20
|
+
self.allow_inclusion = allow_inclusion
|
|
21
|
+
self.numeric = numeric
|
|
22
|
+
|
|
23
|
+
def apply(self, predictions, references):
|
|
24
|
+
if self.allow_inclusion:
|
|
25
|
+
results = []
|
|
26
|
+
for prediction, reference in zip(predictions, references):
|
|
27
|
+
if prediction and prediction in reference:
|
|
28
|
+
results.append(1.0)
|
|
29
|
+
else:
|
|
30
|
+
results.append(0.0)
|
|
31
|
+
return results
|
|
32
|
+
elif self.numeric:
|
|
33
|
+
from .math_parser import extract_answer, math_equal, strip_answer_string
|
|
34
|
+
|
|
35
|
+
results = []
|
|
36
|
+
for prediction, reference in zip(predictions, references):
|
|
37
|
+
pred_answer = strip_answer_string(extract_answer(prediction))
|
|
38
|
+
ref_answer = strip_answer_string(reference)
|
|
39
|
+
results.append(float(math_equal(pred_answer, ref_answer)))
|
|
40
|
+
|
|
41
|
+
return results
|
|
42
|
+
else:
|
|
43
|
+
return super().apply(predictions, references)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@register_metric(name='numeric_match')
|
|
47
|
+
class NumericMatch(Metric):
|
|
48
|
+
|
|
49
|
+
def apply(self, predictions, references):
|
|
50
|
+
return [float(prediction == reference) for prediction, reference in zip(predictions, references)]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_metric(name='math_acc')
|
|
54
|
+
class MathAcc(Metric):
|
|
55
|
+
|
|
56
|
+
def apply(self, predictions, references):
|
|
57
|
+
from .math_parser import extract_answer, math_equal, strip_answer_string
|
|
58
|
+
|
|
59
|
+
results = []
|
|
60
|
+
for prediction, reference in zip(predictions, references):
|
|
61
|
+
pred_answer = strip_answer_string(extract_answer(prediction))
|
|
62
|
+
ref_answer = strip_answer_string(reference)
|
|
63
|
+
results.append(float(math_equal(pred_answer, ref_answer)))
|
|
64
|
+
|
|
65
|
+
return results
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@register_metric(name='multi_choice_acc')
|
|
69
|
+
class MultiChoiceAcc(Metric):
|
|
70
|
+
|
|
71
|
+
def apply(self, predictions, references):
|
|
72
|
+
"""
|
|
73
|
+
Calculate accuracy for multiple-choice questions.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
predictions (List[str]): List of predicted answers.
|
|
77
|
+
references (List[str]): List of correct answers.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List[float]: List of accuracy scores (1.0 for correct, 0.0 for incorrect).
|
|
81
|
+
"""
|
|
82
|
+
res = []
|
|
83
|
+
for prediction, reference in zip(predictions, references):
|
|
84
|
+
prediction = set(prediction.strip().upper())
|
|
85
|
+
reference = set(reference.strip().upper())
|
|
86
|
+
# if the prediction has answer that not in reference, it is wrong
|
|
87
|
+
if not prediction.issubset(reference):
|
|
88
|
+
res.append(0.0)
|
|
89
|
+
continue
|
|
90
|
+
common = prediction.intersection(reference)
|
|
91
|
+
res.append(len(common) / len(reference) if reference else 0.0)
|
|
92
|
+
return res
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# ##################
|
|
96
|
+
# T2I Metrics ######
|
|
97
|
+
####################
|
|
98
|
+
@register_metric(name='VQAScore')
|
|
99
|
+
class VQAScore(T2IMetric):
|
|
100
|
+
|
|
101
|
+
def _init_once(self, model: str = 'clip-flant5-xxl'):
|
|
102
|
+
from .t2v_metrics.vqascore import VQAScore
|
|
103
|
+
self.model = VQAScore(model=model)
|
|
104
|
+
|
|
105
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
106
|
+
return self.model(images, texts, **kwargs)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@register_metric(name='PickScore')
|
|
110
|
+
class PickScore(T2IMetric):
|
|
111
|
+
|
|
112
|
+
def _init_once(self, model: str = 'pickscore-v1'):
|
|
113
|
+
from .t2v_metrics.clipscore import CLIPScore
|
|
114
|
+
self.model = CLIPScore(model=model)
|
|
115
|
+
|
|
116
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
117
|
+
return self.model(images, texts, **kwargs)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@register_metric(name='CLIPScore')
|
|
121
|
+
class CLIPScore(T2IMetric):
|
|
122
|
+
|
|
123
|
+
def _init_once(self, model: str = 'openai:ViT-L-14-336'):
|
|
124
|
+
from .t2v_metrics.clipscore import CLIPScore
|
|
125
|
+
self.model = CLIPScore(model=model)
|
|
126
|
+
|
|
127
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
128
|
+
return self.model(images, texts, **kwargs)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@register_metric(name='BLIPv2Score')
|
|
132
|
+
class BLIPv2Score(T2IMetric):
|
|
133
|
+
|
|
134
|
+
def _init_once(self, model: str = 'blip2-itm'):
|
|
135
|
+
from .t2v_metrics.itmscore import ITMScore
|
|
136
|
+
self.model = ITMScore(model=model)
|
|
137
|
+
|
|
138
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
139
|
+
return self.model(images, texts, **kwargs)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@register_metric(name='HPSv2Score')
|
|
143
|
+
class HPSv2Score(T2IMetric):
|
|
144
|
+
|
|
145
|
+
def _init_once(self, model: str = 'hpsv2'):
|
|
146
|
+
from .t2v_metrics.clipscore import CLIPScore
|
|
147
|
+
self.model = CLIPScore(model=model)
|
|
148
|
+
|
|
149
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
150
|
+
return self.model(images, texts, **kwargs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@register_metric(name='HPSv2.1Score')
|
|
154
|
+
class HPSv2_1Score(T2IMetric):
|
|
155
|
+
|
|
156
|
+
def _init_once(self, model: str = 'hpsv2.1'):
|
|
157
|
+
from .t2v_metrics.clipscore import CLIPScore
|
|
158
|
+
self.model = CLIPScore(model=model)
|
|
159
|
+
|
|
160
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
161
|
+
return self.model(images, texts, **kwargs)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@register_metric(name='ImageRewardScore')
|
|
165
|
+
class ImageRewardScore(T2IMetric):
|
|
166
|
+
|
|
167
|
+
def _init_once(self, model: str = 'image-reward-v1'):
|
|
168
|
+
from .t2v_metrics.itmscore import ITMScore
|
|
169
|
+
self.model = ITMScore(model=model)
|
|
170
|
+
|
|
171
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
172
|
+
return self.model(images, texts, **kwargs)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@register_metric(name='FGA_BLIP2Score')
|
|
176
|
+
class FGA_BLIP2Score(T2IMetric):
|
|
177
|
+
|
|
178
|
+
def _init_once(self, model: str = 'fga_blip2'):
|
|
179
|
+
from .t2v_metrics.itmscore import ITMScore
|
|
180
|
+
self.model = ITMScore(model=model)
|
|
181
|
+
|
|
182
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
183
|
+
return self.model(images, texts, **kwargs)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@register_metric(name='MPS')
|
|
187
|
+
class MPS(T2IMetric):
|
|
188
|
+
|
|
189
|
+
def _init_once(self, model: str = 'mps'):
|
|
190
|
+
from .t2v_metrics.clipscore import CLIPScore
|
|
191
|
+
self.model = CLIPScore(model=model)
|
|
192
|
+
|
|
193
|
+
def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
|
|
194
|
+
return self.model(images, texts, **kwargs)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# ##################
|
|
198
|
+
# Aggregators ######
|
|
199
|
+
# ##################
|
|
200
|
+
@register_aggregation(name='mean')
|
|
201
|
+
class Mean(Aggregator):
|
|
202
|
+
|
|
203
|
+
name = 'mean'
|
|
204
|
+
|
|
205
|
+
def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
|
|
206
|
+
"""Aggregate scores by computing the mean for each metric.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
scores: List of sample scores to aggregate
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List of aggregated scores with mean values
|
|
213
|
+
"""
|
|
214
|
+
if not scores:
|
|
215
|
+
return []
|
|
216
|
+
|
|
217
|
+
# Group score values by metric name
|
|
218
|
+
metric_values = defaultdict(list)
|
|
219
|
+
metric_sample_ids = defaultdict(list)
|
|
220
|
+
|
|
221
|
+
for score in scores:
|
|
222
|
+
|
|
223
|
+
for metric_name, value in score.score.value.items():
|
|
224
|
+
metric_values[metric_name].append(value)
|
|
225
|
+
metric_sample_ids[metric_name].append(score.sample_id)
|
|
226
|
+
|
|
227
|
+
# Calculate mean for each metric
|
|
228
|
+
aggregated_scores = []
|
|
229
|
+
for metric_name, values in metric_values.items():
|
|
230
|
+
if values: # Only process non-empty value lists
|
|
231
|
+
aggregated_scores.append(
|
|
232
|
+
AggScore(
|
|
233
|
+
score=mean(values),
|
|
234
|
+
metric_name=metric_name,
|
|
235
|
+
aggregation_name=self.name,
|
|
236
|
+
num=len(values),
|
|
237
|
+
ids=metric_sample_ids[metric_name]
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return aggregated_scores
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@register_aggregation(name='pass_at_k')
|
|
245
|
+
class PassAtK(Aggregator):
|
|
246
|
+
|
|
247
|
+
def __init__(self, k: int = 1):
|
|
248
|
+
self.k = k
|
|
249
|
+
self.name = f'pass_at_{k}'
|
|
250
|
+
|
|
251
|
+
def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
|
|
252
|
+
"""Aggregate scores by computing the pass@k for each metric using group_id.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
scores: List of sample scores to aggregate
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of aggregated scores with pass@k values
|
|
259
|
+
"""
|
|
260
|
+
if not scores:
|
|
261
|
+
return []
|
|
262
|
+
|
|
263
|
+
import numpy as np
|
|
264
|
+
|
|
265
|
+
from .metrics import calculate_pass_at_k
|
|
266
|
+
|
|
267
|
+
# Group scores by metric name and group_id
|
|
268
|
+
metric_groups = defaultdict(lambda: defaultdict(list))
|
|
269
|
+
|
|
270
|
+
for score in scores:
|
|
271
|
+
group_id = getattr(score, 'group_id', score.sample_id) # fallback to sample_id if no group_id
|
|
272
|
+
|
|
273
|
+
for metric_name, value in score.score.value.items():
|
|
274
|
+
metric_groups[metric_name][group_id].append(float(value))
|
|
275
|
+
|
|
276
|
+
# Calculate pass@k for each metric
|
|
277
|
+
aggregated_scores = []
|
|
278
|
+
for metric_name, groups in metric_groups.items():
|
|
279
|
+
if not groups:
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
# Calculate pass@k for each group (problem)
|
|
283
|
+
num_samples = []
|
|
284
|
+
num_correct = []
|
|
285
|
+
all_sample_ids = []
|
|
286
|
+
|
|
287
|
+
for group_id, group_values in groups.items():
|
|
288
|
+
num_samples.append(len(group_values))
|
|
289
|
+
num_correct.append(sum(group_values)) # count how many passed in this group
|
|
290
|
+
all_sample_ids.extend([f'{group_id}_{i}' for i in range(len(group_values))])
|
|
291
|
+
|
|
292
|
+
if num_samples:
|
|
293
|
+
# Use the calculate_pass_at_k function from metrics
|
|
294
|
+
pass_at_k_values = calculate_pass_at_k(num_samples, num_correct, self.k)
|
|
295
|
+
overall_pass_at_k = float(np.mean(pass_at_k_values))
|
|
296
|
+
|
|
297
|
+
aggregated_scores.append(
|
|
298
|
+
AggScore(
|
|
299
|
+
score=overall_pass_at_k,
|
|
300
|
+
metric_name=f'pass@{self.k}',
|
|
301
|
+
aggregation_name='',
|
|
302
|
+
num=len(scores),
|
|
303
|
+
ids=all_sample_ids
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return aggregated_scores
|
evalscope/metrics/metrics.py
CHANGED
|
@@ -191,7 +191,7 @@ def bleu(items):
|
|
|
191
191
|
return sacrebleu.corpus_bleu(preds, refs).score
|
|
192
192
|
|
|
193
193
|
|
|
194
|
-
def bleu_ngram_one_sample(predict, reference):
|
|
194
|
+
def bleu_ngram_one_sample(predict: str, reference: str):
|
|
195
195
|
"""
|
|
196
196
|
Calculate BLEU-1, BLEU-2, BLEU-3, and BLEU-4 scores
|
|
197
197
|
|
|
@@ -322,11 +322,11 @@ def bootstrap_stderr(f, xs, iters):
|
|
|
322
322
|
|
|
323
323
|
print('bootstrapping for stddev:', f.__name__)
|
|
324
324
|
for bootstrap in tqdm(
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
325
|
+
pool.imap(
|
|
326
|
+
_bootstrap_internal(f, chunk_size),
|
|
327
|
+
[(i, xs) for i in range(iters // chunk_size)],
|
|
328
|
+
),
|
|
329
|
+
total=iters // chunk_size,
|
|
330
330
|
):
|
|
331
331
|
# sample w replacement
|
|
332
332
|
res.extend(bootstrap)
|
|
@@ -361,15 +361,17 @@ def yesno(x):
|
|
|
361
361
|
return 'no'
|
|
362
362
|
|
|
363
363
|
|
|
364
|
-
def compute_elo(
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
364
|
+
def compute_elo(
|
|
365
|
+
battles,
|
|
366
|
+
col_model_a='model_a',
|
|
367
|
+
col_model_b='model_b',
|
|
368
|
+
col_win='win',
|
|
369
|
+
tie_values=['tie', 'tie (bothbad)'],
|
|
370
|
+
k=32,
|
|
371
|
+
scale=400,
|
|
372
|
+
base=10,
|
|
373
|
+
init_rating=1000
|
|
374
|
+
):
|
|
373
375
|
rating = defaultdict(lambda: init_rating)
|
|
374
376
|
|
|
375
377
|
for rd, model_a, model_b, win in battles[[col_model_a, col_model_b, col_win]].itertuples():
|
|
@@ -434,9 +436,11 @@ def calculate_arc_accuracy(question_answers: Dict[str, str], predictions: Dict[s
|
|
|
434
436
|
return score / len(question_answers)
|
|
435
437
|
|
|
436
438
|
|
|
437
|
-
def calculate_pass_at_k(
|
|
438
|
-
|
|
439
|
-
|
|
439
|
+
def calculate_pass_at_k(
|
|
440
|
+
num_samples: Union[int, List[int], np.ndarray],
|
|
441
|
+
num_correct: Union[List[int], np.ndarray],
|
|
442
|
+
k: int = 1
|
|
443
|
+
) -> np.ndarray:
|
|
440
444
|
"""
|
|
441
445
|
Estimates pass@k of each problem and returns them in an array.
|
|
442
446
|
Examples:
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
def clip_flant5_score():
|
|
2
|
-
from .vqascore import VQAScore
|
|
3
|
-
clip_flant5_score = VQAScore(model='clip-flant5-xxl')
|
|
4
|
-
return clip_flant5_score
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def pick_score():
|
|
8
|
-
from .clipscore import CLIPScore
|
|
9
|
-
pick_score = CLIPScore(model='pickscore-v1')
|
|
10
|
-
return pick_score
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def clip_score():
|
|
14
|
-
from .clipscore import CLIPScore
|
|
15
|
-
clip_score = CLIPScore(model='openai:ViT-L-14-336')
|
|
16
|
-
return clip_score
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def blip2_score():
|
|
20
|
-
from .itmscore import ITMScore
|
|
21
|
-
blip_itm_score = ITMScore(model='blip2-itm')
|
|
22
|
-
return blip_itm_score
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def hpsv2_score():
|
|
26
|
-
from .clipscore import CLIPScore
|
|
27
|
-
hpsv2_score = CLIPScore(model='hpsv2')
|
|
28
|
-
return hpsv2_score
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def hpsv2_1_score():
|
|
32
|
-
from .clipscore import CLIPScore
|
|
33
|
-
hpsv2_1_score = CLIPScore(model='hpsv2.1')
|
|
34
|
-
return hpsv2_1_score
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def image_reward_score():
|
|
38
|
-
from .itmscore import ITMScore
|
|
39
|
-
image_reward_score = ITMScore(model='image-reward-v1')
|
|
40
|
-
return image_reward_score
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def fga_blip2_score():
|
|
44
|
-
from .itmscore import ITMScore
|
|
45
|
-
fga_blip2_score = ITMScore(model='fga_blip2')
|
|
46
|
-
return fga_blip2_score
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def mps_score():
|
|
50
|
-
from .clipscore import CLIPScore
|
|
51
|
-
mps_score = CLIPScore(model='mps')
|
|
52
|
-
return mps_score
|
|
@@ -27,7 +27,8 @@ class XCLIPModel(HFCLIPModel):
|
|
|
27
27
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
28
28
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
29
29
|
output_hidden_states = (
|
|
30
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
30
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
31
|
+
)
|
|
31
32
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
32
33
|
|
|
33
34
|
text_outputs = self.text_model(
|
|
@@ -63,7 +64,8 @@ class XCLIPModel(HFCLIPModel):
|
|
|
63
64
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
64
65
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
65
66
|
output_hidden_states = (
|
|
66
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
67
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
68
|
+
)
|
|
67
69
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
68
70
|
|
|
69
71
|
vision_outputs = self.vision_model(
|
|
@@ -178,15 +178,9 @@ class ParallelTransformerBlock(nn.Module):
|
|
|
178
178
|
|
|
179
179
|
class CrossAttention(nn.Module):
|
|
180
180
|
|
|
181
|
-
def __init__(
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
context_dim=None,
|
|
185
|
-
dim_head=64,
|
|
186
|
-
heads=12,
|
|
187
|
-
parallel_ff=False,
|
|
188
|
-
ff_mult=4,
|
|
189
|
-
norm_context=False):
|
|
181
|
+
def __init__(
|
|
182
|
+
self, dim, *, context_dim=None, dim_head=64, heads=12, parallel_ff=False, ff_mult=4, norm_context=False
|
|
183
|
+
):
|
|
190
184
|
super().__init__()
|
|
191
185
|
self.heads = heads
|
|
192
186
|
self.scale = dim_head**-0.5
|
|
@@ -205,8 +199,8 @@ class CrossAttention(nn.Module):
|
|
|
205
199
|
ff_inner_dim = ff_mult * dim
|
|
206
200
|
|
|
207
201
|
self.ff = nn.Sequential(
|
|
208
|
-
nn.Linear(dim, ff_inner_dim
|
|
209
|
-
|
|
202
|
+
nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)
|
|
203
|
+
) if parallel_ff else None
|
|
210
204
|
|
|
211
205
|
def forward(self, x, context, mask):
|
|
212
206
|
"""
|
|
@@ -273,9 +267,11 @@ class Cross_model(nn.Module):
|
|
|
273
267
|
self.layers.append(
|
|
274
268
|
nn.ModuleList([
|
|
275
269
|
Residual(
|
|
276
|
-
CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)
|
|
270
|
+
CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)
|
|
271
|
+
),
|
|
277
272
|
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
|
278
|
-
])
|
|
273
|
+
])
|
|
274
|
+
)
|
|
279
275
|
|
|
280
276
|
def forward(self, query_tokens, context_tokens, mask):
|
|
281
277
|
|
|
@@ -86,7 +86,8 @@ class CLIPScoreModel(ScoreModel):
|
|
|
86
86
|
model_file_path = download_open_clip_model(self.arch, self.pretrained, self.cache_dir)
|
|
87
87
|
|
|
88
88
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
|
89
|
-
self.arch, pretrained=model_file_path, device=self.device
|
|
89
|
+
self.arch, pretrained=model_file_path, device=self.device
|
|
90
|
+
)
|
|
90
91
|
self.tokenizer = open_clip.get_tokenizer(self.arch)
|
|
91
92
|
self.model.eval()
|
|
92
93
|
|
|
@@ -44,11 +44,12 @@ class HPSV2ScoreModel(ScoreModel):
|
|
|
44
44
|
image_std=None,
|
|
45
45
|
image_resize_mode='longest',
|
|
46
46
|
aug_cfg={},
|
|
47
|
-
output_dict=True
|
|
47
|
+
output_dict=True
|
|
48
|
+
)
|
|
48
49
|
|
|
49
50
|
# update weight
|
|
50
51
|
model_weight_path = download_file('AI-ModelScope/HPSv2', HPS_VERSION_MAP[self.model_name], self.cache_dir)
|
|
51
|
-
checkpoint = torch.load(model_weight_path, map_location=self.device)
|
|
52
|
+
checkpoint = torch.load(model_weight_path, map_location=self.device, weights_only=False)
|
|
52
53
|
self.model.load_state_dict(checkpoint['state_dict'])
|
|
53
54
|
self.tokenizer = open_clip.get_tokenizer(self.arch)
|
|
54
55
|
self.model.eval()
|
|
@@ -29,7 +29,8 @@ class MPSModel(ScoreModel):
|
|
|
29
29
|
|
|
30
30
|
config = download_file('AI-ModelScope/MPS', file_name='config.json', cache_dir=self.cache_dir)
|
|
31
31
|
model_pretrained_path = download_file(
|
|
32
|
-
'AI-ModelScope/MPS', file_name='MPS_overall_state_dict.pt', cache_dir=self.cache_dir
|
|
32
|
+
'AI-ModelScope/MPS', file_name='MPS_overall_state_dict.pt', cache_dir=self.cache_dir
|
|
33
|
+
) # modelscope model
|
|
33
34
|
model_weight = torch.load(model_pretrained_path, weights_only=True, map_location='cpu')
|
|
34
35
|
|
|
35
36
|
self.model = CLIPModel(config=CLIPConfig.from_json_file(config))
|
|
@@ -31,8 +31,8 @@ class PickScoreModel(ScoreModel):
|
|
|
31
31
|
"""Load the image(s), and return a tensor (no preprocessing!!) put on self.device
|
|
32
32
|
"""
|
|
33
33
|
image = [self.image_loader(x) for x in image]
|
|
34
|
-
image = self.processor(
|
|
35
|
-
|
|
34
|
+
image = self.processor(images=image, padding=True, truncation=True, max_length=77,
|
|
35
|
+
return_tensors='pt').to(self.device)
|
|
36
36
|
# image = torch.stack(image, dim=0).to(self.device)
|
|
37
37
|
return image
|
|
38
38
|
|
|
@@ -66,7 +66,8 @@ class BLIP2ITMScoreModel(ScoreModel):
|
|
|
66
66
|
query_att = torch.ones(query_token.size()[:-1], dtype=torch.long).to(query_token.device)
|
|
67
67
|
|
|
68
68
|
text_input = self.model.tokenizer(
|
|
69
|
-
texts, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
69
|
+
texts, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
70
|
+
).to(self.device)
|
|
70
71
|
|
|
71
72
|
attention_mask_all = torch.cat([query_att, text_input.attention_mask], dim=1)
|
|
72
73
|
output_itm = self.model.Qformer.bert(
|
|
@@ -42,10 +42,12 @@ class FGA_BLIP2ScoreModel(ScoreModel):
|
|
|
42
42
|
# load model
|
|
43
43
|
self.variant = FGA_BLIP2_MODELS[self.model_name]['variant']
|
|
44
44
|
self.model, self.vis_processors, self.text_processors = load_model_and_preprocess(
|
|
45
|
-
'fga_blip2', self.variant, is_eval=True, device=self.device
|
|
45
|
+
'fga_blip2', self.variant, is_eval=True, device=self.device
|
|
46
|
+
)
|
|
46
47
|
# load pretrained weights
|
|
47
48
|
model_weight_path = download_file(
|
|
48
|
-
'AI-ModelScope/FGA-BLIP2', file_name='fga_blip2.pth', cache_dir=self.cache_dir
|
|
49
|
+
'AI-ModelScope/FGA-BLIP2', file_name='fga_blip2.pth', cache_dir=self.cache_dir
|
|
50
|
+
)
|
|
49
51
|
self.model.load_checkpoint(model_weight_path)
|
|
50
52
|
self.model.eval()
|
|
51
53
|
|
|
@@ -47,7 +47,8 @@ class MLP(nn.Module):
|
|
|
47
47
|
nn.Dropout(0.1),
|
|
48
48
|
nn.Linear(64, 16),
|
|
49
49
|
#nn.ReLU(),
|
|
50
|
-
nn.Linear(16, 1)
|
|
50
|
+
nn.Linear(16, 1)
|
|
51
|
+
)
|
|
51
52
|
|
|
52
53
|
# initial MLP param
|
|
53
54
|
for name, param in self.layers.named_parameters():
|
|
@@ -100,7 +101,8 @@ class ImageReward(nn.Module):
|
|
|
100
101
|
|
|
101
102
|
# text encode
|
|
102
103
|
text_input = self.blip.tokenizer(
|
|
103
|
-
prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
104
|
+
prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
105
|
+
).to(self.device)
|
|
104
106
|
|
|
105
107
|
# image encode
|
|
106
108
|
if isinstance(image, Image.Image):
|
|
@@ -109,7 +111,8 @@ class ImageReward(nn.Module):
|
|
|
109
111
|
pil_image = Image.open(image)
|
|
110
112
|
else:
|
|
111
113
|
raise TypeError(
|
|
112
|
-
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
|
|
114
|
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
|
|
115
|
+
)
|
|
113
116
|
|
|
114
117
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
115
118
|
image_embeds = self.blip.visual_encoder(image)
|
|
@@ -133,7 +136,8 @@ class ImageReward(nn.Module):
|
|
|
133
136
|
def inference_rank(self, prompt, generations_list):
|
|
134
137
|
|
|
135
138
|
text_input = self.blip.tokenizer(
|
|
136
|
-
prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
139
|
+
prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
140
|
+
).to(self.device)
|
|
137
141
|
|
|
138
142
|
txt_set = []
|
|
139
143
|
for generation in generations_list:
|
|
@@ -145,7 +149,8 @@ class ImageReward(nn.Module):
|
|
|
145
149
|
pil_image = Image.open(generation)
|
|
146
150
|
else:
|
|
147
151
|
raise TypeError(
|
|
148
|
-
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
|
|
152
|
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
|
|
153
|
+
)
|
|
149
154
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
150
155
|
image_embeds = self.blip.visual_encoder(image)
|
|
151
156
|
|
|
@@ -30,7 +30,8 @@ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop
|
|
|
30
30
|
num_heads=12,
|
|
31
31
|
use_grad_checkpointing=use_grad_checkpointing,
|
|
32
32
|
ckpt_layer=ckpt_layer,
|
|
33
|
-
drop_path_rate=0 or drop_path_rate
|
|
33
|
+
drop_path_rate=0 or drop_path_rate
|
|
34
|
+
)
|
|
34
35
|
elif vit == 'large':
|
|
35
36
|
vision_width = 1024
|
|
36
37
|
visual_encoder = VisionTransformer(
|
|
@@ -41,7 +42,8 @@ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop
|
|
|
41
42
|
num_heads=16,
|
|
42
43
|
use_grad_checkpointing=use_grad_checkpointing,
|
|
43
44
|
ckpt_layer=ckpt_layer,
|
|
44
|
-
drop_path_rate=0.1 or drop_path_rate
|
|
45
|
+
drop_path_rate=0.1 or drop_path_rate
|
|
46
|
+
)
|
|
45
47
|
return visual_encoder, vision_width
|
|
46
48
|
|
|
47
49
|
|
|
@@ -53,7 +53,8 @@ class ImageRewardScoreModel(ScoreModel):
|
|
|
53
53
|
images = self.load_images(images)
|
|
54
54
|
for index in range(len(texts)):
|
|
55
55
|
text_input = self.model.blip.tokenizer(
|
|
56
|
-
texts[index], padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
56
|
+
texts[index], padding='max_length', truncation=True, max_length=35, return_tensors='pt'
|
|
57
|
+
).to(self.device)
|
|
57
58
|
image_embeds = self.model.blip.visual_encoder(images[index].unsqueeze(0))
|
|
58
59
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
59
60
|
text_output = self.model.blip.text_encoder(
|