evalscope 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (178) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  24. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  25. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  26. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  27. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  28. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  29. evalscope/benchmarks/utils.py +7 -16
  30. evalscope/cli/start_app.py +1 -1
  31. evalscope/collections/evaluator.py +16 -4
  32. evalscope/config.py +7 -3
  33. evalscope/constants.py +11 -0
  34. evalscope/evaluator/evaluator.py +2 -2
  35. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  36. evalscope/metrics/__init__.py +49 -4
  37. evalscope/metrics/llm_judge.py +1 -1
  38. evalscope/metrics/named_metrics.py +13 -0
  39. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  40. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  41. evalscope/metrics/t2v_metrics/constants.py +12 -0
  42. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  43. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  44. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  53. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  60. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  61. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  62. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  138. evalscope/metrics/t2v_metrics/score.py +78 -0
  139. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  140. evalscope/models/__init__.py +50 -14
  141. evalscope/models/adapters/__init__.py +17 -0
  142. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  143. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  144. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  145. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  146. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  147. evalscope/models/adapters/t2i_adapter.py +76 -0
  148. evalscope/models/custom/__init__.py +2 -1
  149. evalscope/models/custom/dummy_model.py +11 -13
  150. evalscope/models/local_model.py +82 -33
  151. evalscope/models/model.py +2 -42
  152. evalscope/models/register.py +26 -0
  153. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  154. evalscope/perf/utils/benchmark_util.py +2 -2
  155. evalscope/perf/utils/db_util.py +8 -2
  156. evalscope/report/__init__.py +1 -0
  157. evalscope/report/app.py +117 -67
  158. evalscope/report/app_arguments.py +11 -0
  159. evalscope/report/generator.py +1 -1
  160. evalscope/run.py +3 -3
  161. evalscope/third_party/thinkbench/eval.py +19 -7
  162. evalscope/utils/chat_service.py +2 -2
  163. evalscope/utils/import_utils.py +66 -0
  164. evalscope/utils/utils.py +12 -4
  165. evalscope/version.py +2 -2
  166. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/METADATA +18 -1
  167. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/RECORD +175 -63
  168. tests/aigc/__init__.py +1 -0
  169. tests/aigc/test_t2i.py +87 -0
  170. tests/cli/test_run.py +11 -5
  171. tests/perf/test_perf.py +2 -1
  172. evalscope/metrics/code_metric.py +0 -98
  173. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  174. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  175. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
  176. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
  177. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,344 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from ...common.registry import registry
12
+ from ..base_model import tile
13
+ from ..med import XBertEncoder, XBertLMHeadDecoder
14
+ from ..vit import VisionTransformerEncoder
15
+ from .blip import BlipBase
16
+ from .blip_outputs import BlipIntermediateOutput, BlipOutput
17
+
18
+
19
+ @registry.register_model('blip_vqa')
20
+ class BlipVQA(BlipBase):
21
+ """
22
+ BLIP VQA models.
23
+
24
+ Supported model types:
25
+ - base: vqa model initialized with pre-trained BLIP base model on 115M image-text pairs after CapFilt; not fine-tuned.
26
+ - vqav2: fine-tuned BLIP base model on VQA v2.0 dataset.
27
+
28
+ Usage:
29
+ >>> from lavis.models import load_model
30
+ >>> model = load_model("blip_vqa", "vqav2")
31
+ >>> model = load_model("blip_vqa", "okvqa")
32
+ >>> model = load_model("blip_vqa", "aokvqa")
33
+ """
34
+
35
+ PRETRAINED_MODEL_CONFIG_DICT = {
36
+ 'vqav2': 'configs/models/blip_vqav2.yaml',
37
+ 'okvqa': 'configs/models/blip_vqa_okvqa.yaml',
38
+ 'aokvqa': 'configs/models/blip_vqa_aokvqa.yaml',
39
+ }
40
+
41
+ def __init__(self, image_encoder, text_encoder, text_decoder, max_txt_len=35):
42
+ super().__init__()
43
+ self.tokenizer = self.init_tokenizer()
44
+
45
+ self.visual_encoder = image_encoder
46
+
47
+ self.text_encoder = text_encoder
48
+ self.text_decoder = text_decoder
49
+
50
+ self.max_txt_len = max_txt_len
51
+
52
+ def forward(self, samples):
53
+ """
54
+ Args:
55
+ samples (dict): A dictionary containing the following keys:
56
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.
57
+ - text_input (list): A list of strings, each string is a question
58
+ - answer (list): A list of strings, each string is an answer
59
+ - weight (torch.Tensor): A tensor used to weigh each answer in the loss computation.
60
+ The shape of the tensor is (sum(n_answers),)
61
+ - n_answers (torch.Tensor): A tensor shape (batch_size,) containing the number of answers
62
+ for each question in the batch.
63
+
64
+ Returns:
65
+ A BlipOutput object containing loss and intermediate outputs,
66
+ see :class:`lavis.models.blip_outputs.BlipOutput` for more details.
67
+
68
+ Examples:
69
+ ```python
70
+ >>> import torch
71
+ >>> from lavis.models import load_model
72
+ >>> model = load_model("blip_vqa")
73
+ >>> samples = {
74
+ ... "image": torch.rand(2, 3, 480, 480),
75
+ ... "text_input": ["What is this?", "What is that?"],
76
+ ... "answer": ["cat", "cat", "dog"],
77
+ ... "weight": torch.tensor([1.0, 1.0, 1.0]),
78
+ ... "n_answers": torch.tensor([2, 1]),
79
+ ... }
80
+ >>> output = model(samples)
81
+ >>> output.keys()
82
+ odict_keys(['intermediate_output', 'loss'])
83
+ >>> output.intermediate_output.keys()
84
+ odict_keys(['image_embeds', 'encoder_output', 'decoder_output', 'decoder_labels'])
85
+ ```
86
+ """
87
+ encoder_output, image_embeds = self.forward_encoder(samples)
88
+ loss, decoder_output, decoder_targets = self.forward_decoder(samples=samples, encoder_out=encoder_output)
89
+
90
+ return BlipOutput(
91
+ loss=loss,
92
+ intermediate_output=BlipIntermediateOutput(
93
+ image_embeds=image_embeds,
94
+ encoder_output=encoder_output,
95
+ decoder_output=decoder_output,
96
+ decoder_labels=decoder_targets,
97
+ ),
98
+ )
99
+
100
+ def forward_encoder(self, samples):
101
+ questions = samples['text_input']
102
+ questions = self.tokenizer(
103
+ questions,
104
+ padding='longest',
105
+ truncation=True,
106
+ max_length=self.max_txt_len,
107
+ return_tensors='pt',
108
+ ).to(self.device)
109
+ questions.input_ids[:, 0] = self.tokenizer.enc_token_id
110
+ samples.update({'tokenized_text': questions})
111
+
112
+ image_embeds = self.visual_encoder.forward_features(samples['image'])
113
+ encoder_output = self.text_encoder.forward_automask(
114
+ tokenized_text=samples['tokenized_text'], visual_embeds=image_embeds)
115
+
116
+ return encoder_output, image_embeds
117
+
118
+ def forward_decoder(self, samples, encoder_out, **kwargs):
119
+ answers = self.tokenizer(samples['answer'], padding='longest', return_tensors='pt').to(self.device)
120
+ answers.input_ids[:, 0] = self.tokenizer.bos_token_id
121
+ answer_targets = answers.input_ids.masked_fill(answers.input_ids == self.tokenizer.pad_token_id, -100)
122
+
123
+ question_states = []
124
+ question_atts = []
125
+
126
+ question = samples['tokenized_text']
127
+ question_output = encoder_out
128
+
129
+ for b, n in enumerate(samples['n_answers']):
130
+ question_states += [question_output.last_hidden_state[b]] * n
131
+ question_atts += [question.attention_mask[b]] * n
132
+
133
+ question_states = torch.stack(question_states, dim=0)
134
+ question_atts = torch.stack(question_atts, dim=0)
135
+
136
+ answer_output = self.text_decoder(
137
+ answers.input_ids,
138
+ attention_mask=answers.attention_mask,
139
+ encoder_hidden_states=question_states,
140
+ encoder_attention_mask=question_atts,
141
+ labels=answer_targets,
142
+ return_dict=True,
143
+ reduction='none',
144
+ )
145
+
146
+ loss = samples['weight'] * answer_output.loss
147
+ bsz = samples['image'].size(0)
148
+
149
+ loss = loss.sum() / bsz
150
+
151
+ return loss, answer_output, answer_targets
152
+
153
+ def predict_answers(self,
154
+ samples,
155
+ num_beams=3,
156
+ inference_method='rank',
157
+ max_len=10,
158
+ min_len=1,
159
+ num_ans_candidates=128,
160
+ answer_list=None,
161
+ **kwargs):
162
+ """
163
+ Args:
164
+ samples (dict): A dictionary containing the following keys:
165
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.
166
+ - text_input (str or [str]): String or a list of strings, each string is a question.
167
+ The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.
168
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
169
+ inference_method (str): Inference method. One of "rank", "generate".
170
+ - If "rank", the model will return answers with the highest probability from the answer list.
171
+ - If "generate", the model will generate answers.
172
+ max_len (int): Maximum length of generated answers.
173
+ min_len (int): Minimum length of generated answers.
174
+ num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability.
175
+ answer_list (list): A list of strings, each string is an answer.
176
+
177
+ Returns:
178
+ List: A list of strings, each string is an answer.
179
+
180
+ Examples:
181
+ ```python
182
+ >>> from PIL import Image
183
+ >>> from lavis.models import load_model_and_preprocess
184
+ >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_vqa", "vqav2")
185
+ >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
186
+ >>> question = "Which city is this photo taken?"
187
+ >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
188
+ >>> question = txt_processors["eval"](question)
189
+ >>> samples = {"image": image, "text_input": [question]}
190
+ >>> answers = model.predict_answers(samples)
191
+ >>> answers
192
+ ['singapore']
193
+ >>> answer_list = ["Singapore", "London", "Palo Alto", "Tokyo"]
194
+ >>> answers = model.predict_answers(samples, answer_list=answer_list)
195
+ >>> answers
196
+ ['Singapore']
197
+ ```
198
+ """
199
+ assert inference_method in [
200
+ 'rank',
201
+ 'generate',
202
+ ], "Inference method must be one of 'rank' or 'generate', got {}.".format(inference_method)
203
+
204
+ if isinstance(samples['text_input'], str):
205
+ samples['text_input'] = [samples['text_input']]
206
+
207
+ assert len(samples['text_input']) == samples['image'].size(
208
+ 0), 'The number of questions must be equal to the batch size.'
209
+
210
+ if inference_method == 'generate':
211
+ return self._generate_answers(samples, num_beams=num_beams, max_length=max_len, min_length=min_len)
212
+ elif inference_method == 'rank':
213
+ assert answer_list is not None, 'answer_list must be provided for ranking'
214
+
215
+ num_ans_candidates = min(num_ans_candidates, len(answer_list))
216
+
217
+ return self._rank_answers(samples, answer_list=answer_list, num_ans_candidates=num_ans_candidates)
218
+
219
+ def _generate_answers(self, samples, num_beams=3, max_length=10, min_length=1):
220
+ encoder_out, _ = self.forward_encoder(samples)
221
+
222
+ question_output = encoder_out
223
+
224
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams, dim=0)
225
+ question_atts = torch.ones(question_states.size()[:-1], dtype=torch.long).to(self.device)
226
+
227
+ model_kwargs = {
228
+ 'encoder_hidden_states': question_states,
229
+ 'encoder_attention_mask': question_atts,
230
+ }
231
+
232
+ bsz = samples['image'].size(0)
233
+ bos_ids = torch.full((bsz, 1), fill_value=self.tokenizer.bos_token_id, device=self.device)
234
+
235
+ outputs = self.text_decoder.generate(
236
+ input_ids=bos_ids,
237
+ max_length=max_length,
238
+ min_length=min_length,
239
+ num_beams=num_beams,
240
+ eos_token_id=self.tokenizer.sep_token_id,
241
+ pad_token_id=self.tokenizer.pad_token_id,
242
+ **model_kwargs)
243
+
244
+ # collect answers
245
+ answers = []
246
+ for output in outputs:
247
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
248
+ answers.append(answer)
249
+
250
+ return answers
251
+
252
+ def _rank_answers(self, samples, answer_list, num_ans_candidates):
253
+ """
254
+ Generate the first token of answers using decoder and select ${num_ans_candidates}
255
+ most probable ones. Then select answers from answer list, which start with the probable tokens.
256
+ Lastly, use the selected answers as the ground-truth labels for decoding and calculating LM loss.
257
+ Return the answers that minimize the losses as result.
258
+
259
+ """
260
+ answer_candidates = self.tokenizer(answer_list, padding='longest', return_tensors='pt').to(self.device)
261
+ answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id
262
+
263
+ answer_ids = answer_candidates.input_ids
264
+ answer_atts = answer_candidates.attention_mask
265
+
266
+ question_output, _ = self.forward_encoder(samples)
267
+ question_states = question_output.last_hidden_state
268
+
269
+ tokenized_question = samples['tokenized_text']
270
+ question_atts = tokenized_question.attention_mask
271
+
272
+ num_ques = question_states.size(0)
273
+ start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token
274
+
275
+ start_output = self.text_decoder(
276
+ start_ids,
277
+ encoder_hidden_states=question_states,
278
+ encoder_attention_mask=question_atts,
279
+ return_dict=True,
280
+ reduction='none',
281
+ )
282
+ logits = start_output.logits[:, 0, :] # first token's logit
283
+
284
+ # topk_probs: top-k probability
285
+ # topk_ids: [num_question, k]
286
+ answer_first_token = answer_ids[:, 1]
287
+ prob_first_token = F.softmax(logits, dim=1).index_select(dim=1, index=answer_first_token)
288
+ topk_probs, topk_ids = prob_first_token.topk(num_ans_candidates, dim=1)
289
+
290
+ # answer input: [num_question*k, answer_len]
291
+ input_ids = []
292
+ input_atts = []
293
+ for b, topk_id in enumerate(topk_ids):
294
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
295
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
296
+ input_ids = torch.cat(input_ids, dim=0)
297
+ input_atts = torch.cat(input_atts, dim=0)
298
+
299
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
300
+
301
+ # repeat encoder's output for top-k answers
302
+ question_states = tile(question_states, 0, num_ans_candidates)
303
+ question_atts = tile(question_atts, 0, num_ans_candidates)
304
+
305
+ output = self.text_decoder(
306
+ input_ids,
307
+ attention_mask=input_atts,
308
+ encoder_hidden_states=question_states,
309
+ encoder_attention_mask=question_atts,
310
+ labels=targets_ids,
311
+ return_dict=True,
312
+ reduction='none',
313
+ )
314
+
315
+ log_probs_sum = -output.loss
316
+ log_probs_sum = log_probs_sum.view(num_ques, num_ans_candidates)
317
+
318
+ max_topk_ids = log_probs_sum.argmax(dim=1)
319
+ max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]
320
+
321
+ answers = [answer_list[max_id] for max_id in max_ids]
322
+
323
+ return answers
324
+
325
+ @classmethod
326
+ def from_config(cls, cfg=None):
327
+ image_encoder = VisionTransformerEncoder.from_config(cfg)
328
+
329
+ # text encoder + multimodal encoder
330
+ text_encoder = XBertEncoder.from_config(cfg)
331
+ text_decoder = XBertLMHeadDecoder.from_config(cfg)
332
+
333
+ max_txt_len = cfg.get('max_txt_len', 35)
334
+
335
+ model = cls(
336
+ image_encoder=image_encoder,
337
+ text_encoder=text_encoder,
338
+ text_decoder=text_decoder,
339
+ max_txt_len=max_txt_len,
340
+ )
341
+
342
+ model.load_checkpoint_from_config(cfg)
343
+
344
+ return model