evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,81 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ from torch import nn
10
+ from typing import List
11
+
12
+
13
+ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str):
14
+ uninitialized_encoder_weights: List[str] = []
15
+ if decoder.__class__ != encoder.__class__:
16
+ logging.info(
17
+ f'{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized.'
18
+ )
19
+
20
+ def tie_encoder_to_decoder_recursively(
21
+ decoder_pointer: nn.Module,
22
+ encoder_pointer: nn.Module,
23
+ module_name: str,
24
+ uninitialized_encoder_weights: List[str],
25
+ skip_key: str,
26
+ depth=0,
27
+ ):
28
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
29
+ encoder_pointer, nn.Module), f'{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module'
30
+ if hasattr(decoder_pointer, 'weight') and skip_key not in module_name:
31
+ assert hasattr(encoder_pointer, 'weight')
32
+ encoder_pointer.weight = decoder_pointer.weight
33
+ if hasattr(decoder_pointer, 'bias'):
34
+ assert hasattr(encoder_pointer, 'bias')
35
+ encoder_pointer.bias = decoder_pointer.bias
36
+ # print(module_name + " is tied")
37
+ return
38
+
39
+ encoder_modules = encoder_pointer._modules
40
+ decoder_modules = decoder_pointer._modules
41
+ if len(decoder_modules) > 0:
42
+ assert (len(encoder_modules) >
43
+ 0), f'Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}'
44
+
45
+ all_encoder_weights = set([module_name + '/' + sub_name for sub_name in encoder_modules.keys()])
46
+ encoder_layer_pos = 0
47
+ for name, module in decoder_modules.items():
48
+ if name.isdigit():
49
+ encoder_name = str(int(name) + encoder_layer_pos)
50
+ decoder_name = name
51
+ if not isinstance(
52
+ decoder_modules[decoder_name],
53
+ type(encoder_modules[encoder_name]),
54
+ ) and len(encoder_modules) != len(decoder_modules):
55
+ # this can happen if the name corresponds to the position in a list module list of layers
56
+ # in this case the decoder has added a cross-attention that the encoder does not have
57
+ # thus skip this step and subtract one layer pos from encoder
58
+ encoder_layer_pos -= 1
59
+ continue
60
+ elif name not in encoder_modules:
61
+ continue
62
+ elif depth > 500:
63
+ raise ValueError(
64
+ 'Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model.'
65
+ )
66
+ else:
67
+ decoder_name = encoder_name = name
68
+ tie_encoder_to_decoder_recursively(
69
+ decoder_modules[decoder_name],
70
+ encoder_modules[encoder_name],
71
+ module_name + '/' + name,
72
+ uninitialized_encoder_weights,
73
+ skip_key,
74
+ depth=depth + 1,
75
+ )
76
+ all_encoder_weights.remove(module_name + '/' + encoder_name)
77
+
78
+ uninitialized_encoder_weights += list(all_encoder_weights)
79
+
80
+ # tie weights recursively
81
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
@@ -0,0 +1,56 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import torch
11
+ from modelscope import AutoTokenizer
12
+
13
+ from ...common.dist_utils import download_cached_file
14
+ from ...common.utils import is_url
15
+ from ...models.base_model import BaseModel
16
+ from ...models.vit import interpolate_pos_embed
17
+
18
+
19
+ class BlipBase(BaseModel):
20
+
21
+ @classmethod
22
+ def init_tokenizer(cls):
23
+ tokenizer = AutoTokenizer.from_pretrained('AI-ModelScope/bert-base-uncased')
24
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
25
+ tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
26
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
27
+ return tokenizer
28
+
29
+ def load_from_pretrained(self, url_or_filename):
30
+ if is_url(url_or_filename):
31
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
32
+ checkpoint = torch.load(cached_file, map_location='cpu')
33
+ elif os.path.isfile(url_or_filename):
34
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
35
+ else:
36
+ raise RuntimeError('checkpoint url or path is invalid')
37
+
38
+ state_dict = checkpoint['model']
39
+
40
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
41
+ self.visual_encoder)
42
+ if 'visual_encoder_m.pos_embed' in self.state_dict().keys():
43
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
44
+ self.visual_encoder_m)
45
+
46
+ for key in self.state_dict().keys():
47
+ if key in state_dict.keys():
48
+ if state_dict[key].shape != self.state_dict()[key].shape:
49
+ del state_dict[key]
50
+
51
+ msg = self.load_state_dict(state_dict, strict=False)
52
+
53
+ logging.info('Missing keys {}'.format(msg.missing_keys))
54
+ logging.info('load checkpoint from %s' % url_or_filename)
55
+
56
+ return msg
@@ -0,0 +1,212 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+
10
+ from ...common.registry import registry
11
+ from ..med import XBertLMHeadDecoder
12
+ from ..vit import VisionTransformerEncoder
13
+ from .blip import BlipBase
14
+ from .blip_outputs import BlipIntermediateOutput, BlipOutput
15
+
16
+
17
+ @registry.register_model('blip_caption')
18
+ class BlipCaption(BlipBase):
19
+ """
20
+ BLIP captioning model.
21
+
22
+ Supported model types:
23
+ - base_coco: fine-tuned BLIP base model on COCO caption dataset (Karparthy split).
24
+ - large_coco: fine-tuned BLIP large model on COCO caption dataset (Karparthy split).
25
+
26
+ Usage:
27
+ >>> from lavis.models import load_model
28
+ >>> model = load_model("blip_caption", "base_coco")
29
+ >>> model = load_model("blip_caption", "large_coco")
30
+ """
31
+
32
+ PRETRAINED_MODEL_CONFIG_DICT = {
33
+ 'base_coco': 'configs/models/blip_caption_base_coco.yaml',
34
+ 'large_coco': 'configs/models/blip_caption_large_coco.yaml',
35
+ }
36
+
37
+ def __init__(self, image_encoder, text_decoder, prompt=None, max_txt_len=40):
38
+ super().__init__()
39
+
40
+ self.tokenizer = self.init_tokenizer()
41
+
42
+ self.visual_encoder = image_encoder
43
+ self.text_decoder = text_decoder
44
+
45
+ self.prompt = prompt
46
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
47
+
48
+ self.max_txt_len = max_txt_len
49
+
50
+ def forward_encoder(self, samples):
51
+ image_embeds = self.visual_encoder.forward_features(samples['image'])
52
+ return image_embeds
53
+
54
+ def forward_decoder(self, samples, image_embeds):
55
+ # prepare inputs for forwarding decoder
56
+ raw_text = samples['text_input']
57
+ text = self.tokenizer(
58
+ raw_text,
59
+ padding='longest',
60
+ truncation=True,
61
+ max_length=self.max_txt_len,
62
+ return_tensors='pt',
63
+ ).to(self.device)
64
+ text.input_ids[:, 0] = self.tokenizer.bos_token_id
65
+
66
+ # prepare targets for forwarding decoder
67
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
68
+ decoder_targets[:, :self.prompt_length] = -100
69
+
70
+ # forward decoder
71
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
72
+ decoder_output = self.text_decoder(
73
+ input_ids=text.input_ids,
74
+ attention_mask=text.attention_mask,
75
+ encoder_hidden_states=image_embeds,
76
+ encoder_attention_mask=image_atts,
77
+ labels=decoder_targets,
78
+ return_dict=True,
79
+ )
80
+
81
+ return decoder_output, decoder_targets
82
+
83
+ def forward(self, samples):
84
+ r"""
85
+ Args:
86
+ samples (dict): A dictionary containing the following keys:
87
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
88
+ - text_input (list): A list of strings of length batch_size.
89
+ Returns:
90
+ output (BlipOutput): A BlipOutput object containing the following
91
+ attributes:
92
+ - loss (torch.Tensor): A scalar tensor containing the total loss. For BlipCaption, this is the same as the LM loss.
93
+ - loss_lm (torch.Tensor): A scalar tensor containing the LM loss.
94
+ - intermediate_outputs (BlipIntermediateOutput): A BlipIntermediateOutput object containing intermediate outputs.
95
+ see :class:`lavis.models.blip_models.blip_outputs.BlipOutput` for more details.
96
+
97
+ Example:
98
+ ```python
99
+ >>> from PIL import Image
100
+ >>> from lavis.models import load_model_and_preprocess
101
+ >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption")
102
+ >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
103
+ >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
104
+ >>> text_input = ["a large statue of a person spraying water from a fountain"]
105
+ >>> samples = {"image": image, "text_input": text_input}
106
+ >>> output = model(samples)
107
+ >>> output.keys()
108
+ odict_keys(['intermediate_output', 'loss', 'loss_lm'])
109
+ >>> output.intermediate_output.image_embeds.shape
110
+ torch.Size([1, 577, 768])
111
+ >>> output.intermediate_output.decoder_labels.shape
112
+ torch.Size([1, 13])
113
+ ```"""
114
+
115
+ image_embeds = self.forward_encoder(samples)
116
+ decoder_output, decoder_targets = self.forward_decoder(samples, image_embeds)
117
+
118
+ # return decoder_out
119
+ return BlipOutput(
120
+ loss=decoder_output.loss,
121
+ loss_lm=decoder_output.loss,
122
+ intermediate_output=BlipIntermediateOutput(
123
+ image_embeds=image_embeds,
124
+ decoder_output=decoder_output,
125
+ decoder_labels=decoder_targets,
126
+ ),
127
+ )
128
+
129
+ def generate(
130
+ self,
131
+ samples,
132
+ use_nucleus_sampling=False,
133
+ num_beams=3,
134
+ max_length=30,
135
+ min_length=10,
136
+ top_p=0.9,
137
+ repetition_penalty=1.0,
138
+ num_captions=1,
139
+ ):
140
+ """
141
+ Args:
142
+ samples (dict): A dictionary containing the following keys:
143
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
144
+ use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
145
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
146
+ max_length (int): The maximum length of the sequence to be generated.
147
+ min_length (int): The minimum length of the sequence to be generated.
148
+ top_p (float): The cumulative probability for nucleus sampling.
149
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
150
+ num_captions (int): Number of captions to be generated for each image.
151
+ Returns:
152
+ captions (list): A list of strings of length batch_size * num_captions.
153
+
154
+ Example:
155
+ ```python
156
+ >>> from PIL import Image
157
+ >>> from lavis.models import load_model_and_preprocess
158
+ >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption")
159
+ >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
160
+ >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
161
+ >>> samples = {"image": image}
162
+ >>> captions = model.generate(samples)
163
+ >>> captions
164
+ ['a large statue of a person spraying water from a fountain']
165
+ >>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3)
166
+ >>> captions # example output, results may vary due to randomness
167
+ ['singapore showing the view of some building',
168
+ 'the singapore harbor in twilight, as the weather is going down',
169
+ 'the famous singapore fountain at sunset']
170
+ """
171
+ # prepare inputs for decoder generation.
172
+ encoder_out = self.forward_encoder(samples)
173
+ image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0)
174
+
175
+ prompt = [self.prompt] * image_embeds.size(0)
176
+ prompt = self.tokenizer(prompt, return_tensors='pt').to(self.device)
177
+ prompt.input_ids[:, 0] = self.tokenizer.bos_token_id
178
+ prompt.input_ids = prompt.input_ids[:, :-1]
179
+
180
+ # get decoded text
181
+ decoder_out = self.text_decoder.generate_from_encoder(
182
+ tokenized_prompt=prompt,
183
+ visual_embeds=image_embeds,
184
+ sep_token_id=self.tokenizer.sep_token_id,
185
+ pad_token_id=self.tokenizer.pad_token_id,
186
+ use_nucleus_sampling=use_nucleus_sampling,
187
+ num_beams=num_beams,
188
+ max_length=max_length,
189
+ min_length=min_length,
190
+ top_p=top_p,
191
+ repetition_penalty=repetition_penalty,
192
+ )
193
+
194
+ outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)
195
+ captions = [output[len(self.prompt):] for output in outputs]
196
+
197
+ return captions
198
+
199
+ @classmethod
200
+ def from_config(cls, cfg):
201
+ # vision encoder
202
+ image_encoder = VisionTransformerEncoder.from_config(cfg)
203
+ # text encoder + multimodal decoder
204
+ text_decoder = XBertLMHeadDecoder.from_config(cfg)
205
+
206
+ prompt = cfg.get('prompt', None)
207
+ max_txt_len = cfg.get('max_txt_len', 40)
208
+
209
+ model = cls(image_encoder, text_decoder, prompt=prompt, max_txt_len=max_txt_len)
210
+ model.load_checkpoint_from_config(cfg)
211
+
212
+ return model
@@ -0,0 +1,164 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from copy import deepcopy
11
+ from torch import nn
12
+
13
+ from ...common.registry import registry
14
+ from ..base_model import MomentumDistilationMixin
15
+ from ..med import XBertEncoder
16
+ from ..vit import VisionTransformerEncoder
17
+ from .blip import BlipBase
18
+ from .blip_outputs import BlipIntermediateOutput, BlipOutputWithLogits
19
+
20
+
21
+ @registry.register_model('blip_classification')
22
+ class BlipClassification(BlipBase, MomentumDistilationMixin):
23
+ PRETRAINED_MODEL_CONFIG_DICT = {
24
+ 'base': 'configs/models/blip_classification_base.yaml',
25
+ }
26
+
27
+ def __init__(
28
+ self,
29
+ image_encoder,
30
+ text_encoder,
31
+ num_classes,
32
+ momentum=0.995,
33
+ alpha=0.4,
34
+ max_txt_len=40,
35
+ use_distill=True,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.tokenizer = self.init_tokenizer()
40
+
41
+ self.use_distill = use_distill
42
+
43
+ self.visual_encoder = image_encoder
44
+ self.text_encoder = text_encoder
45
+
46
+ hidden_size = text_encoder.config.hidden_size
47
+ self.cls_head = nn.Sequential(
48
+ nn.Linear(hidden_size, hidden_size),
49
+ nn.ReLU(),
50
+ nn.Linear(hidden_size, num_classes),
51
+ )
52
+
53
+ if self.use_distill:
54
+ self.visual_encoder_m = deepcopy(self.visual_encoder)
55
+ self.text_encoder_m = deepcopy(self.text_encoder)
56
+ self.cls_head_m = deepcopy(self.cls_head)
57
+
58
+ self.momentum = momentum
59
+ self.alpha = alpha
60
+
61
+ self.model_pairs = [
62
+ [self.visual_encoder, self.visual_encoder_m],
63
+ [self.text_encoder, self.text_encoder_m],
64
+ [self.cls_head, self.cls_head_m],
65
+ ]
66
+
67
+ self.copy_params()
68
+
69
+ self.max_txt_len = max_txt_len
70
+
71
+ def _rampup_factor(self, epoch, iters, num_iters_per_epoch):
72
+ return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch)
73
+
74
+ def forward(self, samples, is_train=True):
75
+ sentences = samples['text_input']
76
+ sentences = self.tokenizer(
77
+ sentences,
78
+ padding='longest',
79
+ truncation=True,
80
+ max_length=self.max_txt_len,
81
+ return_tensors='pt',
82
+ ).to(self.device)
83
+ samples.update({'tokenized_text': sentences})
84
+
85
+ targets = samples['label']
86
+
87
+ image_embeds = self.visual_encoder.forward_features(samples['image'])
88
+ encoder_output = self.text_encoder.forward_automask(samples['tokenized_text'], image_embeds)
89
+
90
+ prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])
91
+
92
+ if is_train:
93
+ if self.use_distill:
94
+ with torch.no_grad():
95
+ self._momentum_update()
96
+
97
+ image_embeds_m = self.visual_encoder_m(samples['image'])
98
+ encoder_output_m = self.text_encoder_m.forward_automask(samples['tokenized_text'], image_embeds_m)
99
+
100
+ prediction_m = self.cls_head_m(encoder_output_m.last_hidden_state[:, 0, :])
101
+
102
+ alpha = self.alpha * self._rampup_factor(
103
+ epoch=samples['epoch'],
104
+ iters=samples['iters'],
105
+ num_iters_per_epoch=samples['num_iters_per_epoch'],
106
+ )
107
+
108
+ loss = (1 - alpha) * F.cross_entropy(prediction, targets) - alpha * torch.sum(
109
+ F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),
110
+ dim=1,
111
+ ).mean()
112
+ else:
113
+ loss = F.cross_entropy(prediction, targets)
114
+
115
+ # return {"loss": loss}
116
+ return BlipOutputWithLogits(
117
+ loss=loss,
118
+ intermediate_output=BlipIntermediateOutput(
119
+ image_embeds=image_embeds,
120
+ image_embeds_m=image_embeds_m,
121
+ encoder_output=encoder_output,
122
+ encoder_output_m=encoder_output_m,
123
+ ),
124
+ logits=prediction,
125
+ logits_m=prediction_m,
126
+ )
127
+
128
+ else:
129
+ return {'predictions': prediction, 'targets': targets}
130
+
131
+ def predict(self, samples):
132
+ output = self.forward(samples, is_train=False)
133
+ return output
134
+
135
+ @classmethod
136
+ def from_config(cls, cfg=None):
137
+ image_encoder = VisionTransformerEncoder.from_config(cfg)
138
+
139
+ # text encoder + multimodal encoder
140
+ text_encoder = XBertEncoder.from_config(cfg)
141
+ use_distill = cfg.get('use_distill', True)
142
+ momentum = cfg.get('momentum', 0.995)
143
+ num_classes = cfg.get('num_classes', -1)
144
+ alpha = cfg.get('alpha', 0.4)
145
+ max_txt_len = cfg.get('max_txt_len', 40)
146
+
147
+ assert num_classes > 1, 'Invalid number of classes provided, found {}'.format(num_classes)
148
+
149
+ model = cls(
150
+ image_encoder=image_encoder,
151
+ text_encoder=text_encoder,
152
+ use_distill=use_distill,
153
+ alpha=alpha,
154
+ num_classes=num_classes,
155
+ momentum=momentum,
156
+ max_txt_len=max_txt_len,
157
+ )
158
+
159
+ # load pre-trained weights
160
+ pretrain_path = cfg.get('pretrained', None)
161
+ if pretrain_path is not None:
162
+ msg = model.load_from_pretrained(url_or_filename=pretrain_path)
163
+
164
+ return model