evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,755 @@
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import copy
8
+ import logging
9
+ import random
10
+ import string
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.cuda.amp import autocast as autocast
14
+ from transformers import T5TokenizerFast
15
+ from transformers.modeling_outputs import BaseModelOutput
16
+
17
+ from ...common.registry import registry
18
+ from .blip2 import Blip2Base, disabled_train
19
+ from .modeling_t5 import T5Config, T5ForConditionalGeneration
20
+
21
+
22
+ @registry.register_model('blip2_t5_instruct')
23
+ class Blip2T5Instruct(Blip2Base):
24
+ """
25
+ BLIP2 T5 model.
26
+ Supported model types:
27
+ - flant5xl
28
+ - flant5xxl
29
+ Usage:
30
+ >>> from lavis.models import load_model
31
+ >>> model = load_model("blip2_t5_instruct", "flant5xl")
32
+ """
33
+
34
+ PRETRAINED_MODEL_CONFIG_DICT = {
35
+ 'flant5xl': 'configs/models/blip2/blip2_instruct_flant5xl.yaml',
36
+ 'flant5xxl': 'configs/models/blip2/blip2_instruct_flant5xxl.yaml',
37
+ }
38
+
39
+ def __init__(
40
+ self,
41
+ vit_model='eva_clip_g',
42
+ img_size=224,
43
+ drop_path_rate=0,
44
+ use_grad_checkpoint=False,
45
+ vit_precision='fp16',
46
+ freeze_vit=True,
47
+ num_query_token=32,
48
+ t5_model='google/flan-t5-xl',
49
+ prompt='',
50
+ max_txt_len=128,
51
+ max_output_txt_len=256,
52
+ apply_lemmatizer=False,
53
+ num_few_shot_examples=0,
54
+ few_shot_prob=0,
55
+ qformer_text_input=True,
56
+ ):
57
+ """
58
+ apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
59
+ """
60
+ super().__init__()
61
+
62
+ self.tokenizer = self.init_tokenizer(truncation_side='left')
63
+
64
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate,
65
+ use_grad_checkpoint, vit_precision)
66
+ if freeze_vit:
67
+ for name, param in self.visual_encoder.named_parameters():
68
+ param.requires_grad = False
69
+ self.visual_encoder = self.visual_encoder.eval()
70
+ self.visual_encoder.train = disabled_train
71
+ logging.info('freeze vision encoder')
72
+
73
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, self.visual_encoder.num_features)
74
+
75
+ if not qformer_text_input:
76
+ self.Qformer.bert.embeddings.word_embeddings = None
77
+ self.Qformer.bert.embeddings.position_embeddings = None
78
+ for layer in self.Qformer.bert.encoder.layer:
79
+ layer.output = None
80
+ layer.intermediate = None
81
+ else:
82
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
83
+ self.Qformer.cls = None
84
+
85
+ self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left')
86
+ self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right')
87
+
88
+ t5_config = T5Config.from_pretrained(t5_model)
89
+ t5_config.dense_act_fn = 'gelu'
90
+ self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model, config=t5_config)
91
+
92
+ for name, param in self.t5_model.named_parameters():
93
+ param.requires_grad = False
94
+ param.data = param.data.bfloat16()
95
+
96
+ self.t5_proj = nn.Linear(self.Qformer.config.hidden_size, self.t5_model.config.hidden_size)
97
+
98
+ self.max_txt_len = max_txt_len
99
+ self.max_output_txt_len = max_output_txt_len
100
+ self.prompt = prompt
101
+
102
+ self._apply_lemmatizer = apply_lemmatizer
103
+ self._lemmatizer = None
104
+
105
+ self.num_few_shot_examples = num_few_shot_examples
106
+ self.few_shot_prob = few_shot_prob
107
+
108
+ self.qformer_text_input = qformer_text_input
109
+
110
+ def forward(self, samples):
111
+ # print('-----------------')
112
+ # print(samples["text_input"])
113
+ # print(samples["text_output"])
114
+ # print('-----------------')
115
+
116
+ image = samples['image']
117
+ with self.maybe_autocast():
118
+ image_embeds = self.ln_vision(self.visual_encoder(image))
119
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
120
+
121
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
122
+ if self.qformer_text_input:
123
+ text_Qformer = self.tokenizer(
124
+ samples['text_input'],
125
+ padding='longest',
126
+ truncation=True,
127
+ max_length=self.max_txt_len,
128
+ return_tensors='pt',
129
+ ).to(image.device)
130
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
131
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
132
+
133
+ query_output = self.Qformer.bert(
134
+ text_Qformer.input_ids,
135
+ attention_mask=Qformer_atts,
136
+ query_embeds=query_tokens,
137
+ encoder_hidden_states=image_embeds,
138
+ encoder_attention_mask=image_atts,
139
+ return_dict=True,
140
+ )
141
+ else:
142
+ query_output = self.Qformer.bert(
143
+ query_embeds=query_tokens,
144
+ encoder_hidden_states=image_embeds,
145
+ encoder_attention_mask=image_atts,
146
+ return_dict=True,
147
+ )
148
+
149
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
150
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
151
+
152
+ fs_embeds, fs_atts = None, None
153
+ if self.few_shot_prob > 0 and 'few_shot_samples' in samples.keys():
154
+ fs_embeds, fs_atts = self.prepare_few_shot_embeds(samples['few_shot_samples'])
155
+
156
+ with self.maybe_autocast(dtype=torch.bfloat16):
157
+ input_tokens = self.t5_tokenizer(
158
+ samples['text_input'],
159
+ padding='longest',
160
+ truncation=True,
161
+ max_length=self.max_txt_len,
162
+ return_tensors='pt',
163
+ ).to(image.device)
164
+ output_tokens = self.t5_output_tokenizer(
165
+ samples['text_output'],
166
+ padding='longest',
167
+ truncation=True,
168
+ max_length=self.max_output_txt_len,
169
+ return_tensors='pt',
170
+ ).to(image.device)
171
+
172
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
173
+
174
+ targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id,
175
+ -100)
176
+
177
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
178
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
179
+
180
+ if fs_embeds is not None:
181
+ inputs_embeds = torch.cat([fs_embeds, inputs_embeds], dim=1)
182
+ encoder_atts = torch.cat([fs_atts, encoder_atts], dim=1)
183
+
184
+ outputs = self.t5_model(
185
+ inputs_embeds=inputs_embeds,
186
+ attention_mask=encoder_atts,
187
+ decoder_attention_mask=output_tokens.attention_mask,
188
+ return_dict=True,
189
+ labels=targets,
190
+ )
191
+ loss = outputs.loss
192
+
193
+ return {'loss': loss}
194
+
195
+ def prepare_few_shot_embeds(self, samples):
196
+ this_n_fs = random.choices(
197
+ list(range(self.num_few_shot_examples + 1)),
198
+ weights=[1 - self.few_shot_prob]
199
+ + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples)[0]
200
+
201
+ if this_n_fs == 0:
202
+ return None, None
203
+
204
+ images = []
205
+ text_input = []
206
+ for sample in samples:
207
+ for n in range(this_n_fs):
208
+ images.append(sample['image'][n])
209
+ text_input.append(sample['text_input'][n])
210
+ images = torch.stack(images, dim=0)
211
+
212
+ image = images
213
+
214
+ with self.maybe_autocast():
215
+ image_embeds = self.ln_vision(self.visual_encoder(image))
216
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
217
+
218
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
219
+ if self.qformer_text_input:
220
+ text_Qformer = self.tokenizer(
221
+ text_input,
222
+ padding='longest',
223
+ truncation=True,
224
+ max_length=self.max_txt_len,
225
+ return_tensors='pt',
226
+ ).to(image.device)
227
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
228
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
229
+ query_output = self.Qformer.bert(
230
+ text_Qformer.input_ids,
231
+ attention_mask=Qformer_atts,
232
+ query_embeds=query_tokens,
233
+ encoder_hidden_states=image_embeds,
234
+ encoder_attention_mask=image_atts,
235
+ return_dict=True,
236
+ )
237
+ else:
238
+ query_output = self.Qformer.bert(
239
+ query_embeds=query_tokens,
240
+ encoder_hidden_states=image_embeds,
241
+ encoder_attention_mask=image_atts,
242
+ return_dict=True,
243
+ )
244
+
245
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
246
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
247
+
248
+ with self.maybe_autocast(dtype=torch.bfloat16):
249
+ input_tokens = self.t5_tokenizer(
250
+ text_input,
251
+ padding='longest',
252
+ truncation=True,
253
+ max_length=self.max_txt_len,
254
+ return_tensors='pt',
255
+ ).to(image.device)
256
+
257
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
258
+
259
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
260
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
261
+
262
+ if this_n_fs > 1:
263
+ encoder_atts = encoder_atts.reshape(encoder_atts.size(0) // this_n_fs, encoder_atts.size(1) * this_n_fs)
264
+ inputs_embeds = inputs_embeds.reshape(
265
+ inputs_embeds.size(0) // this_n_fs,
266
+ inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2))
267
+
268
+ return inputs_embeds, encoder_atts
269
+
270
+ @torch.no_grad()
271
+ def generate(
272
+ self,
273
+ samples,
274
+ use_nucleus_sampling=False,
275
+ num_beams=5,
276
+ max_length=256,
277
+ min_length=1,
278
+ top_p=0.9,
279
+ repetition_penalty=1.5,
280
+ length_penalty=1.0,
281
+ num_captions=1,
282
+ temperature=1,
283
+ ):
284
+ if 'prompt' in samples.keys():
285
+ prompt = samples['prompt']
286
+ else:
287
+ prompt = self.prompt
288
+
289
+ image = samples['image']
290
+
291
+ bs = image.size(0)
292
+
293
+ if isinstance(prompt, str):
294
+ prompt = [prompt] * bs
295
+ else:
296
+ assert len(prompt) == bs, 'The number of prompts must be equal to the batch size.'
297
+
298
+ # For TextCaps
299
+ if 'ocr_tokens' in samples.keys() and '{}' in prompt[0]:
300
+ prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
301
+
302
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
303
+ if self.qformer_text_input:
304
+ # remove ocr tokens in q_former (for eval textvqa)
305
+ # qformer_prompt = prompt
306
+ # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
307
+
308
+ text_Qformer = self.tokenizer(
309
+ prompt,
310
+ padding='longest',
311
+ truncation=True,
312
+ max_length=self.max_txt_len,
313
+ return_tensors='pt',
314
+ ).to(image.device)
315
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
316
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
317
+
318
+ # For video data
319
+ if image.dim() == 5:
320
+ inputs_t5, atts_t5 = [], []
321
+ for j in range(image.size(2)):
322
+ this_frame = image[:, :, j, :, :]
323
+ with self.maybe_autocast():
324
+ frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
325
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
326
+
327
+ if self.qformer_text_input:
328
+ frame_query_output = self.Qformer.bert(
329
+ text_Qformer.input_ids,
330
+ attention_mask=Qformer_atts,
331
+ query_embeds=query_tokens,
332
+ encoder_hidden_states=frame_embeds,
333
+ encoder_attention_mask=frame_atts,
334
+ return_dict=True,
335
+ )
336
+ else:
337
+ frame_query_output = self.Qformer.bert(
338
+ query_embeds=query_tokens,
339
+ encoder_hidden_states=frame_embeds,
340
+ encoder_attention_mask=frame_atts,
341
+ return_dict=True,
342
+ )
343
+
344
+ frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:, :query_tokens.size(1), :])
345
+ frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
346
+ inputs_t5.append(frame_inputs_t5)
347
+ atts_t5.append(frame_atts_t5)
348
+ inputs_t5 = torch.cat(inputs_t5, dim=1)
349
+ atts_t5 = torch.cat(atts_t5, dim=1)
350
+ else:
351
+ with self.maybe_autocast():
352
+ image_embeds = self.ln_vision(self.visual_encoder(image))
353
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
354
+
355
+ if self.qformer_text_input:
356
+ query_output = self.Qformer.bert(
357
+ text_Qformer.input_ids,
358
+ attention_mask=Qformer_atts,
359
+ query_embeds=query_tokens,
360
+ encoder_hidden_states=image_embeds,
361
+ encoder_attention_mask=image_atts,
362
+ return_dict=True,
363
+ )
364
+ else:
365
+ query_output = self.Qformer.bert(
366
+ query_embeds=query_tokens,
367
+ encoder_hidden_states=image_embeds,
368
+ encoder_attention_mask=image_atts,
369
+ return_dict=True,
370
+ )
371
+
372
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
373
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
374
+
375
+ input_tokens = self.t5_tokenizer(prompt, padding='longest', return_tensors='pt').to(image.device)
376
+
377
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
378
+
379
+ with self.maybe_autocast(dtype=torch.bfloat16):
380
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
381
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
382
+
383
+ outputs = self.t5_model.generate(
384
+ inputs_embeds=inputs_embeds,
385
+ attention_mask=encoder_atts,
386
+ do_sample=use_nucleus_sampling,
387
+ top_p=top_p,
388
+ temperature=temperature,
389
+ num_beams=num_beams,
390
+ max_new_tokens=max_length,
391
+ min_length=min_length,
392
+ repetition_penalty=repetition_penalty,
393
+ length_penalty=length_penalty,
394
+ num_return_sequences=num_captions,
395
+ )
396
+ output_text = self.t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
397
+
398
+ return output_text
399
+
400
+ def predict_answers(self,
401
+ samples,
402
+ num_beams=5,
403
+ inference_method='generate',
404
+ max_len=10,
405
+ min_len=1,
406
+ num_ans_candidates=128,
407
+ answer_list=None,
408
+ prompt='',
409
+ length_penalty=-1,
410
+ **kwargs):
411
+ if isinstance(samples['text_input'], str):
412
+ samples['text_input'] = [samples['text_input']]
413
+
414
+ if prompt:
415
+ if prompt.count('{}') == 2:
416
+ if 'ocr_tokens' in samples:
417
+ text_input = [
418
+ prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples['text_input'][i])
419
+ for i in range(len(samples['text_input']))
420
+ ]
421
+ elif 'choices' in samples:
422
+ text_input = []
423
+ for i in range(len(samples['text_input'])):
424
+ this_choices = [
425
+ f'({string.ascii_lowercase[j]}) {ch}' for j, ch in enumerate(samples['choices'][i])
426
+ ]
427
+ this_choices = ' '.join(this_choices)
428
+ text_input.append(prompt.format(samples['text_input'][i], this_choices))
429
+ else:
430
+ text_input = [prompt.format(question) for question in samples['text_input']]
431
+ else:
432
+ text_input = samples['text_input']
433
+
434
+ samples['prompt'] = text_input
435
+
436
+ output_text = self.generate(
437
+ samples, num_beams=num_beams, max_length=max_len, min_length=min_len, length_penalty=length_penalty)
438
+
439
+ if self._apply_lemmatizer or ('apply_lemmatizer' in samples.keys() and samples['apply_lemmatizer']):
440
+ output_text = self._lemmatize(output_text)
441
+
442
+ return output_text
443
+
444
+ def predict_class(
445
+ self,
446
+ samples,
447
+ candidates,
448
+ n_segments=1,
449
+ ):
450
+ # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
451
+ if type(candidates[0]) == list:
452
+ results = []
453
+
454
+ for i in range(samples['image'].size(0)):
455
+ this_sample = {
456
+ 'image': samples['image'][i].unsqueeze(0),
457
+ 'prompt': samples['prompt'],
458
+ }
459
+
460
+ if 'text_input' in samples.keys():
461
+ this_sample['text_input'] = [samples['text_input'][i]]
462
+
463
+ if 'context' in samples.keys():
464
+ this_sample['context'] = [samples['context'][i]]
465
+
466
+ if 'history' in samples.keys():
467
+ this_sample['history'] = [samples['history'][i]]
468
+
469
+ if 'caption' in samples.keys():
470
+ this_sample['caption'] = [samples['caption'][i]]
471
+
472
+ this_result = self._predict_class(this_sample, candidates[i], n_segments)
473
+ results.append(this_result)
474
+
475
+ try:
476
+ results = torch.cat(results, dim=0)
477
+ except:
478
+ results = [res.tolist()[0] for res in results]
479
+
480
+ return results
481
+
482
+ return self._predict_class(samples, candidates, n_segments)
483
+
484
+ def _predict_class(
485
+ self,
486
+ samples,
487
+ candidates,
488
+ n_segments=1,
489
+ ):
490
+ """
491
+ Args:
492
+ samples (dict): A dictionary containing the following keys:
493
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
494
+ - prompt: the instruction
495
+ candidates:
496
+ (list): A list of candidate class names;
497
+ n_segments:
498
+ (int): Split the candidates into n_segments and predict one by one. This is useful when the number of candidates is too large.
499
+ Returns:
500
+ output_class: predicted class index
501
+ """
502
+
503
+ image = samples['image']
504
+ prompt = samples['prompt']
505
+
506
+ bs = image.size(0)
507
+
508
+ if isinstance(prompt, str):
509
+ prompt = [prompt] * bs
510
+ else:
511
+ assert len(prompt) == bs, 'The number of prompts must be equal to the batch size.'
512
+
513
+ if 'text_input' in samples.keys():
514
+ if type(samples['text_input'][0]) == list:
515
+ prompt = [prompt[i].format(*samples['text_input'][i]) for i in range(len(prompt))]
516
+ else:
517
+ prompt = [prompt[i].format(samples['text_input'][i]) for i in range(len(prompt))]
518
+
519
+ # scienceqa
520
+ if 'context' in samples.keys() and samples['context'] != '':
521
+ prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
522
+
523
+ # visual dialog
524
+ if 'history' in samples.keys() and samples['history'][0] != '':
525
+ prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
526
+
527
+ if 'caption' in samples.keys() and samples['caption'][0] != '':
528
+ prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
529
+
530
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
531
+ if self.qformer_text_input:
532
+ text_Qformer = self.tokenizer(
533
+ prompt, padding='longest', truncation=True, max_length=self.max_txt_len,
534
+ return_tensors='pt').to(image.device)
535
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
536
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
537
+
538
+ if image.dim() == 5:
539
+ inputs_t5, atts_t5 = [], []
540
+ for j in range(image.size(2)):
541
+ this_frame = image[:, :, j, :, :]
542
+ with self.maybe_autocast():
543
+ frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
544
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
545
+
546
+ if self.qformer_text_input:
547
+ frame_query_output = self.Qformer.bert(
548
+ text_Qformer.input_ids,
549
+ attention_mask=Qformer_atts,
550
+ query_embeds=query_tokens,
551
+ encoder_hidden_states=frame_embeds,
552
+ encoder_attention_mask=frame_atts,
553
+ return_dict=True,
554
+ )
555
+ else:
556
+ frame_query_output = self.Qformer.bert(
557
+ query_embeds=query_tokens,
558
+ encoder_hidden_states=frame_embeds,
559
+ encoder_attention_mask=frame_atts,
560
+ return_dict=True,
561
+ )
562
+
563
+ frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:, :query_tokens.size(1), :])
564
+ frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
565
+ inputs_t5.append(frame_inputs_t5)
566
+ atts_t5.append(frame_atts_t5)
567
+ inputs_t5 = torch.cat(inputs_t5, dim=1)
568
+ atts_t5 = torch.cat(atts_t5, dim=1)
569
+ else:
570
+ with self.maybe_autocast():
571
+ image_embeds = self.ln_vision(self.visual_encoder(image))
572
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
573
+
574
+ if self.qformer_text_input:
575
+ query_output = self.Qformer.bert(
576
+ text_Qformer.input_ids,
577
+ attention_mask=Qformer_atts,
578
+ query_embeds=query_tokens,
579
+ encoder_hidden_states=image_embeds,
580
+ encoder_attention_mask=image_atts,
581
+ return_dict=True,
582
+ )
583
+ else:
584
+ query_output = self.Qformer.bert(
585
+ query_embeds=query_tokens,
586
+ encoder_hidden_states=image_embeds,
587
+ encoder_attention_mask=image_atts,
588
+ return_dict=True,
589
+ )
590
+
591
+ inputs_t5 = self.t5_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
592
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
593
+
594
+ input_tokens = self.t5_tokenizer(prompt, padding='longest', return_tensors='pt').to(image.device)
595
+ output_tokens = self.t5_tokenizer(candidates, padding='longest', return_tensors='pt').to(image.device)
596
+
597
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
598
+
599
+ n_cands = len(candidates)
600
+
601
+ with self.maybe_autocast(dtype=torch.bfloat16):
602
+ inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
603
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
604
+
605
+ encoder_outputs = self.t5_model.encoder(
606
+ inputs_embeds=inputs_embeds,
607
+ attention_mask=encoder_atts,
608
+ )
609
+
610
+ all_losses = []
611
+ for n in range(n_segments):
612
+ seg_len = n_cands // n_segments
613
+ if n == (n_segments - 1):
614
+ seg_len = n_cands - seg_len * (n_segments - 1)
615
+
616
+ # this_encoder_outputs = copy.deepcopy(encoder_outputs)
617
+ this_encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs[0].clone(), )
618
+
619
+ this_encoder_outputs['last_hidden_state'] = this_encoder_outputs[0].repeat_interleave(seg_len, dim=0)
620
+ this_encoder_atts = encoder_atts.repeat_interleave(seg_len, dim=0)
621
+
622
+ start_i = n * (n_cands // n_segments)
623
+ end_i = start_i + seg_len
624
+ this_output_tokens_ids = output_tokens.input_ids[start_i:end_i].repeat(bs, 1)
625
+ this_output_tokens_atts = output_tokens.attention_mask[start_i:end_i].repeat(bs, 1)
626
+
627
+ this_targets = this_output_tokens_ids.masked_fill(
628
+ this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100)
629
+
630
+ outputs = self.t5_model(
631
+ encoder_outputs=this_encoder_outputs,
632
+ attention_mask=this_encoder_atts,
633
+ decoder_attention_mask=this_output_tokens_atts,
634
+ return_dict=True,
635
+ labels=this_targets,
636
+ reduction='none',
637
+ )
638
+ loss = outputs.loss
639
+
640
+ loss = loss.reshape(bs, seg_len)
641
+ # output_class_ranks = torch.argsort(loss, dim=-1)
642
+ all_losses.append(loss)
643
+
644
+ all_losses = torch.cat(all_losses, dim=-1)
645
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
646
+
647
+ # encoder_outputs['last_hidden_state'] = encoder_outputs[0].repeat_interleave(n_cands, dim=0)
648
+ # encoder_atts = encoder_atts.repeat_interleave(n_cands, dim=0)
649
+ # output_tokens.input_ids = output_tokens.input_ids.repeat(bs, 1)
650
+ # output_tokens.attention_mask = output_tokens.attention_mask.repeat(bs, 1)
651
+
652
+ # # compute the LM loss for each candidate (sum logprob across all tokens) and select the highest
653
+ # targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100)
654
+
655
+ # outputs = self.t5_model(
656
+ # encoder_outputs=encoder_outputs,
657
+ # attention_mask=encoder_atts,
658
+ # decoder_attention_mask=output_tokens.attention_mask,
659
+ # return_dict=True,
660
+ # labels=targets,
661
+ # reduction="none",
662
+ # )
663
+ # loss = outputs.loss
664
+
665
+ # loss = loss.reshape(bs, n_cands)
666
+ # output_class_ranks = torch.argsort(loss, dim=-1) # (bs, num_candidates)
667
+
668
+ return output_class_ranks
669
+
670
+ def _lemmatize(self, answers):
671
+
672
+ def apply(answer):
673
+ doc = self.lemmatizer(answer)
674
+
675
+ words = []
676
+ for token in doc:
677
+ if token.pos_ in ['NOUN', 'VERB']:
678
+ words.append(token.lemma_)
679
+ else:
680
+ words.append(token.text)
681
+ answer = ' '.join(words)
682
+
683
+ return answer
684
+
685
+ return [apply(answer) for answer in answers]
686
+
687
+ @property
688
+ def lemmatizer(self):
689
+ if self._lemmatizer is None:
690
+ try:
691
+ import spacy
692
+
693
+ self._lemmatizer = spacy.load('en_core_web_sm')
694
+ except ImportError:
695
+ logging.error("""
696
+ Please install spacy and en_core_web_sm model to apply lemmatization.
697
+ python -m spacy download en_core_web_sm
698
+ OR
699
+ import spacy.cli
700
+ spacy.cli.download("en_core_web_sm")
701
+ """)
702
+ exit(1)
703
+
704
+ return self._lemmatizer
705
+
706
+ @classmethod
707
+ def from_config(cls, cfg):
708
+ vit_model = cfg.get('vit_model', 'eva_clip_g')
709
+ img_size = cfg.get('image_size')
710
+ num_query_token = cfg.get('num_query_token')
711
+ t5_model = cfg.get('t5_model')
712
+
713
+ drop_path_rate = cfg.get('drop_path_rate', 0)
714
+ use_grad_checkpoint = cfg.get('use_grad_checkpoint', False)
715
+ vit_precision = cfg.get('vit_precision', 'fp16')
716
+ freeze_vit = cfg.get('freeze_vit', True)
717
+
718
+ prompt = cfg.get('prompt', '')
719
+ max_txt_len = cfg.get('max_txt_len', 128)
720
+ max_output_txt_len = cfg.get('max_output_txt_len', 256)
721
+
722
+ apply_lemmatizer = cfg.get('apply_lemmatizer', False)
723
+
724
+ num_few_shot_examples = cfg.get('num_few_shot_examples', 0)
725
+ few_shot_prob = cfg.get('few_shot_prob', 0.0)
726
+
727
+ qformer_text_input = cfg.get('qformer_text_input', True)
728
+
729
+ model = cls(
730
+ vit_model=vit_model,
731
+ img_size=img_size,
732
+ drop_path_rate=drop_path_rate,
733
+ use_grad_checkpoint=use_grad_checkpoint,
734
+ vit_precision=vit_precision,
735
+ freeze_vit=freeze_vit,
736
+ num_query_token=num_query_token,
737
+ t5_model=t5_model,
738
+ prompt=prompt,
739
+ max_txt_len=max_txt_len,
740
+ max_output_txt_len=max_output_txt_len,
741
+ apply_lemmatizer=apply_lemmatizer,
742
+ num_few_shot_examples=num_few_shot_examples,
743
+ few_shot_prob=few_shot_prob,
744
+ qformer_text_input=qformer_text_input,
745
+ )
746
+
747
+ # if qformer_text_input:
748
+ # # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
749
+ # model.load_from_pretrained(
750
+ # url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
751
+ # )
752
+
753
+ model.load_checkpoint_from_config(cfg)
754
+
755
+ return model