evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +9 -3
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/benchmark.py +4 -3
- evalscope/perf/main.py +4 -2
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +16 -8
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +20 -7
- tests/perf/test_perf.py +6 -3
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from ...common.registry import registry
|
|
12
|
+
from ..base_model import tile
|
|
13
|
+
from ..med import XBertEncoder, XBertLMHeadDecoder
|
|
14
|
+
from ..vit import VisionTransformerEncoder
|
|
15
|
+
from .blip import BlipBase
|
|
16
|
+
from .blip_outputs import BlipIntermediateOutput, BlipOutput
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@registry.register_model('blip_vqa')
|
|
20
|
+
class BlipVQA(BlipBase):
|
|
21
|
+
"""
|
|
22
|
+
BLIP VQA models.
|
|
23
|
+
|
|
24
|
+
Supported model types:
|
|
25
|
+
- base: vqa model initialized with pre-trained BLIP base model on 115M image-text pairs after CapFilt; not fine-tuned.
|
|
26
|
+
- vqav2: fine-tuned BLIP base model on VQA v2.0 dataset.
|
|
27
|
+
|
|
28
|
+
Usage:
|
|
29
|
+
>>> from lavis.models import load_model
|
|
30
|
+
>>> model = load_model("blip_vqa", "vqav2")
|
|
31
|
+
>>> model = load_model("blip_vqa", "okvqa")
|
|
32
|
+
>>> model = load_model("blip_vqa", "aokvqa")
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
|
36
|
+
'vqav2': 'configs/models/blip_vqav2.yaml',
|
|
37
|
+
'okvqa': 'configs/models/blip_vqa_okvqa.yaml',
|
|
38
|
+
'aokvqa': 'configs/models/blip_vqa_aokvqa.yaml',
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def __init__(self, image_encoder, text_encoder, text_decoder, max_txt_len=35):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.tokenizer = self.init_tokenizer()
|
|
44
|
+
|
|
45
|
+
self.visual_encoder = image_encoder
|
|
46
|
+
|
|
47
|
+
self.text_encoder = text_encoder
|
|
48
|
+
self.text_decoder = text_decoder
|
|
49
|
+
|
|
50
|
+
self.max_txt_len = max_txt_len
|
|
51
|
+
|
|
52
|
+
def forward(self, samples):
|
|
53
|
+
"""
|
|
54
|
+
Args:
|
|
55
|
+
samples (dict): A dictionary containing the following keys:
|
|
56
|
+
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.
|
|
57
|
+
- text_input (list): A list of strings, each string is a question
|
|
58
|
+
- answer (list): A list of strings, each string is an answer
|
|
59
|
+
- weight (torch.Tensor): A tensor used to weigh each answer in the loss computation.
|
|
60
|
+
The shape of the tensor is (sum(n_answers),)
|
|
61
|
+
- n_answers (torch.Tensor): A tensor shape (batch_size,) containing the number of answers
|
|
62
|
+
for each question in the batch.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A BlipOutput object containing loss and intermediate outputs,
|
|
66
|
+
see :class:`lavis.models.blip_outputs.BlipOutput` for more details.
|
|
67
|
+
|
|
68
|
+
Examples:
|
|
69
|
+
```python
|
|
70
|
+
>>> import torch
|
|
71
|
+
>>> from lavis.models import load_model
|
|
72
|
+
>>> model = load_model("blip_vqa")
|
|
73
|
+
>>> samples = {
|
|
74
|
+
... "image": torch.rand(2, 3, 480, 480),
|
|
75
|
+
... "text_input": ["What is this?", "What is that?"],
|
|
76
|
+
... "answer": ["cat", "cat", "dog"],
|
|
77
|
+
... "weight": torch.tensor([1.0, 1.0, 1.0]),
|
|
78
|
+
... "n_answers": torch.tensor([2, 1]),
|
|
79
|
+
... }
|
|
80
|
+
>>> output = model(samples)
|
|
81
|
+
>>> output.keys()
|
|
82
|
+
odict_keys(['intermediate_output', 'loss'])
|
|
83
|
+
>>> output.intermediate_output.keys()
|
|
84
|
+
odict_keys(['image_embeds', 'encoder_output', 'decoder_output', 'decoder_labels'])
|
|
85
|
+
```
|
|
86
|
+
"""
|
|
87
|
+
encoder_output, image_embeds = self.forward_encoder(samples)
|
|
88
|
+
loss, decoder_output, decoder_targets = self.forward_decoder(samples=samples, encoder_out=encoder_output)
|
|
89
|
+
|
|
90
|
+
return BlipOutput(
|
|
91
|
+
loss=loss,
|
|
92
|
+
intermediate_output=BlipIntermediateOutput(
|
|
93
|
+
image_embeds=image_embeds,
|
|
94
|
+
encoder_output=encoder_output,
|
|
95
|
+
decoder_output=decoder_output,
|
|
96
|
+
decoder_labels=decoder_targets,
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def forward_encoder(self, samples):
|
|
101
|
+
questions = samples['text_input']
|
|
102
|
+
questions = self.tokenizer(
|
|
103
|
+
questions,
|
|
104
|
+
padding='longest',
|
|
105
|
+
truncation=True,
|
|
106
|
+
max_length=self.max_txt_len,
|
|
107
|
+
return_tensors='pt',
|
|
108
|
+
).to(self.device)
|
|
109
|
+
questions.input_ids[:, 0] = self.tokenizer.enc_token_id
|
|
110
|
+
samples.update({'tokenized_text': questions})
|
|
111
|
+
|
|
112
|
+
image_embeds = self.visual_encoder.forward_features(samples['image'])
|
|
113
|
+
encoder_output = self.text_encoder.forward_automask(
|
|
114
|
+
tokenized_text=samples['tokenized_text'], visual_embeds=image_embeds)
|
|
115
|
+
|
|
116
|
+
return encoder_output, image_embeds
|
|
117
|
+
|
|
118
|
+
def forward_decoder(self, samples, encoder_out, **kwargs):
|
|
119
|
+
answers = self.tokenizer(samples['answer'], padding='longest', return_tensors='pt').to(self.device)
|
|
120
|
+
answers.input_ids[:, 0] = self.tokenizer.bos_token_id
|
|
121
|
+
answer_targets = answers.input_ids.masked_fill(answers.input_ids == self.tokenizer.pad_token_id, -100)
|
|
122
|
+
|
|
123
|
+
question_states = []
|
|
124
|
+
question_atts = []
|
|
125
|
+
|
|
126
|
+
question = samples['tokenized_text']
|
|
127
|
+
question_output = encoder_out
|
|
128
|
+
|
|
129
|
+
for b, n in enumerate(samples['n_answers']):
|
|
130
|
+
question_states += [question_output.last_hidden_state[b]] * n
|
|
131
|
+
question_atts += [question.attention_mask[b]] * n
|
|
132
|
+
|
|
133
|
+
question_states = torch.stack(question_states, dim=0)
|
|
134
|
+
question_atts = torch.stack(question_atts, dim=0)
|
|
135
|
+
|
|
136
|
+
answer_output = self.text_decoder(
|
|
137
|
+
answers.input_ids,
|
|
138
|
+
attention_mask=answers.attention_mask,
|
|
139
|
+
encoder_hidden_states=question_states,
|
|
140
|
+
encoder_attention_mask=question_atts,
|
|
141
|
+
labels=answer_targets,
|
|
142
|
+
return_dict=True,
|
|
143
|
+
reduction='none',
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
loss = samples['weight'] * answer_output.loss
|
|
147
|
+
bsz = samples['image'].size(0)
|
|
148
|
+
|
|
149
|
+
loss = loss.sum() / bsz
|
|
150
|
+
|
|
151
|
+
return loss, answer_output, answer_targets
|
|
152
|
+
|
|
153
|
+
def predict_answers(self,
|
|
154
|
+
samples,
|
|
155
|
+
num_beams=3,
|
|
156
|
+
inference_method='rank',
|
|
157
|
+
max_len=10,
|
|
158
|
+
min_len=1,
|
|
159
|
+
num_ans_candidates=128,
|
|
160
|
+
answer_list=None,
|
|
161
|
+
**kwargs):
|
|
162
|
+
"""
|
|
163
|
+
Args:
|
|
164
|
+
samples (dict): A dictionary containing the following keys:
|
|
165
|
+
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.
|
|
166
|
+
- text_input (str or [str]): String or a list of strings, each string is a question.
|
|
167
|
+
The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.
|
|
168
|
+
num_beams (int): Number of beams for beam search. 1 means no beam search.
|
|
169
|
+
inference_method (str): Inference method. One of "rank", "generate".
|
|
170
|
+
- If "rank", the model will return answers with the highest probability from the answer list.
|
|
171
|
+
- If "generate", the model will generate answers.
|
|
172
|
+
max_len (int): Maximum length of generated answers.
|
|
173
|
+
min_len (int): Minimum length of generated answers.
|
|
174
|
+
num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability.
|
|
175
|
+
answer_list (list): A list of strings, each string is an answer.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
List: A list of strings, each string is an answer.
|
|
179
|
+
|
|
180
|
+
Examples:
|
|
181
|
+
```python
|
|
182
|
+
>>> from PIL import Image
|
|
183
|
+
>>> from lavis.models import load_model_and_preprocess
|
|
184
|
+
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_vqa", "vqav2")
|
|
185
|
+
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
|
|
186
|
+
>>> question = "Which city is this photo taken?"
|
|
187
|
+
>>> image = vis_processors["eval"](raw_image).unsqueeze(0)
|
|
188
|
+
>>> question = txt_processors["eval"](question)
|
|
189
|
+
>>> samples = {"image": image, "text_input": [question]}
|
|
190
|
+
>>> answers = model.predict_answers(samples)
|
|
191
|
+
>>> answers
|
|
192
|
+
['singapore']
|
|
193
|
+
>>> answer_list = ["Singapore", "London", "Palo Alto", "Tokyo"]
|
|
194
|
+
>>> answers = model.predict_answers(samples, answer_list=answer_list)
|
|
195
|
+
>>> answers
|
|
196
|
+
['Singapore']
|
|
197
|
+
```
|
|
198
|
+
"""
|
|
199
|
+
assert inference_method in [
|
|
200
|
+
'rank',
|
|
201
|
+
'generate',
|
|
202
|
+
], "Inference method must be one of 'rank' or 'generate', got {}.".format(inference_method)
|
|
203
|
+
|
|
204
|
+
if isinstance(samples['text_input'], str):
|
|
205
|
+
samples['text_input'] = [samples['text_input']]
|
|
206
|
+
|
|
207
|
+
assert len(samples['text_input']) == samples['image'].size(
|
|
208
|
+
0), 'The number of questions must be equal to the batch size.'
|
|
209
|
+
|
|
210
|
+
if inference_method == 'generate':
|
|
211
|
+
return self._generate_answers(samples, num_beams=num_beams, max_length=max_len, min_length=min_len)
|
|
212
|
+
elif inference_method == 'rank':
|
|
213
|
+
assert answer_list is not None, 'answer_list must be provided for ranking'
|
|
214
|
+
|
|
215
|
+
num_ans_candidates = min(num_ans_candidates, len(answer_list))
|
|
216
|
+
|
|
217
|
+
return self._rank_answers(samples, answer_list=answer_list, num_ans_candidates=num_ans_candidates)
|
|
218
|
+
|
|
219
|
+
def _generate_answers(self, samples, num_beams=3, max_length=10, min_length=1):
|
|
220
|
+
encoder_out, _ = self.forward_encoder(samples)
|
|
221
|
+
|
|
222
|
+
question_output = encoder_out
|
|
223
|
+
|
|
224
|
+
question_states = question_output.last_hidden_state.repeat_interleave(num_beams, dim=0)
|
|
225
|
+
question_atts = torch.ones(question_states.size()[:-1], dtype=torch.long).to(self.device)
|
|
226
|
+
|
|
227
|
+
model_kwargs = {
|
|
228
|
+
'encoder_hidden_states': question_states,
|
|
229
|
+
'encoder_attention_mask': question_atts,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
bsz = samples['image'].size(0)
|
|
233
|
+
bos_ids = torch.full((bsz, 1), fill_value=self.tokenizer.bos_token_id, device=self.device)
|
|
234
|
+
|
|
235
|
+
outputs = self.text_decoder.generate(
|
|
236
|
+
input_ids=bos_ids,
|
|
237
|
+
max_length=max_length,
|
|
238
|
+
min_length=min_length,
|
|
239
|
+
num_beams=num_beams,
|
|
240
|
+
eos_token_id=self.tokenizer.sep_token_id,
|
|
241
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
242
|
+
**model_kwargs)
|
|
243
|
+
|
|
244
|
+
# collect answers
|
|
245
|
+
answers = []
|
|
246
|
+
for output in outputs:
|
|
247
|
+
answer = self.tokenizer.decode(output, skip_special_tokens=True)
|
|
248
|
+
answers.append(answer)
|
|
249
|
+
|
|
250
|
+
return answers
|
|
251
|
+
|
|
252
|
+
def _rank_answers(self, samples, answer_list, num_ans_candidates):
|
|
253
|
+
"""
|
|
254
|
+
Generate the first token of answers using decoder and select ${num_ans_candidates}
|
|
255
|
+
most probable ones. Then select answers from answer list, which start with the probable tokens.
|
|
256
|
+
Lastly, use the selected answers as the ground-truth labels for decoding and calculating LM loss.
|
|
257
|
+
Return the answers that minimize the losses as result.
|
|
258
|
+
|
|
259
|
+
"""
|
|
260
|
+
answer_candidates = self.tokenizer(answer_list, padding='longest', return_tensors='pt').to(self.device)
|
|
261
|
+
answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id
|
|
262
|
+
|
|
263
|
+
answer_ids = answer_candidates.input_ids
|
|
264
|
+
answer_atts = answer_candidates.attention_mask
|
|
265
|
+
|
|
266
|
+
question_output, _ = self.forward_encoder(samples)
|
|
267
|
+
question_states = question_output.last_hidden_state
|
|
268
|
+
|
|
269
|
+
tokenized_question = samples['tokenized_text']
|
|
270
|
+
question_atts = tokenized_question.attention_mask
|
|
271
|
+
|
|
272
|
+
num_ques = question_states.size(0)
|
|
273
|
+
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token
|
|
274
|
+
|
|
275
|
+
start_output = self.text_decoder(
|
|
276
|
+
start_ids,
|
|
277
|
+
encoder_hidden_states=question_states,
|
|
278
|
+
encoder_attention_mask=question_atts,
|
|
279
|
+
return_dict=True,
|
|
280
|
+
reduction='none',
|
|
281
|
+
)
|
|
282
|
+
logits = start_output.logits[:, 0, :] # first token's logit
|
|
283
|
+
|
|
284
|
+
# topk_probs: top-k probability
|
|
285
|
+
# topk_ids: [num_question, k]
|
|
286
|
+
answer_first_token = answer_ids[:, 1]
|
|
287
|
+
prob_first_token = F.softmax(logits, dim=1).index_select(dim=1, index=answer_first_token)
|
|
288
|
+
topk_probs, topk_ids = prob_first_token.topk(num_ans_candidates, dim=1)
|
|
289
|
+
|
|
290
|
+
# answer input: [num_question*k, answer_len]
|
|
291
|
+
input_ids = []
|
|
292
|
+
input_atts = []
|
|
293
|
+
for b, topk_id in enumerate(topk_ids):
|
|
294
|
+
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
|
295
|
+
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
|
296
|
+
input_ids = torch.cat(input_ids, dim=0)
|
|
297
|
+
input_atts = torch.cat(input_atts, dim=0)
|
|
298
|
+
|
|
299
|
+
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
|
300
|
+
|
|
301
|
+
# repeat encoder's output for top-k answers
|
|
302
|
+
question_states = tile(question_states, 0, num_ans_candidates)
|
|
303
|
+
question_atts = tile(question_atts, 0, num_ans_candidates)
|
|
304
|
+
|
|
305
|
+
output = self.text_decoder(
|
|
306
|
+
input_ids,
|
|
307
|
+
attention_mask=input_atts,
|
|
308
|
+
encoder_hidden_states=question_states,
|
|
309
|
+
encoder_attention_mask=question_atts,
|
|
310
|
+
labels=targets_ids,
|
|
311
|
+
return_dict=True,
|
|
312
|
+
reduction='none',
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
log_probs_sum = -output.loss
|
|
316
|
+
log_probs_sum = log_probs_sum.view(num_ques, num_ans_candidates)
|
|
317
|
+
|
|
318
|
+
max_topk_ids = log_probs_sum.argmax(dim=1)
|
|
319
|
+
max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]
|
|
320
|
+
|
|
321
|
+
answers = [answer_list[max_id] for max_id in max_ids]
|
|
322
|
+
|
|
323
|
+
return answers
|
|
324
|
+
|
|
325
|
+
@classmethod
|
|
326
|
+
def from_config(cls, cfg=None):
|
|
327
|
+
image_encoder = VisionTransformerEncoder.from_config(cfg)
|
|
328
|
+
|
|
329
|
+
# text encoder + multimodal encoder
|
|
330
|
+
text_encoder = XBertEncoder.from_config(cfg)
|
|
331
|
+
text_decoder = XBertLMHeadDecoder.from_config(cfg)
|
|
332
|
+
|
|
333
|
+
max_txt_len = cfg.get('max_txt_len', 35)
|
|
334
|
+
|
|
335
|
+
model = cls(
|
|
336
|
+
image_encoder=image_encoder,
|
|
337
|
+
text_encoder=text_encoder,
|
|
338
|
+
text_decoder=text_decoder,
|
|
339
|
+
max_txt_len=max_txt_len,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
model.load_checkpoint_from_config(cfg)
|
|
343
|
+
|
|
344
|
+
return model
|