evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import datetime
9
+ import logging
10
+ import os
11
+ import time
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from modelscope import snapshot_download
17
+ from transformers import BertTokenizer
18
+
19
+ from ...common import dist_utils
20
+ from ...common.dist_utils import download_cached_file
21
+ from ...common.logger import MetricLogger
22
+ from ...common.utils import is_url
23
+ from ..base_model import BaseModel
24
+ from ..clip_vit import create_clip_vit_L
25
+ from ..eva_vit import create_eva_vit_g
26
+ from .Qformer import BertConfig, BertLMHeadModel
27
+
28
+
29
+ class Blip2Base(BaseModel):
30
+
31
+ @classmethod
32
+ def init_tokenizer(cls, truncation_side='right'):
33
+ bert_path = snapshot_download('AI-ModelScope/bert-base-uncased')
34
+ tokenizer = BertTokenizer.from_pretrained(bert_path, truncation_side=truncation_side)
35
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
36
+ return tokenizer
37
+
38
+ def maybe_autocast(self, dtype=torch.float16):
39
+ # if on cpu, don't use autocast
40
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
41
+ enable_autocast = self.device != torch.device('cpu')
42
+
43
+ if enable_autocast:
44
+ return torch.amp.autocast(device_type=self.device.type, dtype=dtype)
45
+ else:
46
+ return contextlib.nullcontext()
47
+
48
+ @classmethod
49
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
50
+ bert_path = snapshot_download('AI-ModelScope/bert-base-uncased')
51
+ encoder_config = BertConfig.from_pretrained(bert_path)
52
+ encoder_config.encoder_width = vision_width
53
+ encoder_config.vocab_size += 1 # add one for [DEC]
54
+ # insert cross-attention layer every other block
55
+ encoder_config.add_cross_attention = True
56
+ encoder_config.cross_attention_freq = cross_attention_freq
57
+ encoder_config.query_length = num_query_token
58
+ Qformer = BertLMHeadModel._from_config(encoder_config)
59
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
60
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
61
+ return Qformer, query_tokens
62
+
63
+ @classmethod
64
+ def init_vision_encoder(cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision):
65
+ assert model_name in [
66
+ 'eva_clip_g',
67
+ 'clip_L',
68
+ ], 'vit model must be eva_clip_g or clip_L'
69
+ if model_name == 'eva_clip_g':
70
+ visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision)
71
+ elif model_name == 'clip_L':
72
+ visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
73
+ ln_vision = LayerNorm(visual_encoder.num_features)
74
+ return visual_encoder, ln_vision
75
+
76
+ def load_from_pretrained(self, url_or_filename):
77
+ if is_url(url_or_filename):
78
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
79
+ checkpoint = torch.load(cached_file, map_location='cpu')
80
+ elif os.path.isfile(url_or_filename):
81
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
82
+ else:
83
+ raise RuntimeError('checkpoint url or path is invalid')
84
+
85
+ state_dict = checkpoint['model']
86
+
87
+ msg = self.load_state_dict(state_dict, strict=False)
88
+
89
+ # logging.info("Missing keys {}".format(msg.missing_keys))
90
+ logging.info('load checkpoint from %s' % url_or_filename)
91
+
92
+ return msg
93
+
94
+
95
+ def disabled_train(self, mode=True):
96
+ """Overwrite model.train with this function to make sure train/eval mode
97
+ does not change anymore."""
98
+ return self
99
+
100
+
101
+ class LayerNorm(nn.LayerNorm):
102
+ """Subclass torch's LayerNorm to handle fp16."""
103
+
104
+ def forward(self, x: torch.Tensor):
105
+ orig_type = x.dtype
106
+ ret = super().forward(x.type(torch.float32))
107
+ return ret.type(orig_type)
108
+
109
+
110
+ def compute_sim_matrix(model, data_loader, **kwargs):
111
+ k_test = kwargs.pop('k_test')
112
+
113
+ metric_logger = MetricLogger(delimiter=' ')
114
+ header = 'Evaluation:'
115
+
116
+ logging.info('Computing features for evaluation...')
117
+ start_time = time.time()
118
+
119
+ texts = data_loader.dataset.text
120
+ num_text = len(texts)
121
+ text_bs = 256
122
+ text_ids = []
123
+ text_embeds = []
124
+ text_atts = []
125
+ for i in range(0, num_text, text_bs):
126
+ text = texts[i:min(num_text, i + text_bs)]
127
+ text_input = model.tokenizer(
128
+ text,
129
+ padding='max_length',
130
+ truncation=True,
131
+ max_length=35,
132
+ return_tensors='pt',
133
+ ).to(model.device)
134
+ text_feat = model.forward_text(text_input)
135
+ text_embed = F.normalize(model.text_proj(text_feat))
136
+ text_embeds.append(text_embed)
137
+ text_ids.append(text_input.input_ids)
138
+ text_atts.append(text_input.attention_mask)
139
+
140
+ text_embeds = torch.cat(text_embeds, dim=0)
141
+ text_ids = torch.cat(text_ids, dim=0)
142
+ text_atts = torch.cat(text_atts, dim=0)
143
+
144
+ vit_feats = []
145
+ image_embeds = []
146
+ for samples in data_loader:
147
+ image = samples['image']
148
+
149
+ image = image.to(model.device)
150
+ image_feat, vit_feat = model.forward_image(image)
151
+ image_embed = model.vision_proj(image_feat)
152
+ image_embed = F.normalize(image_embed, dim=-1)
153
+
154
+ vit_feats.append(vit_feat.cpu())
155
+ image_embeds.append(image_embed)
156
+
157
+ vit_feats = torch.cat(vit_feats, dim=0)
158
+ image_embeds = torch.cat(image_embeds, dim=0)
159
+
160
+ sims_matrix = []
161
+ for image_embed in image_embeds:
162
+ sim_q2t = image_embed @ text_embeds.t()
163
+ sim_i2t, _ = sim_q2t.max(0)
164
+ sims_matrix.append(sim_i2t)
165
+ sims_matrix = torch.stack(sims_matrix, dim=0)
166
+
167
+ score_matrix_i2t = torch.full((len(data_loader.dataset.image), len(texts)), -100.0).to(model.device)
168
+
169
+ num_tasks = dist_utils.get_world_size()
170
+ rank = dist_utils.get_rank()
171
+ step = sims_matrix.size(0) // num_tasks + 1
172
+ start = rank * step
173
+ end = min(sims_matrix.size(0), start + step)
174
+
175
+ for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
176
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
177
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
178
+ score = model.compute_itm(
179
+ image_inputs=image_inputs,
180
+ text_ids=text_ids[topk_idx],
181
+ text_atts=text_atts[topk_idx],
182
+ ).float()
183
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
184
+
185
+ sims_matrix = sims_matrix.t()
186
+ score_matrix_t2i = torch.full((len(texts), len(data_loader.dataset.image)), -100.0).to(model.device)
187
+
188
+ step = sims_matrix.size(0) // num_tasks + 1
189
+ start = rank * step
190
+ end = min(sims_matrix.size(0), start + step)
191
+
192
+ for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
193
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
194
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
195
+ score = model.compute_itm(
196
+ image_inputs=image_inputs,
197
+ text_ids=text_ids[start + i].repeat(k_test, 1),
198
+ text_atts=text_atts[start + i].repeat(k_test, 1),
199
+ ).float()
200
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
201
+
202
+ if dist_utils.is_dist_avail_and_initialized():
203
+ dist.barrier()
204
+ torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
205
+ torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
206
+
207
+ total_time = time.time() - start_time
208
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
209
+ logging.info('Evaluation time {}'.format(total_time_str))
210
+
211
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
@@ -0,0 +1,109 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from ...common.registry import registry
12
+ from .blip2_qformer import Blip2Qformer
13
+
14
+
15
+ @registry.register_model('blip2_image_text_matching')
16
+ class Blip2ITM(Blip2Qformer):
17
+ """
18
+ BLIP Image-Text Matching (ITM) model.
19
+ Supported model types:
20
+ - pretrained: pretrained model
21
+ - coco: fintuned model on coco
22
+ Usage:
23
+ >>> from lavis.models import load_model
24
+ >>> model = load_model("blip2_image_text_matching", "pretrained")
25
+ >>> model = load_model("blip2_image_text_matching", "coco")
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ vit_model='eva_clip_g',
31
+ img_size=224,
32
+ drop_path_rate=0,
33
+ use_grad_checkpoint=False,
34
+ vit_precision='fp16',
35
+ freeze_vit=True,
36
+ num_query_token=32,
37
+ cross_attention_freq=2,
38
+ embed_dim=256,
39
+ max_txt_len=32,
40
+ ):
41
+ super().__init__(
42
+ vit_model=vit_model,
43
+ img_size=img_size,
44
+ drop_path_rate=drop_path_rate,
45
+ use_grad_checkpoint=use_grad_checkpoint,
46
+ vit_precision=vit_precision,
47
+ freeze_vit=freeze_vit,
48
+ num_query_token=num_query_token,
49
+ cross_attention_freq=cross_attention_freq,
50
+ embed_dim=embed_dim,
51
+ max_txt_len=max_txt_len,
52
+ )
53
+
54
+ def forward(self, samples, match_head='itm'):
55
+ image = samples['image']
56
+ caption = samples['text_input']
57
+
58
+ with self.maybe_autocast():
59
+ image_embeds = self.ln_vision(self.visual_encoder(image))
60
+ image_embeds = image_embeds.float()
61
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
62
+
63
+ text = self.tokenizer(
64
+ caption,
65
+ truncation=True,
66
+ max_length=self.max_txt_len,
67
+ return_tensors='pt',
68
+ ).to(image.device)
69
+
70
+ if match_head == 'itm':
71
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
72
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
73
+ attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
74
+ output_itm = self.Qformer.bert(
75
+ text.input_ids,
76
+ query_embeds=query_tokens,
77
+ attention_mask=attention_mask,
78
+ encoder_hidden_states=image_embeds,
79
+ encoder_attention_mask=image_atts,
80
+ return_dict=True,
81
+ )
82
+ itm_embeddings = output_itm.last_hidden_state[:, :query_tokens.size(1), :]
83
+ itm_logit = self.itm_head(itm_embeddings)
84
+ itm_logit = itm_logit.mean(dim=1)
85
+
86
+ return itm_logit
87
+
88
+ elif match_head == 'itc':
89
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
90
+
91
+ query_output = self.Qformer.bert(
92
+ query_embeds=query_tokens,
93
+ encoder_hidden_states=image_embeds,
94
+ encoder_attention_mask=image_atts,
95
+ return_dict=True,
96
+ )
97
+ image_feats = F.normalize(self.vision_proj(query_output.last_hidden_state), dim=-1)
98
+
99
+ text_output = self.Qformer.bert(
100
+ text.input_ids,
101
+ attention_mask=text.attention_mask,
102
+ return_dict=True,
103
+ )
104
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
105
+
106
+ sims = torch.bmm(image_feats, text_feat.unsqueeze(-1))
107
+ sim, _ = torch.max(sims, dim=1)
108
+
109
+ return sim