evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +4 -1
- evalscope/api/benchmark/__init__.py +3 -0
- evalscope/api/benchmark/adapters/__init__.py +5 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
- evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
- evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
- evalscope/api/benchmark/benchmark.py +356 -0
- evalscope/api/benchmark/meta.py +121 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +349 -0
- evalscope/api/dataset/loader.py +262 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +378 -0
- evalscope/api/evaluator/evaluator.py +56 -0
- evalscope/api/evaluator/state.py +275 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +12 -0
- evalscope/api/messages/chat_message.py +243 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +55 -0
- evalscope/api/metric/scorer.py +113 -0
- evalscope/api/mixin/__init__.py +1 -0
- evalscope/api/mixin/llm_judge_mixin.py +168 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +155 -0
- evalscope/api/model/model.py +386 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/app.py +3 -0
- evalscope/app/ui/app_ui.py +2 -1
- evalscope/app/ui/multi_model.py +50 -25
- evalscope/app/ui/single_model.py +26 -14
- evalscope/app/utils/data_utils.py +43 -27
- evalscope/app/utils/env_utils.py +12 -0
- evalscope/app/utils/text_utils.py +14 -14
- evalscope/app/utils/visualization.py +9 -4
- evalscope/arguments.py +7 -10
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +6 -5
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
- evalscope/backend/rag_eval/utils/embedding.py +10 -1
- evalscope/backend/rag_eval/utils/llm.py +13 -12
- evalscope/benchmarks/__init__.py +0 -2
- evalscope/benchmarks/aime/aime24_adapter.py +38 -40
- evalscope/benchmarks/aime/aime25_adapter.py +34 -40
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
- evalscope/benchmarks/arc/arc_adapter.py +34 -147
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
- evalscope/benchmarks/arena_hard/utils.py +37 -1
- evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
- evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
- evalscope/benchmarks/bfcl/generation.py +222 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
- evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
- evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
- evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
- evalscope/benchmarks/docmath/utils.py +4 -5
- evalscope/benchmarks/drop/drop_adapter.py +88 -40
- evalscope/benchmarks/frames/frames_adapter.py +136 -52
- evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
- evalscope/benchmarks/general_arena/utils.py +23 -27
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
- evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
- evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
- evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
- evalscope/benchmarks/hle/hle_adapter.py +127 -93
- evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
- evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
- evalscope/benchmarks/ifeval/instructions.py +109 -64
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/instructions_util.py +2 -3
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
- evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
- evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
- evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
- evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
- evalscope/benchmarks/math_vista/__init__.py +0 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
- evalscope/benchmarks/mmmu/__init__.py +0 -0
- evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
- evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
- evalscope/benchmarks/musr/musr_adapter.py +33 -64
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
- evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
- evalscope/benchmarks/race/race_adapter.py +33 -119
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
- evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
- evalscope/benchmarks/super_gpqa/utils.py +2 -1
- evalscope/benchmarks/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
- evalscope/benchmarks/text2image/__init__.py +0 -0
- evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
- evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
- evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
- evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
- evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
- evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_app.py +7 -1
- evalscope/cli/start_perf.py +7 -1
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +2 -10
- evalscope/collections/sampler.py +10 -10
- evalscope/collections/schema.py +13 -11
- evalscope/config.py +157 -57
- evalscope/constants.py +37 -61
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +275 -419
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +13 -13
- evalscope/metrics/llm_judge.py +47 -33
- evalscope/metrics/math_parser.py +27 -22
- evalscope/metrics/metric.py +307 -0
- evalscope/metrics/metrics.py +22 -18
- evalscope/metrics/t2v_metrics/__init__.py +0 -52
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
- evalscope/models/__init__.py +6 -29
- evalscope/models/image_edit_model.py +125 -0
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +67 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +126 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +701 -0
- evalscope/perf/benchmark.py +4 -1
- evalscope/perf/http_client.py +4 -2
- evalscope/perf/plugin/api/custom_api.py +5 -4
- evalscope/perf/plugin/api/openai_api.py +11 -9
- evalscope/perf/plugin/datasets/custom.py +2 -1
- evalscope/perf/plugin/datasets/flickr8k.py +1 -1
- evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
- evalscope/perf/plugin/datasets/line_by_line.py +2 -1
- evalscope/perf/plugin/datasets/longalpaca.py +2 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -2
- evalscope/perf/utils/benchmark_util.py +15 -10
- evalscope/perf/utils/db_util.py +9 -6
- evalscope/perf/utils/local_server.py +11 -3
- evalscope/perf/utils/rich_display.py +16 -10
- evalscope/report/__init__.py +2 -3
- evalscope/report/combinator.py +18 -12
- evalscope/report/generator.py +51 -35
- evalscope/report/{utils.py → report.py} +8 -6
- evalscope/run.py +33 -47
- evalscope/summarizer.py +1 -1
- evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
- evalscope/utils/__init__.py +21 -2
- evalscope/utils/chat_service.py +3 -2
- evalscope/utils/deprecation_utils.py +12 -1
- evalscope/utils/function_utils.py +29 -0
- evalscope/utils/import_utils.py +23 -1
- evalscope/utils/io_utils.py +142 -6
- evalscope/utils/json_schema.py +208 -0
- evalscope/utils/logger.py +51 -12
- evalscope/utils/model_utils.py +11 -7
- evalscope/utils/multi_choices.py +288 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
- tests/benchmark/test_eval.py +385 -0
- tests/benchmark/test_image_edit.py +65 -0
- tests/{aigc → benchmark}/test_t2i.py +22 -4
- tests/benchmark/test_vlm.py +80 -0
- tests/cli/test_all.py +85 -47
- tests/cli/test_collection.py +20 -8
- tests/cli/test_custom.py +22 -15
- tests/cli/test_reasoning.py +81 -0
- tests/common.py +73 -0
- tests/perf/test_perf.py +4 -2
- tests/rag/test_clip_benchmark.py +0 -2
- evalscope/benchmarks/aigc/t2i/base.py +0 -56
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -81
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -528
- evalscope/benchmarks/filters.py +0 -59
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/process_bench/critique_template.txt +0 -13
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/benchmarks/utils.py +0 -60
- evalscope/collections/evaluator.py +0 -375
- evalscope/metrics/completion_parsers.py +0 -227
- evalscope/metrics/named_metrics.py +0 -55
- evalscope/models/adapters/__init__.py +0 -14
- evalscope/models/adapters/base_adapter.py +0 -84
- evalscope/models/adapters/bfcl_adapter.py +0 -246
- evalscope/models/adapters/chat_adapter.py +0 -207
- evalscope/models/adapters/choice_adapter.py +0 -222
- evalscope/models/adapters/custom_adapter.py +0 -71
- evalscope/models/adapters/server_adapter.py +0 -236
- evalscope/models/adapters/t2i_adapter.py +0 -79
- evalscope/models/adapters/tau_bench_adapter.py +0 -189
- evalscope/models/custom/__init__.py +0 -4
- evalscope/models/custom/custom_model.py +0 -50
- evalscope/models/custom/dummy_model.py +0 -99
- evalscope/models/local_model.py +0 -128
- evalscope/models/register.py +0 -41
- tests/cli/test_run.py +0 -489
- /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
- /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
- /tests/{aigc → benchmark}/__init__.py +0 -0
evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py
CHANGED
|
@@ -58,8 +58,9 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
|
58
58
|
def get_model(self):
|
|
59
59
|
return self # for compatibility with LlavaMetaForCausalLM
|
|
60
60
|
|
|
61
|
-
def prepare_inputs_labels_for_multimodal(
|
|
62
|
-
|
|
61
|
+
def prepare_inputs_labels_for_multimodal(
|
|
62
|
+
self, input_ids, attention_mask, decoder_attention_mask, past_key_values, labels, images
|
|
63
|
+
):
|
|
63
64
|
# The labels are now separated from the input_ids.
|
|
64
65
|
vision_tower = self.get_vision_tower()
|
|
65
66
|
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
@@ -103,10 +104,12 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
|
103
104
|
_input_embeds_lengths = []
|
|
104
105
|
for cur_new_embed in new_input_embeds:
|
|
105
106
|
_input_embeds_lengths.append(cur_new_embed.shape[0])
|
|
106
|
-
cur_new_embed = torch.cat((
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
107
|
+
cur_new_embed = torch.cat((
|
|
108
|
+
cur_new_embed,
|
|
109
|
+
torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
|
110
|
+
dtype=cur_new_embed.dtype,
|
|
111
|
+
device=cur_new_embed.device)
|
|
112
|
+
),
|
|
110
113
|
dim=0)
|
|
111
114
|
new_input_embeds_align.append(cur_new_embed)
|
|
112
115
|
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
|
@@ -123,7 +126,8 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
|
123
126
|
dtype=attention_mask.dtype,
|
|
124
127
|
device=attention_mask.device)
|
|
125
128
|
cur_new_attention_mask = torch.cat(
|
|
126
|
-
(new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0
|
|
129
|
+
(new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0
|
|
130
|
+
)
|
|
127
131
|
new_attention_mask.append(cur_new_attention_mask)
|
|
128
132
|
attention_mask = torch.stack(new_attention_mask, dim=0)
|
|
129
133
|
assert attention_mask.shape == new_input_embeds.shape[:2]
|
|
@@ -135,7 +139,8 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
|
135
139
|
(attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]),
|
|
136
140
|
True,
|
|
137
141
|
dtype=attention_mask.dtype,
|
|
138
|
-
device=attention_mask.device
|
|
142
|
+
device=attention_mask.device
|
|
143
|
+
)
|
|
139
144
|
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
|
140
145
|
assert attention_mask.shape == new_input_embeds.shape[:2]
|
|
141
146
|
|
|
@@ -204,7 +209,8 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
|
204
209
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
205
210
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
206
211
|
output_hidden_states = (
|
|
207
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
212
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
213
|
+
)
|
|
208
214
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
209
215
|
|
|
210
216
|
if inputs_embeds is None:
|
|
@@ -44,12 +44,14 @@ class CLIPVisionTower(nn.Module):
|
|
|
44
44
|
image_features = []
|
|
45
45
|
for image in images:
|
|
46
46
|
image_forward_out = self.vision_tower(
|
|
47
|
-
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
|
|
47
|
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
|
|
48
|
+
)
|
|
48
49
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
49
50
|
image_features.append(image_feature)
|
|
50
51
|
else:
|
|
51
52
|
image_forward_outs = self.vision_tower(
|
|
52
|
-
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
|
53
|
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
|
54
|
+
)
|
|
53
55
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
54
56
|
|
|
55
57
|
return image_features
|
|
@@ -98,7 +98,8 @@ class CLIPT5Model(VQAScoreModel):
|
|
|
98
98
|
mmprojector_repo=mmprojector_repo,
|
|
99
99
|
mmprojector_name=mmprojector_name,
|
|
100
100
|
device=self.device,
|
|
101
|
-
cache_dir=self.cache_dir
|
|
101
|
+
cache_dir=self.cache_dir
|
|
102
|
+
)
|
|
102
103
|
|
|
103
104
|
def load_images(self, image: List[str]) -> torch.Tensor:
|
|
104
105
|
"""Load the image(s), and return a tensor (after preprocessing) put on self.device
|
|
@@ -115,11 +116,13 @@ class CLIPT5Model(VQAScoreModel):
|
|
|
115
116
|
|
|
116
117
|
@torch.no_grad()
|
|
117
118
|
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
|
|
118
|
-
def forward(
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
119
|
+
def forward(
|
|
120
|
+
self,
|
|
121
|
+
images: List[str],
|
|
122
|
+
texts: List[str],
|
|
123
|
+
question_template: str = default_question_template,
|
|
124
|
+
answer_template: str = default_answer_template
|
|
125
|
+
) -> torch.Tensor:
|
|
123
126
|
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
124
127
|
"""
|
|
125
128
|
assert len(images) == len(texts), 'Number of images and texts must match'
|
|
@@ -139,7 +142,8 @@ class CLIPT5Model(VQAScoreModel):
|
|
|
139
142
|
labels = [t5_tokenizer_image_token(ans, self.tokenizer, return_tensors='pt') for ans in answers]
|
|
140
143
|
|
|
141
144
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
142
|
-
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
145
|
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
146
|
+
)
|
|
143
147
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
|
144
148
|
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
|
145
149
|
labels = labels[:, :self.tokenizer.model_max_length]
|
|
@@ -169,8 +173,8 @@ class CLIPT5Model(VQAScoreModel):
|
|
|
169
173
|
lm_prob = torch.zeros(logits.shape[0])
|
|
170
174
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
|
|
171
175
|
for k in range(lm_prob.shape[0]):
|
|
172
|
-
lm_prob[k] = (
|
|
173
|
-
|
|
176
|
+
lm_prob[k] = (-loss_fct(logits[k],
|
|
177
|
+
labels[k])).exp() # exp to cancel the log and get raw prob between 0 and 1
|
|
174
178
|
return lm_prob
|
|
175
179
|
|
|
176
180
|
@torch.no_grad()
|
|
@@ -191,7 +195,8 @@ class CLIPT5Model(VQAScoreModel):
|
|
|
191
195
|
|
|
192
196
|
input_ids = [t5_tokenizer_image_token(qs, self.tokenizer, return_tensors='pt') for qs in questions]
|
|
193
197
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
194
|
-
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
198
|
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
199
|
+
)
|
|
195
200
|
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
|
196
201
|
|
|
197
202
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import os
|
|
3
|
-
import tiktoken
|
|
4
3
|
import torch
|
|
5
4
|
from openai import OpenAI
|
|
6
5
|
from typing import List
|
|
@@ -42,6 +41,8 @@ class GPT4VModel(VQAScoreModel):
|
|
|
42
41
|
def load_model(self):
|
|
43
42
|
"""Load the model, tokenizer, image transform
|
|
44
43
|
"""
|
|
44
|
+
import tiktoken
|
|
45
|
+
|
|
45
46
|
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
|
|
46
47
|
self.client = OpenAI(api_key=self.openai_key)
|
|
47
48
|
# self.candidate_answers = GPT4V_MODELS[self.model_name]['candidate_answers']
|
|
@@ -122,11 +123,13 @@ class GPT4VModel(VQAScoreModel):
|
|
|
122
123
|
print(completion.choices[0].logprobs.content[0].top_logprobs)
|
|
123
124
|
return torch.Tensor([0.0])
|
|
124
125
|
|
|
125
|
-
def forward(
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
126
|
+
def forward(
|
|
127
|
+
self,
|
|
128
|
+
images: List[str],
|
|
129
|
+
texts: List[str],
|
|
130
|
+
question_template: str = default_question_template,
|
|
131
|
+
answer_template: str = default_answer_template
|
|
132
|
+
) -> torch.Tensor:
|
|
130
133
|
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
131
134
|
"""
|
|
132
135
|
assert len(images) == len(texts), 'Number of images and texts must match'
|
|
@@ -227,8 +227,8 @@ class ConfigValidator:
|
|
|
227
227
|
"""
|
|
228
228
|
for k, v in config.items():
|
|
229
229
|
assert (
|
|
230
|
-
k
|
|
231
|
-
|
|
230
|
+
k in self.arguments
|
|
231
|
+
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
|
232
232
|
|
|
233
233
|
if self.arguments[k].type is not None:
|
|
234
234
|
try:
|
|
@@ -17,6 +17,8 @@ def getAttMap(img, attMap, blur=True, overlap=True):
|
|
|
17
17
|
attMapV = cmap(attMap)
|
|
18
18
|
attMapV = np.delete(attMapV, 3, 2)
|
|
19
19
|
if overlap:
|
|
20
|
-
attMap = (
|
|
21
|
-
|
|
20
|
+
attMap = (
|
|
21
|
+
1 * (1 - attMap**0.7).reshape(attMap.shape + (1, )) * img +
|
|
22
|
+
(attMap**0.7).reshape(attMap.shape + (1, )) * attMapV
|
|
23
|
+
)
|
|
22
24
|
return attMap
|
|
@@ -155,7 +155,8 @@ class MetricLogger(object):
|
|
|
155
155
|
time=str(iter_time),
|
|
156
156
|
data=str(data_time),
|
|
157
157
|
memory=torch.cuda.max_memory_allocated() / MB,
|
|
158
|
-
)
|
|
158
|
+
)
|
|
159
|
+
)
|
|
159
160
|
else:
|
|
160
161
|
print(
|
|
161
162
|
log_msg.format(
|
|
@@ -165,7 +166,8 @@ class MetricLogger(object):
|
|
|
165
166
|
meters=str(self),
|
|
166
167
|
time=str(iter_time),
|
|
167
168
|
data=str(data_time),
|
|
168
|
-
)
|
|
169
|
+
)
|
|
170
|
+
)
|
|
169
171
|
i += 1
|
|
170
172
|
end = time.time()
|
|
171
173
|
total_time = time.time() - start_time
|
|
@@ -13,15 +13,9 @@ from . import registry
|
|
|
13
13
|
@registry.register_lr_scheduler('linear_warmup_step_lr')
|
|
14
14
|
class LinearWarmupStepLRScheduler:
|
|
15
15
|
|
|
16
|
-
def __init__(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
min_lr,
|
|
20
|
-
init_lr,
|
|
21
|
-
decay_rate=1,
|
|
22
|
-
warmup_start_lr=-1,
|
|
23
|
-
warmup_steps=0,
|
|
24
|
-
**kwargs):
|
|
16
|
+
def __init__(
|
|
17
|
+
self, optimizer, max_epoch, min_lr, init_lr, decay_rate=1, warmup_start_lr=-1, warmup_steps=0, **kwargs
|
|
18
|
+
):
|
|
25
19
|
self.optimizer = optimizer
|
|
26
20
|
|
|
27
21
|
self.max_epoch = max_epoch
|
|
@@ -96,8 +96,9 @@ class Registry:
|
|
|
96
96
|
|
|
97
97
|
assert issubclass(model_cls, BaseModel), 'All models must inherit BaseModel class'
|
|
98
98
|
if name in cls.mapping['model_name_mapping']:
|
|
99
|
-
raise KeyError(
|
|
100
|
-
|
|
99
|
+
raise KeyError(
|
|
100
|
+
"Name '{}' already registered for {}.".format(name, cls.mapping['model_name_mapping'][name])
|
|
101
|
+
)
|
|
101
102
|
cls.mapping['model_name_mapping'][name] = model_cls
|
|
102
103
|
return model_cls
|
|
103
104
|
|
|
@@ -120,8 +121,9 @@ class Registry:
|
|
|
120
121
|
|
|
121
122
|
assert issubclass(processor_cls, BaseProcessor), 'All processors must inherit BaseProcessor class'
|
|
122
123
|
if name in cls.mapping['processor_name_mapping']:
|
|
123
|
-
raise KeyError(
|
|
124
|
-
name, cls.mapping['processor_name_mapping'][name])
|
|
124
|
+
raise KeyError(
|
|
125
|
+
"Name '{}' already registered for {}.".format(name, cls.mapping['processor_name_mapping'][name])
|
|
126
|
+
)
|
|
125
127
|
cls.mapping['processor_name_mapping'][name] = processor_cls
|
|
126
128
|
return processor_cls
|
|
127
129
|
|
|
@@ -141,8 +143,9 @@ class Registry:
|
|
|
141
143
|
|
|
142
144
|
def wrap(lr_sched_cls):
|
|
143
145
|
if name in cls.mapping['lr_scheduler_name_mapping']:
|
|
144
|
-
raise KeyError(
|
|
145
|
-
name, cls.mapping['lr_scheduler_name_mapping'][name])
|
|
146
|
+
raise KeyError(
|
|
147
|
+
"Name '{}' already registered for {}.".format(name, cls.mapping['lr_scheduler_name_mapping'][name])
|
|
148
|
+
)
|
|
146
149
|
cls.mapping['lr_scheduler_name_mapping'][name] = lr_sched_cls
|
|
147
150
|
return lr_sched_cls
|
|
148
151
|
|
|
@@ -162,8 +165,9 @@ class Registry:
|
|
|
162
165
|
|
|
163
166
|
def wrap(runner_cls):
|
|
164
167
|
if name in cls.mapping['runner_name_mapping']:
|
|
165
|
-
raise KeyError(
|
|
166
|
-
|
|
168
|
+
raise KeyError(
|
|
169
|
+
"Name '{}' already registered for {}.".format(name, cls.mapping['runner_name_mapping'][name])
|
|
170
|
+
)
|
|
167
171
|
cls.mapping['runner_name_mapping'][name] = runner_cls
|
|
168
172
|
return runner_cls
|
|
169
173
|
|
|
@@ -285,8 +289,10 @@ class Registry:
|
|
|
285
289
|
break
|
|
286
290
|
|
|
287
291
|
if ('writer' in cls.mapping['state'] and value == default and no_warning is False):
|
|
288
|
-
cls.mapping['state']['writer'].warning(
|
|
289
|
-
|
|
292
|
+
cls.mapping['state']['writer'].warning(
|
|
293
|
+
'Key {} is not present in registry, returning default value '
|
|
294
|
+
'of {}'.format(original_name, default)
|
|
295
|
+
)
|
|
290
296
|
return value
|
|
291
297
|
|
|
292
298
|
@classmethod
|
|
@@ -178,8 +178,9 @@ class VQA:
|
|
|
178
178
|
for ann in anns:
|
|
179
179
|
quesId = ann['question_id']
|
|
180
180
|
if res.dataset['task_type'] == 'Multiple Choice':
|
|
181
|
-
assert (
|
|
182
|
-
|
|
181
|
+
assert (
|
|
182
|
+
ann['answer'] in self.qqa[quesId]['multiple_choices']
|
|
183
|
+
), 'predicted answer is not one of the multiple choices'
|
|
183
184
|
qaAnn = self.qa[quesId]
|
|
184
185
|
ann['image_id'] = qaAnn['image_id']
|
|
185
186
|
ann['question_type'] = qaAnn['question_type']
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
__author__ = 'aagrawal'
|
|
11
11
|
|
|
12
12
|
import re
|
|
13
|
+
|
|
13
14
|
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
|
14
15
|
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
|
15
16
|
import sys
|
|
@@ -312,7 +313,8 @@ class VQAEval:
|
|
|
312
313
|
progress = 1
|
|
313
314
|
status = 'Done...\r\n'
|
|
314
315
|
block = int(round(barLength * progress))
|
|
315
|
-
text = '\rFinshed Percent: [{0}] {1}% {2}'.format(
|
|
316
|
-
|
|
316
|
+
text = '\rFinshed Percent: [{0}] {1}% {2}'.format(
|
|
317
|
+
'#' * block + '-' * (barLength - block), int(progress * 100), status
|
|
318
|
+
)
|
|
317
319
|
sys.stdout.write(text)
|
|
318
320
|
sys.stdout.flush()
|
|
@@ -166,10 +166,12 @@ def load_model_and_preprocess(name, model_type, is_eval=False, device='cpu'):
|
|
|
166
166
|
vis_processors, txt_processors = load_preprocess(preprocess_cfg)
|
|
167
167
|
else:
|
|
168
168
|
vis_processors, txt_processors = None, None
|
|
169
|
-
logging.info(
|
|
169
|
+
logging.info(
|
|
170
|
+
f"""No default preprocess for model {name} ({model_type}).
|
|
170
171
|
This can happen if the model is not finetuned on downstream datasets,
|
|
171
172
|
or it is not intended for direct use without finetuning.
|
|
172
|
-
"""
|
|
173
|
+
"""
|
|
174
|
+
)
|
|
173
175
|
|
|
174
176
|
if device == 'cpu' or device == torch.device('cpu'):
|
|
175
177
|
model = model.float()
|
|
@@ -195,8 +197,10 @@ class ModelZoo:
|
|
|
195
197
|
}
|
|
196
198
|
|
|
197
199
|
def __str__(self) -> str:
|
|
198
|
-
return (
|
|
199
|
-
|
|
200
|
+
return (
|
|
201
|
+
'=' * 50 + '\n' + f"{'Architectures':<30} {'Types'}\n" + '=' * 50 + '\n'
|
|
202
|
+
+ '\n'.join([f"{name:<30} {', '.join(types)}" for name, types in self.model_zoo.items()])
|
|
203
|
+
)
|
|
200
204
|
|
|
201
205
|
def __iter__(self):
|
|
202
206
|
return iter(self.model_zoo.items())
|
|
@@ -19,13 +19,23 @@ from torch import Tensor, device, dtype, nn
|
|
|
19
19
|
from torch.nn import CrossEntropyLoss
|
|
20
20
|
from transformers.activations import ACT2FN
|
|
21
21
|
from transformers.file_utils import ModelOutput
|
|
22
|
-
from transformers.modeling_outputs import (
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
22
|
+
from transformers.modeling_outputs import (
|
|
23
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
|
24
|
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
25
|
+
CausalLMOutputWithCrossAttentions,
|
|
26
|
+
MaskedLMOutput,
|
|
27
|
+
MultipleChoiceModelOutput,
|
|
28
|
+
NextSentencePredictorOutput,
|
|
29
|
+
QuestionAnsweringModelOutput,
|
|
30
|
+
SequenceClassifierOutput,
|
|
31
|
+
TokenClassifierOutput,
|
|
32
|
+
)
|
|
33
|
+
from transformers.modeling_utils import (
|
|
34
|
+
PreTrainedModel,
|
|
35
|
+
apply_chunking_to_forward,
|
|
36
|
+
find_pruneable_heads_and_indices,
|
|
37
|
+
prune_linear_layer,
|
|
38
|
+
)
|
|
29
39
|
from transformers.models.bert.configuration_bert import BertConfig
|
|
30
40
|
from transformers.utils import logging
|
|
31
41
|
from typing import Any, Dict, Optional, Tuple
|
|
@@ -89,8 +99,10 @@ class BertSelfAttention(nn.Module):
|
|
|
89
99
|
super().__init__()
|
|
90
100
|
self.config = config
|
|
91
101
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
|
|
92
|
-
raise ValueError(
|
|
93
|
-
|
|
102
|
+
raise ValueError(
|
|
103
|
+
'The hidden size (%d) is not a multiple of the number of attention '
|
|
104
|
+
'heads (%d)' % (config.hidden_size, config.num_attention_heads)
|
|
105
|
+
)
|
|
94
106
|
|
|
95
107
|
self.num_attention_heads = config.num_attention_heads
|
|
96
108
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
@@ -366,8 +378,9 @@ class BertLayer(nn.Module):
|
|
|
366
378
|
query_attention_output = attention_output[:, :query_length, :]
|
|
367
379
|
|
|
368
380
|
if self.has_cross_attention:
|
|
369
|
-
assert (
|
|
370
|
-
|
|
381
|
+
assert (
|
|
382
|
+
encoder_hidden_states is not None
|
|
383
|
+
), 'encoder_hidden_states must be given for cross-attention layers'
|
|
371
384
|
cross_attention_outputs = self.crossattention(
|
|
372
385
|
query_attention_output,
|
|
373
386
|
attention_mask,
|
|
@@ -377,8 +390,9 @@ class BertLayer(nn.Module):
|
|
|
377
390
|
output_attentions=output_attentions,
|
|
378
391
|
)
|
|
379
392
|
query_attention_output = cross_attention_outputs[0]
|
|
380
|
-
outputs = (
|
|
381
|
-
|
|
393
|
+
outputs = (
|
|
394
|
+
outputs + cross_attention_outputs[1:-1]
|
|
395
|
+
) # add cross attentions if we output attention weights
|
|
382
396
|
|
|
383
397
|
layer_output = apply_chunking_to_forward(
|
|
384
398
|
self.feed_forward_chunk_query,
|
|
@@ -457,7 +471,8 @@ class BertEncoder(nn.Module):
|
|
|
457
471
|
|
|
458
472
|
if use_cache:
|
|
459
473
|
logger.warn(
|
|
460
|
-
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
|
474
|
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
|
475
|
+
)
|
|
461
476
|
use_cache = False
|
|
462
477
|
|
|
463
478
|
def create_custom_forward(module):
|
|
@@ -498,13 +513,15 @@ class BertEncoder(nn.Module):
|
|
|
498
513
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
499
514
|
|
|
500
515
|
if not return_dict:
|
|
501
|
-
return tuple(
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
516
|
+
return tuple(
|
|
517
|
+
v for v in [
|
|
518
|
+
hidden_states,
|
|
519
|
+
next_decoder_cache,
|
|
520
|
+
all_hidden_states,
|
|
521
|
+
all_self_attentions,
|
|
522
|
+
all_cross_attentions,
|
|
523
|
+
] if v is not None
|
|
524
|
+
)
|
|
508
525
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
509
526
|
last_hidden_state=hidden_states,
|
|
510
527
|
past_key_values=next_decoder_cache,
|
|
@@ -708,8 +725,11 @@ class BertModel(BertPreTrainedModel):
|
|
|
708
725
|
else:
|
|
709
726
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
710
727
|
else:
|
|
711
|
-
raise ValueError(
|
|
712
|
-
|
|
728
|
+
raise ValueError(
|
|
729
|
+
'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
|
|
730
|
+
input_shape, attention_mask.shape
|
|
731
|
+
)
|
|
732
|
+
)
|
|
713
733
|
|
|
714
734
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
715
735
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
@@ -756,7 +776,8 @@ class BertModel(BertPreTrainedModel):
|
|
|
756
776
|
"""
|
|
757
777
|
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
|
758
778
|
output_hidden_states = (
|
|
759
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
779
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
780
|
+
)
|
|
760
781
|
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
|
761
782
|
|
|
762
783
|
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
@@ -766,7 +787,8 @@ class BertModel(BertPreTrainedModel):
|
|
|
766
787
|
|
|
767
788
|
# past_key_values_length
|
|
768
789
|
past_key_values_length = (
|
|
769
|
-
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
|
790
|
+
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
|
791
|
+
)
|
|
770
792
|
|
|
771
793
|
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
|
772
794
|
|
evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py
CHANGED
|
@@ -54,16 +54,18 @@ class Blip2Qformer(Blip2Base):
|
|
|
54
54
|
|
|
55
55
|
self.tokenizer = self.init_tokenizer()
|
|
56
56
|
|
|
57
|
-
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
58
|
-
|
|
57
|
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
58
|
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
|
59
|
+
)
|
|
59
60
|
if freeze_vit:
|
|
60
61
|
for name, param in self.visual_encoder.named_parameters():
|
|
61
62
|
param.requires_grad = False
|
|
62
63
|
self.visual_encoder = self.visual_encoder.eval()
|
|
63
64
|
self.visual_encoder.train = disabled_train
|
|
64
65
|
logging.info('freeze vision encoder')
|
|
65
|
-
self.Qformer, self.query_tokens = self.init_Qformer(
|
|
66
|
-
|
|
66
|
+
self.Qformer, self.query_tokens = self.init_Qformer(
|
|
67
|
+
num_query_token, self.visual_encoder.num_features, cross_attention_freq
|
|
68
|
+
)
|
|
67
69
|
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
|
68
70
|
state_dict = self.Qformer.state_dict()
|
|
69
71
|
for name, param in self.Qformer.named_parameters():
|
|
@@ -135,8 +137,10 @@ class Blip2Qformer(Blip2Base):
|
|
|
135
137
|
bs = image.size(0)
|
|
136
138
|
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image.device)
|
|
137
139
|
|
|
138
|
-
loss_itc = (
|
|
139
|
-
|
|
140
|
+
loss_itc = (
|
|
141
|
+
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
|
|
142
|
+
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
|
|
143
|
+
) / 2
|
|
140
144
|
|
|
141
145
|
###============== Image-text Matching ===================###
|
|
142
146
|
text_input_ids_world = concat_all_gather(text_tokens.input_ids)
|
|
@@ -274,7 +278,8 @@ class Blip2Qformer(Blip2Base):
|
|
|
274
278
|
top_p=top_p,
|
|
275
279
|
eos_token_id=self.tokenizer.sep_token_id,
|
|
276
280
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
277
|
-
**model_kwargs
|
|
281
|
+
**model_kwargs
|
|
282
|
+
)
|
|
278
283
|
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
279
284
|
return captions
|
|
280
285
|
|
|
@@ -66,8 +66,9 @@ class Blip2T5(Blip2Base):
|
|
|
66
66
|
|
|
67
67
|
self.tokenizer = self.init_tokenizer()
|
|
68
68
|
|
|
69
|
-
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
70
|
-
|
|
69
|
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
70
|
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
|
71
|
+
)
|
|
71
72
|
if freeze_vit:
|
|
72
73
|
for name, param in self.visual_encoder.named_parameters():
|
|
73
74
|
param.requires_grad = False
|
|
@@ -136,8 +137,9 @@ class Blip2T5(Blip2Base):
|
|
|
136
137
|
|
|
137
138
|
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
|
|
138
139
|
|
|
139
|
-
targets = output_tokens.input_ids.masked_fill(
|
|
140
|
-
|
|
140
|
+
targets = output_tokens.input_ids.masked_fill(
|
|
141
|
+
output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
|
|
142
|
+
)
|
|
141
143
|
|
|
142
144
|
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
|
|
143
145
|
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
|
|
@@ -234,17 +236,19 @@ class Blip2T5(Blip2Base):
|
|
|
234
236
|
|
|
235
237
|
return output_text
|
|
236
238
|
|
|
237
|
-
def predict_answers(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
239
|
+
def predict_answers(
|
|
240
|
+
self,
|
|
241
|
+
samples,
|
|
242
|
+
num_beams=5,
|
|
243
|
+
inference_method='generate',
|
|
244
|
+
max_len=10,
|
|
245
|
+
min_len=1,
|
|
246
|
+
num_ans_candidates=128,
|
|
247
|
+
answer_list=None,
|
|
248
|
+
prompt='',
|
|
249
|
+
length_penalty=-1,
|
|
250
|
+
**kwargs
|
|
251
|
+
):
|
|
248
252
|
image = samples['image']
|
|
249
253
|
with self.maybe_autocast():
|
|
250
254
|
image_embeds = self.ln_vision(self.visual_encoder(image))
|
|
@@ -318,13 +322,15 @@ class Blip2T5(Blip2Base):
|
|
|
318
322
|
|
|
319
323
|
self._lemmatizer = spacy.load('en_core_web_sm')
|
|
320
324
|
except ImportError:
|
|
321
|
-
logging.error(
|
|
325
|
+
logging.error(
|
|
326
|
+
"""
|
|
322
327
|
Please install spacy and en_core_web_sm model to apply lemmatization.
|
|
323
328
|
python -m spacy download en_core_web_sm
|
|
324
329
|
OR
|
|
325
330
|
import spacy.cli
|
|
326
331
|
spacy.cli.download("en_core_web_sm")
|
|
327
|
-
"""
|
|
332
|
+
"""
|
|
333
|
+
)
|
|
328
334
|
exit(1)
|
|
329
335
|
|
|
330
336
|
return self._lemmatizer
|