evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,364 @@
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.cuda.amp import autocast as autocast
11
+ from transformers import T5TokenizerFast
12
+
13
+ from ...common.registry import registry
14
+ from .blip2 import Blip2Base, disabled_train
15
+ from .modeling_t5 import T5Config, T5ForConditionalGeneration
16
+
17
+
18
+ @registry.register_model('blip2_t5')
19
+ class Blip2T5(Blip2Base):
20
+ """
21
+ BLIP2 T5 model.
22
+ Supported model types:
23
+ - pretrain_flant5xl: pretrained model with FlanT5-XL
24
+ - pretrain_flant5xl_vitL: pretrained model with FlanT5-XL
25
+ - pretrain_flant5xxl: pretrained model with FlanT5-XXL
26
+ - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL
27
+ Usage:
28
+ >>> from lavis.models import load_model
29
+ >>> model = load_model("blip2_t5", "pretrain_flant5xl")
30
+ """
31
+
32
+ PRETRAINED_MODEL_CONFIG_DICT = {
33
+ 'pretrain_flant5xl':
34
+ 'configs/models/blip2/blip2_pretrain_flant5xl.yaml',
35
+ 'pretrain_flant5xl_vitL':
36
+ 'configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml',
37
+ 'pretrain_flant5xxl':
38
+ 'configs/models/blip2/blip2_pretrain_flant5xxl.yaml',
39
+ 'caption_coco_flant5xl':
40
+ 'configs/models/blip2/blip2_caption_flant5xl.yaml',
41
+ # Added by ZQ
42
+ 'pretrain_flant5xl_iter_80k_total_100k_prefix':
43
+ 'configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml',
44
+ 'pretrain_flant5xl_iter_80k_total_100k_no_prefix':
45
+ 'configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml',
46
+ }
47
+
48
+ def __init__(
49
+ self,
50
+ vit_model='eva_clip_g',
51
+ img_size=224,
52
+ drop_path_rate=0,
53
+ use_grad_checkpoint=False,
54
+ vit_precision='fp16',
55
+ freeze_vit=True,
56
+ num_query_token=32,
57
+ t5_model='google/flan-t5-xl',
58
+ prompt='',
59
+ max_txt_len=32,
60
+ apply_lemmatizer=False,
61
+ ):
62
+ """
63
+ apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
64
+ """
65
+ super().__init__()
66
+
67
+ self.tokenizer = self.init_tokenizer()
68
+
69
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate,
70
+ use_grad_checkpoint, vit_precision)
71
+ if freeze_vit:
72
+ for name, param in self.visual_encoder.named_parameters():
73
+ param.requires_grad = False
74
+ self.visual_encoder = self.visual_encoder.eval()
75
+ self.visual_encoder.train = disabled_train
76
+ logging.info('freeze vision encoder')
77
+
78
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, self.visual_encoder.num_features)
79
+ self.Qformer.cls = None
80
+ self.Qformer.bert.embeddings.word_embeddings = None
81
+ self.Qformer.bert.embeddings.position_embeddings = None
82
+ for layer in self.Qformer.bert.encoder.layer:
83
+ layer.output = None
84
+ layer.intermediate = None
85
+
86
+ self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)
87
+ t5_config = T5Config.from_pretrained(t5_model)
88
+ t5_config.dense_act_fn = 'gelu'
89
+ self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model, config=t5_config)
90
+
91
+ for name, param in self.t5_model.named_parameters():
92
+ param.requires_grad = False
93
+ param.data = param.data.bfloat16()
94
+
95
+ self.t5_proj = nn.Linear(self.Qformer.config.hidden_size, self.t5_model.config.hidden_size)
96
+
97
+ self.max_txt_len = max_txt_len
98
+ self.prompt = prompt
99
+
100
+ self._apply_lemmatizer = apply_lemmatizer
101
+ self._lemmatizer = None
102
+
103
+ def forward(self, samples):
104
+ image = samples['image']
105
+
106
+ with self.maybe_autocast():
107
+ image_embeds = self.ln_vision(self.visual_encoder(image))
108
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
109
+
110
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
111
+ query_output = self.Qformer.bert(
112
+ query_embeds=query_tokens,
113
+ encoder_hidden_states=image_embeds,
114
+ encoder_attention_mask=image_atts,
115
+ return_dict=True,
116
+ )
117
+
118
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state)
119
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
120
+
121
+ with self.maybe_autocast(dtype=torch.bfloat16):
122
+ input_tokens = self.t5_tokenizer(
123
+ samples['text_input'],
124
+ padding='longest',
125
+ truncation=True,
126
+ max_length=self.max_txt_len,
127
+ return_tensors='pt',
128
+ ).to(image.device)
129
+ output_tokens = self.t5_tokenizer(
130
+ samples['text_output'],
131
+ padding='longest',
132
+ truncation=True,
133
+ max_length=self.max_txt_len,
134
+ return_tensors='pt',
135
+ ).to(image.device)
136
+
137
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
138
+
139
+ targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id,
140
+ -100)
141
+
142
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
143
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
144
+
145
+ outputs = self.t5_model(
146
+ inputs_embeds=inputs_embeds,
147
+ attention_mask=encoder_atts,
148
+ decoder_attention_mask=output_tokens.attention_mask,
149
+ return_dict=True,
150
+ labels=targets,
151
+ )
152
+ loss = outputs.loss
153
+
154
+ return {'loss': loss}
155
+
156
+ @torch.no_grad()
157
+ def generate(
158
+ self,
159
+ samples,
160
+ use_nucleus_sampling=False,
161
+ num_beams=5,
162
+ max_length=30,
163
+ min_length=1,
164
+ top_p=0.9,
165
+ repetition_penalty=1.0,
166
+ length_penalty=1.0,
167
+ num_captions=1,
168
+ temperature=1,
169
+ ):
170
+ """
171
+ Args:
172
+ samples (dict): A dictionary containing the following keys:
173
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
174
+ use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
175
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
176
+ max_length (int): The maximum length of the sequence to be generated.
177
+ min_length (int): The minimum length of the sequence to be generated.
178
+ top_p (float): The cumulative probability for nucleus sampling.
179
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
180
+ num_captions (int): Number of captions to be generated for each image.
181
+ Returns:
182
+ captions (list): A list of strings of length batch_size * num_captions.
183
+ """
184
+ image = samples['image']
185
+
186
+ with self.maybe_autocast():
187
+ image_embeds = self.ln_vision(self.visual_encoder(image))
188
+ image_embeds = image_embeds.float()
189
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
190
+
191
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
192
+ query_output = self.Qformer.bert(
193
+ query_embeds=query_tokens,
194
+ encoder_hidden_states=image_embeds,
195
+ encoder_attention_mask=image_atts,
196
+ return_dict=True,
197
+ )
198
+
199
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state)
200
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
201
+
202
+ if 'prompt' in samples.keys():
203
+ prompt = samples['prompt']
204
+ else:
205
+ prompt = self.prompt
206
+
207
+ if isinstance(prompt, str):
208
+ prompt = [prompt] * image.size(0)
209
+ else:
210
+ assert len(prompt) == image.size(0), 'The number of prompts must be equal to the batch size.'
211
+
212
+ input_tokens = self.t5_tokenizer(prompt, padding='longest', return_tensors='pt').to(image.device)
213
+
214
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
215
+
216
+ with self.maybe_autocast(dtype=torch.bfloat16):
217
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
218
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
219
+
220
+ outputs = self.t5_model.generate(
221
+ inputs_embeds=inputs_embeds,
222
+ attention_mask=encoder_atts,
223
+ do_sample=use_nucleus_sampling,
224
+ top_p=top_p,
225
+ temperature=temperature,
226
+ num_beams=num_beams,
227
+ max_new_tokens=max_length,
228
+ min_length=min_length,
229
+ repetition_penalty=repetition_penalty,
230
+ length_penalty=length_penalty,
231
+ num_return_sequences=num_captions,
232
+ )
233
+ output_text = self.t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
234
+
235
+ return output_text
236
+
237
+ def predict_answers(self,
238
+ samples,
239
+ num_beams=5,
240
+ inference_method='generate',
241
+ max_len=10,
242
+ min_len=1,
243
+ num_ans_candidates=128,
244
+ answer_list=None,
245
+ prompt='',
246
+ length_penalty=-1,
247
+ **kwargs):
248
+ image = samples['image']
249
+ with self.maybe_autocast():
250
+ image_embeds = self.ln_vision(self.visual_encoder(image))
251
+ image_embeds = image_embeds.float()
252
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
253
+
254
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
255
+ query_output = self.Qformer.bert(
256
+ query_embeds=query_tokens,
257
+ encoder_hidden_states=image_embeds,
258
+ encoder_attention_mask=image_atts,
259
+ return_dict=True,
260
+ )
261
+
262
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state)
263
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
264
+
265
+ if isinstance(samples['text_input'], str):
266
+ samples['text_input'] = [samples['text_input']]
267
+ if prompt:
268
+ text_input = [prompt.format(question) for question in samples['text_input']]
269
+ else:
270
+ text_input = samples['text_input']
271
+
272
+ input_tokens = self.t5_tokenizer(text_input, padding='longest', return_tensors='pt').to(image.device)
273
+
274
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
275
+
276
+ with self.maybe_autocast(dtype=torch.bfloat16):
277
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
278
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
279
+
280
+ outputs = self.t5_model.generate(
281
+ inputs_embeds=inputs_embeds,
282
+ attention_mask=encoder_atts,
283
+ do_sample=False,
284
+ num_beams=num_beams,
285
+ max_new_tokens=max_len,
286
+ min_length=min_len,
287
+ length_penalty=length_penalty,
288
+ )
289
+ output_text = self.t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
290
+
291
+ if self._apply_lemmatizer:
292
+ output_text = self._lemmatize(output_text)
293
+
294
+ return output_text
295
+
296
+ def _lemmatize(self, answers):
297
+
298
+ def apply(answer):
299
+ doc = self.lemmatizer(answer)
300
+
301
+ words = []
302
+ for token in doc:
303
+ if token.pos_ in ['NOUN', 'VERB']:
304
+ words.append(token.lemma_)
305
+ else:
306
+ words.append(token.text)
307
+ answer = ' '.join(words)
308
+
309
+ return answer
310
+
311
+ return [apply(answer) for answer in answers]
312
+
313
+ @property
314
+ def lemmatizer(self):
315
+ if self._lemmatizer is None:
316
+ try:
317
+ import spacy
318
+
319
+ self._lemmatizer = spacy.load('en_core_web_sm')
320
+ except ImportError:
321
+ logging.error("""
322
+ Please install spacy and en_core_web_sm model to apply lemmatization.
323
+ python -m spacy download en_core_web_sm
324
+ OR
325
+ import spacy.cli
326
+ spacy.cli.download("en_core_web_sm")
327
+ """)
328
+ exit(1)
329
+
330
+ return self._lemmatizer
331
+
332
+ @classmethod
333
+ def from_config(cls, cfg):
334
+ vit_model = cfg.get('vit_model', 'eva_clip_g')
335
+ img_size = cfg.get('image_size')
336
+ num_query_token = cfg.get('num_query_token')
337
+ t5_model = cfg.get('t5_model')
338
+
339
+ drop_path_rate = cfg.get('drop_path_rate', 0)
340
+ use_grad_checkpoint = cfg.get('use_grad_checkpoint', False)
341
+ vit_precision = cfg.get('vit_precision', 'fp16')
342
+ freeze_vit = cfg.get('freeze_vit', True)
343
+
344
+ prompt = cfg.get('prompt', '')
345
+ max_txt_len = cfg.get('max_txt_len', 32)
346
+
347
+ apply_lemmatizer = cfg.get('apply_lemmatizer', False)
348
+
349
+ model = cls(
350
+ vit_model=vit_model,
351
+ img_size=img_size,
352
+ drop_path_rate=drop_path_rate,
353
+ use_grad_checkpoint=use_grad_checkpoint,
354
+ vit_precision=vit_precision,
355
+ freeze_vit=freeze_vit,
356
+ num_query_token=num_query_token,
357
+ t5_model=t5_model,
358
+ prompt=prompt,
359
+ max_txt_len=max_txt_len,
360
+ apply_lemmatizer=apply_lemmatizer,
361
+ )
362
+ model.load_checkpoint_from_config(cfg)
363
+
364
+ return model