evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,452 @@
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn as nn
11
+ from torch.cuda.amp import autocast as autocast
12
+ from torch.nn import functional as F
13
+
14
+ from ...common.registry import registry
15
+ from ..base_model import all_gather_with_grad, concat_all_gather
16
+ from ..blip_models.blip_outputs import BlipOutput, BlipOutputFeatures
17
+ from .blip2 import Blip2Base, compute_sim_matrix, disabled_train
18
+
19
+
20
+ @registry.register_model('blip2')
21
+ @registry.register_model('blip2_feature_extractor')
22
+ class Blip2Qformer(Blip2Base):
23
+ """
24
+ BLIP2 stage-1 model with Q-former and ViT.
25
+ Supported model types:
26
+ - pretrained: pretrained model with vit-g
27
+ - pretrain_vitL: pretrained model with vit-large
28
+ - coco: fintuned model on coco
29
+ Usage:
30
+ >>> from lavis.models import load_model
31
+ >>> model = load_model("blip2", "pretrain")
32
+ """
33
+
34
+ PRETRAINED_MODEL_CONFIG_DICT = {
35
+ 'pretrain': 'configs/models/blip2/blip2_pretrain.yaml',
36
+ 'pretrain_vitL': 'configs/models/blip2/blip2_pretrain_vitL.yaml',
37
+ 'coco': 'configs/models/blip2/blip2_coco.yaml',
38
+ }
39
+
40
+ def __init__(
41
+ self,
42
+ vit_model='eva_clip_g',
43
+ img_size=224,
44
+ drop_path_rate=0,
45
+ use_grad_checkpoint=False,
46
+ vit_precision='fp16',
47
+ freeze_vit=True,
48
+ num_query_token=32,
49
+ cross_attention_freq=2,
50
+ embed_dim=256,
51
+ max_txt_len=32,
52
+ ):
53
+ super().__init__()
54
+
55
+ self.tokenizer = self.init_tokenizer()
56
+
57
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate,
58
+ use_grad_checkpoint, vit_precision)
59
+ if freeze_vit:
60
+ for name, param in self.visual_encoder.named_parameters():
61
+ param.requires_grad = False
62
+ self.visual_encoder = self.visual_encoder.eval()
63
+ self.visual_encoder.train = disabled_train
64
+ logging.info('freeze vision encoder')
65
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, self.visual_encoder.num_features,
66
+ cross_attention_freq)
67
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
68
+ state_dict = self.Qformer.state_dict()
69
+ for name, param in self.Qformer.named_parameters():
70
+ if '_query' in name:
71
+ key_orig = name.replace('_query', '')
72
+ param.data.copy_(state_dict[key_orig])
73
+
74
+ self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
75
+ self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
76
+
77
+ self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2)
78
+
79
+ self.temp = nn.Parameter(0.07 * torch.ones([]))
80
+
81
+ self.max_txt_len = max_txt_len
82
+
83
+ def forward(self, samples):
84
+ image = samples['image']
85
+ text = samples['text_input']
86
+
87
+ image_embeds = self.ln_vision(self.visual_encoder(image))
88
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
89
+
90
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
91
+
92
+ query_output = self.Qformer.bert(
93
+ query_embeds=query_tokens,
94
+ encoder_hidden_states=image_embeds,
95
+ encoder_attention_mask=image_atts,
96
+ use_cache=True,
97
+ return_dict=True,
98
+ )
99
+
100
+ image_feats = F.normalize(self.vision_proj(query_output.last_hidden_state), dim=-1)
101
+
102
+ text_tokens = self.tokenizer(
103
+ text,
104
+ padding='max_length',
105
+ truncation=True,
106
+ max_length=self.max_txt_len,
107
+ return_tensors='pt',
108
+ ).to(image.device)
109
+ text_output = self.Qformer.bert(
110
+ text_tokens.input_ids,
111
+ attention_mask=text_tokens.attention_mask,
112
+ return_dict=True,
113
+ )
114
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
115
+
116
+ ###============== Image-text Contrastive ===================###
117
+ image_feats_all = concat_all_gather(image_feats) # [batch_size*num_gpu, num_query_tokens, embed_dim]
118
+ text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim]
119
+
120
+ sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
121
+ # [batch_size, batch_size*num_gpu, num_query_tokens]
122
+
123
+ # image-text similarity: aggregate across all query tokens
124
+ sim_i2t, _ = sim_q2t.max(-1)
125
+ sim_i2t = sim_i2t / self.temp
126
+
127
+ # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
128
+ sim_t2q = torch.matmul(text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze()
129
+
130
+ # text-image similarity: aggregate across all query tokens
131
+ sim_t2i, _ = sim_t2q.max(-1)
132
+ sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu]
133
+
134
+ rank = dist.get_rank()
135
+ bs = image.size(0)
136
+ targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image.device)
137
+
138
+ loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
139
+ + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2
140
+
141
+ ###============== Image-text Matching ===================###
142
+ text_input_ids_world = concat_all_gather(text_tokens.input_ids)
143
+ text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
144
+ image_embeds_world = all_gather_with_grad(image_embeds)
145
+ with torch.no_grad():
146
+ weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4
147
+ weights_t2i[:, rank * bs:rank * bs + bs].fill_diagonal_(0)
148
+ weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4
149
+ weights_i2t[:, rank * bs:rank * bs + bs].fill_diagonal_(0)
150
+
151
+ # select a negative image for each text
152
+ image_embeds_neg = []
153
+ for b in range(bs):
154
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
155
+ image_embeds_neg.append(image_embeds_world[neg_idx])
156
+ image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
157
+
158
+ # select a negative text for each image
159
+ text_ids_neg = []
160
+ text_atts_neg = []
161
+ for b in range(bs):
162
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
163
+ text_ids_neg.append(text_input_ids_world[neg_idx])
164
+ text_atts_neg.append(text_attention_mask_world[neg_idx])
165
+
166
+ text_ids_neg = torch.stack(text_ids_neg, dim=0)
167
+ text_atts_neg = torch.stack(text_atts_neg, dim=0)
168
+
169
+ text_ids_all = torch.cat([text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0) # pos, pos, neg
170
+ text_atts_all = torch.cat(
171
+ [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
172
+ dim=0,
173
+ )
174
+
175
+ query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
176
+ query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(image.device)
177
+ attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
178
+
179
+ image_embeds_all = torch.cat([image_embeds, image_embeds_neg, image_embeds], dim=0) # pos, neg, pos
180
+ image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(image.device)
181
+
182
+ output_itm = self.Qformer.bert(
183
+ text_ids_all,
184
+ query_embeds=query_tokens_itm,
185
+ attention_mask=attention_mask_all,
186
+ encoder_hidden_states=image_embeds_all,
187
+ encoder_attention_mask=image_atts_all,
188
+ return_dict=True,
189
+ )
190
+
191
+ vl_embeddings = output_itm.last_hidden_state[:, :query_tokens_itm.size(1), :]
192
+ vl_output = self.itm_head(vl_embeddings)
193
+ logits = vl_output.mean(dim=1)
194
+
195
+ itm_labels = torch.cat(
196
+ [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
197
+ dim=0,
198
+ ).to(image.device)
199
+ loss_itm = F.cross_entropy(logits, itm_labels)
200
+
201
+ ##================= Image Captioning ========================##
202
+ decoder_input_ids = text_tokens.input_ids.clone()
203
+ decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
204
+ labels = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
205
+
206
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
207
+ attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
208
+ lm_output = self.Qformer(
209
+ decoder_input_ids,
210
+ attention_mask=attention_mask,
211
+ past_key_values=query_output.past_key_values,
212
+ return_dict=True,
213
+ labels=labels,
214
+ )
215
+
216
+ loss_lm = lm_output.loss
217
+
218
+ return BlipOutput(
219
+ loss=loss_itc + loss_itm + loss_lm,
220
+ loss_itc=loss_itc,
221
+ loss_itm=loss_itm,
222
+ loss_lm=loss_lm,
223
+ )
224
+
225
+ @torch.no_grad()
226
+ def generate(
227
+ self,
228
+ samples,
229
+ use_nucleus_sampling=False,
230
+ num_beams=3,
231
+ max_length=30,
232
+ min_length=10,
233
+ top_p=0.9,
234
+ repetition_penalty=1.0,
235
+ ):
236
+ """
237
+ Args:
238
+ samples (dict): A dictionary containing the following keys:
239
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
240
+ use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
241
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
242
+ max_length (int): The maximum length of the sequence to be generated.
243
+ min_length (int): The minimum length of the sequence to be generated.
244
+ top_p (float): The cumulative probability for nucleus sampling.
245
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
246
+ num_captions (int): Number of captions to be generated for each image.
247
+ Returns:
248
+ captions (list): A list of strings of length batch_size * num_captions.
249
+ """
250
+ image = samples['image']
251
+ image_embeds = self.ln_vision(self.visual_encoder(image))
252
+
253
+ if not use_nucleus_sampling:
254
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
255
+ else:
256
+ num_beams = 1
257
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
258
+
259
+ model_kwargs = {
260
+ 'encoder_hidden_states': image_embeds,
261
+ 'encoder_attention_mask': image_atts,
262
+ }
263
+
264
+ input_ids = (torch.LongTensor(image.size(0), 1).fill_(self.tokenizer.bos_token_id).to(image.device))
265
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
266
+
267
+ outputs = self.Qformer.generate(
268
+ input_ids=input_ids,
269
+ query_embeds=query_tokens,
270
+ max_length=max_length,
271
+ min_length=min_length,
272
+ num_beams=num_beams,
273
+ do_sample=use_nucleus_sampling,
274
+ top_p=top_p,
275
+ eos_token_id=self.tokenizer.sep_token_id,
276
+ pad_token_id=self.tokenizer.pad_token_id,
277
+ **model_kwargs)
278
+ captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
279
+ return captions
280
+
281
+ def forward_image(self, image):
282
+ image_embeds = self.ln_vision(self.visual_encoder(image))
283
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
284
+
285
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
286
+
287
+ query_output = self.Qformer.bert(
288
+ query_embeds=query_tokens,
289
+ encoder_hidden_states=image_embeds,
290
+ encoder_attention_mask=image_atts,
291
+ return_dict=True,
292
+ )
293
+ return query_output.last_hidden_state, image_embeds
294
+
295
+ def forward_text(self, text_tokens):
296
+ text_output = self.Qformer.bert(
297
+ text_tokens.input_ids,
298
+ attention_mask=text_tokens.attention_mask,
299
+ return_dict=True,
300
+ )
301
+ return text_output.last_hidden_state[:, 0, :]
302
+
303
+ def compute_itm(self, image_inputs, text_ids, text_atts):
304
+ image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(image_inputs.device)
305
+ query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1)
306
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_inputs.device)
307
+ attention_mask = torch.cat([query_atts, text_atts], dim=1)
308
+ output_itm = self.Qformer.bert(
309
+ text_ids,
310
+ query_embeds=query_tokens,
311
+ attention_mask=attention_mask,
312
+ encoder_hidden_states=image_inputs,
313
+ encoder_attention_mask=image_atts,
314
+ return_dict=True,
315
+ )
316
+ vl_embeddings = output_itm.last_hidden_state[:, :query_tokens.size(1), :]
317
+ itm_logit = self.itm_head(vl_embeddings)
318
+ itm_logit = itm_logit[:, :, 1].mean(dim=1)
319
+ return itm_logit
320
+
321
+ @torch.no_grad()
322
+ def extract_features(self, samples, mode='multimodal'):
323
+ """
324
+ Extract features for multimodal or unimodal samples.
325
+ Args:
326
+ samples (dict): A dictionary of samples, containing the following keys:
327
+ - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.
328
+ Raw images should be preprocessed before being passed to feature extractor.
329
+ - text_input (list): A list of strings containing the text, length B.
330
+ mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image".
331
+ If "multimodal", return image features and multimodal features;
332
+ if "text", return text features;
333
+ if "image", return image features.
334
+ Default: "multimodal".
335
+ Returns:
336
+ BlipOutputFeatures: A BlipOutputFeatures object containing the features.
337
+ See lavis/models/blip_models/blip_outputs.py for more details.
338
+ """
339
+ image = samples.get('image')
340
+ caption = samples.get('text_input')
341
+
342
+ # assert mode is one of "image", "text", "multimodal"
343
+ assert mode in [
344
+ 'image',
345
+ 'text',
346
+ 'multimodal',
347
+ ], "mode must be one of 'image', 'text', 'multimodal'"
348
+
349
+ # initalize output
350
+ image_embeds, text_embeds, multimodal_embeds = None, None, None
351
+ image_features, text_features = None, None
352
+
353
+ if mode == 'image':
354
+ assert (image is not None), "Image is not provided for mode 'image' or 'multimodal'"
355
+ # return query features
356
+ with self.maybe_autocast():
357
+ image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
358
+ image_embeds_frozen = image_embeds_frozen.float()
359
+ image_atts = torch.ones(image_embeds_frozen.size()[:-1], dtype=torch.long).to(self.device)
360
+ query_tokens = self.query_tokens.expand(image_embeds_frozen.shape[0], -1, -1)
361
+
362
+ query_output = self.Qformer.bert(
363
+ query_embeds=query_tokens,
364
+ encoder_hidden_states=image_embeds_frozen,
365
+ encoder_attention_mask=image_atts,
366
+ return_dict=True,
367
+ )
368
+ image_embeds = query_output.last_hidden_state
369
+ image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)
370
+
371
+ elif mode == 'text':
372
+ assert (caption is not None), "text input is None for mode 'text' or 'multimodal'"
373
+
374
+ # return text features
375
+ text = self.tokenizer(caption, return_tensors='pt', padding=True).to(self.device)
376
+
377
+ text_output = self.Qformer.bert(
378
+ text.input_ids,
379
+ attention_mask=text.attention_mask,
380
+ return_dict=True,
381
+ )
382
+ text_embeds = text_output.last_hidden_state
383
+ text_features = self.text_proj(text_embeds)
384
+ text_features = F.normalize(text_features, dim=-1)
385
+
386
+ elif mode == 'multimodal':
387
+ # return multimodel query features
388
+ with self.maybe_autocast():
389
+ image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
390
+ image_embeds_frozen = image_embeds_frozen.float()
391
+ image_atts = torch.ones(image_embeds_frozen.size()[:-1], dtype=torch.long).to(self.device)
392
+ query_tokens = self.query_tokens.expand(image_embeds_frozen.shape[0], -1, -1)
393
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
394
+
395
+ text = self.tokenizer(caption, return_tensors='pt', padding=True).to(self.device)
396
+ attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
397
+
398
+ output = self.Qformer.bert(
399
+ text.input_ids,
400
+ query_embeds=query_tokens,
401
+ attention_mask=attention_mask,
402
+ encoder_hidden_states=image_embeds_frozen,
403
+ encoder_attention_mask=image_atts,
404
+ return_dict=True,
405
+ )
406
+
407
+ multimodal_embeds = output.last_hidden_state[:, :query_tokens.size(1), :]
408
+
409
+ return BlipOutputFeatures(
410
+ image_embeds=image_embeds,
411
+ image_embeds_proj=image_features,
412
+ text_embeds=text_embeds,
413
+ text_embeds_proj=text_features,
414
+ multimodal_embeds=multimodal_embeds,
415
+ )
416
+
417
+ @classmethod
418
+ def from_config(cls, cfg):
419
+ vit_model = cfg.get('vit_model', 'eva_clip_g')
420
+ img_size = cfg.get('image_size')
421
+ num_query_token = cfg.get('num_query_token')
422
+ cross_attention_freq = cfg.get('cross_attention_freq', 2)
423
+
424
+ drop_path_rate = cfg.get('drop_path_rate', 0)
425
+ use_grad_checkpoint = cfg.get('use_grad_checkpoint', False)
426
+ vit_precision = cfg.get('vit_precision', 'fp16')
427
+ freeze_vit = cfg.get('freeze_vit', True)
428
+
429
+ max_txt_len = cfg.get('max_txt_len', 32)
430
+
431
+ model = cls(
432
+ vit_model=vit_model,
433
+ img_size=img_size,
434
+ drop_path_rate=drop_path_rate,
435
+ use_grad_checkpoint=use_grad_checkpoint,
436
+ vit_precision=vit_precision,
437
+ freeze_vit=freeze_vit,
438
+ num_query_token=num_query_token,
439
+ cross_attention_freq=cross_attention_freq,
440
+ max_txt_len=max_txt_len,
441
+ )
442
+ model.load_checkpoint_from_config(cfg)
443
+
444
+ return model
445
+
446
+ def compute_sim_matrix(self, data_loader, task_cfg):
447
+ """
448
+ Compute similarity i2t, t2i matrix for the given data loader.
449
+ """
450
+ k_test = task_cfg.k_test
451
+
452
+ return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)