evalscope 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (214) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/backend/rag_eval/__init__.py +1 -1
  3. evalscope/backend/rag_eval/backend_manager.py +21 -5
  4. evalscope/backend/rag_eval/cmteb/arguments.py +10 -0
  5. evalscope/backend/rag_eval/ragas/arguments.py +0 -1
  6. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +7 -2
  7. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +0 -5
  8. evalscope/backend/rag_eval/utils/embedding.py +49 -3
  9. evalscope/backend/rag_eval/utils/llm.py +4 -4
  10. evalscope/backend/vlm_eval_kit/backend_manager.py +4 -2
  11. evalscope/benchmarks/__init__.py +2 -2
  12. evalscope/benchmarks/aigc/__init__.py +0 -0
  13. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  14. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  15. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  16. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  17. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  18. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  19. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  20. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  21. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  22. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  23. evalscope/benchmarks/arc/arc_adapter.py +2 -2
  24. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  25. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  26. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  27. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  28. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  29. evalscope/benchmarks/data_adapter.py +21 -10
  30. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  31. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  32. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  33. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +1 -1
  34. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  35. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +5 -4
  36. evalscope/benchmarks/live_code_bench/testing_util.py +369 -550
  37. evalscope/benchmarks/maritime_bench/__init__.py +0 -0
  38. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +79 -0
  39. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  40. evalscope/benchmarks/mmlu/mmlu_adapter.py +8 -8
  41. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
  42. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +1 -1
  43. evalscope/benchmarks/musr/musr_adapter.py +1 -1
  44. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  45. evalscope/benchmarks/utils.py +7 -16
  46. evalscope/cli/start_app.py +1 -1
  47. evalscope/collections/evaluator.py +20 -6
  48. evalscope/config.py +8 -4
  49. evalscope/constants.py +11 -0
  50. evalscope/evaluator/evaluator.py +2 -2
  51. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  52. evalscope/metrics/__init__.py +49 -4
  53. evalscope/metrics/llm_judge.py +1 -1
  54. evalscope/metrics/named_metrics.py +13 -0
  55. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  56. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  57. evalscope/metrics/t2v_metrics/constants.py +12 -0
  58. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  59. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  60. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  61. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  62. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  63. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  64. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  65. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  66. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  67. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  68. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  69. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  70. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  71. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  72. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  73. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  74. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  75. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  76. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  77. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  139. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  140. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  141. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  142. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  143. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  144. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  145. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  146. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  147. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  148. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  149. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  150. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  151. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  152. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  153. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  154. evalscope/metrics/t2v_metrics/score.py +78 -0
  155. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  156. evalscope/models/__init__.py +50 -14
  157. evalscope/models/adapters/__init__.py +17 -0
  158. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  159. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  160. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  161. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  162. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  163. evalscope/models/adapters/t2i_adapter.py +76 -0
  164. evalscope/models/custom/__init__.py +2 -1
  165. evalscope/models/custom/dummy_model.py +11 -13
  166. evalscope/models/local_model.py +82 -33
  167. evalscope/models/model.py +2 -42
  168. evalscope/models/register.py +26 -0
  169. evalscope/perf/arguments.py +24 -5
  170. evalscope/perf/benchmark.py +28 -42
  171. evalscope/perf/http_client.py +2 -3
  172. evalscope/perf/plugin/api/custom_api.py +1 -1
  173. evalscope/perf/plugin/api/openai_api.py +2 -2
  174. evalscope/perf/plugin/datasets/custom.py +4 -1
  175. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  176. evalscope/perf/plugin/datasets/line_by_line.py +4 -1
  177. evalscope/perf/plugin/datasets/longalpaca.py +4 -1
  178. evalscope/perf/plugin/datasets/openqa.py +4 -1
  179. evalscope/perf/plugin/datasets/random_dataset.py +13 -6
  180. evalscope/perf/utils/benchmark_util.py +14 -8
  181. evalscope/perf/utils/db_util.py +9 -3
  182. evalscope/perf/utils/log_utils.py +41 -0
  183. evalscope/report/__init__.py +1 -0
  184. evalscope/report/app.py +128 -78
  185. evalscope/report/app_arguments.py +11 -0
  186. evalscope/report/generator.py +1 -1
  187. evalscope/run.py +10 -3
  188. evalscope/summarizer.py +2 -1
  189. evalscope/third_party/thinkbench/eval.py +19 -7
  190. evalscope/utils/chat_service.py +2 -2
  191. evalscope/utils/import_utils.py +66 -0
  192. evalscope/utils/utils.py +48 -29
  193. evalscope/version.py +2 -2
  194. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/METADATA +37 -15
  195. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/RECORD +209 -96
  196. tests/aigc/__init__.py +1 -0
  197. tests/aigc/test_t2i.py +87 -0
  198. tests/cli/test_all.py +4 -4
  199. tests/cli/test_collection.py +2 -1
  200. tests/cli/test_run.py +19 -12
  201. tests/perf/test_perf.py +3 -3
  202. tests/rag/test_clip_benchmark.py +0 -1
  203. tests/rag/test_mteb.py +37 -8
  204. tests/rag/test_ragas.py +29 -26
  205. tests/vlm/test_vlmeval.py +37 -1
  206. evalscope/backend/vlm_eval_kit/custom_dataset.py +0 -46
  207. evalscope/benchmarks/live_code_bench/execute_utils.py +0 -267
  208. evalscope/metrics/code_metric.py +0 -98
  209. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  210. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  211. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
  212. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
  213. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
  214. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import warnings
11
+ from torch import nn
12
+
13
+ from ...common.registry import registry
14
+ from ..med import XBertEncoder
15
+ from ..vit import VisionTransformerEncoder
16
+ from .blip import BlipBase
17
+ from .blip_outputs import BlipOutputFeatures
18
+
19
+
20
+ @registry.register_model('blip_feature_extractor')
21
+ class BlipFeatureExtractor(BlipBase):
22
+ """
23
+ Class for BLIP feature extractor.
24
+
25
+ Supported model types:
26
+ - base: BLIP base model with pre-trained weights from capfilt by BLIP large model.
27
+
28
+ Usage:
29
+ >>> from lavis.models import load_model
30
+ >>> model = load_model("blip_feature_extractor", "base")
31
+ """
32
+
33
+ PRETRAINED_MODEL_CONFIG_DICT = {
34
+ 'base': 'configs/models/blip_feature_extractor_base.yaml',
35
+ # "large": "configs/models/blip_feature_extractor_large.yaml",
36
+ }
37
+
38
+ def __init__(self, image_encoder, text_encoder, embed_dim, max_txt_len=40):
39
+ super().__init__()
40
+
41
+ self.tokenizer = self.init_tokenizer()
42
+
43
+ self.visual_encoder = image_encoder
44
+ self.text_encoder = text_encoder
45
+
46
+ # creating projection layers for ITC
47
+ text_width = text_encoder.config.hidden_size
48
+ vision_width = image_encoder.vision_width
49
+
50
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
51
+ self.text_proj = nn.Linear(text_width, embed_dim)
52
+
53
+ self.max_txt_len = max_txt_len
54
+
55
+ self.temp = nn.Parameter(0.07 * torch.ones([]))
56
+
57
+ @torch.no_grad()
58
+ def extract_features(self, samples, mode='multimodal'):
59
+ """
60
+ Extract features for multimodal or unimodal samples.
61
+
62
+ Args:
63
+ samples (dict): A dictionary of samples, containing the following keys:
64
+ - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.
65
+ Raw images should be preprocessed before being passed to feature extractor.
66
+ - text_input (list): A list of strings containing the text, length B.
67
+ mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image".
68
+ If "multimodal", return image features and multimodal features;
69
+ if "text", return text features;
70
+ if "image", return image features.
71
+ Default: "multimodal".
72
+
73
+ Returns:
74
+ BlipOutputFeatures: A BlipOutputFeatures object containing the features.
75
+ See lavis/models/blip_models/blip_outputs.py for more details.
76
+
77
+ Examples:
78
+ ```python
79
+ >>> from PIL import Image
80
+ >>> from lavis.models import load_model_and_preprocess
81
+ >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
82
+ >>> caption = "a large fountain spewing water into the air"
83
+ >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_feature_extractor", is_eval=True)
84
+ >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
85
+ >>> text_input = txt_processors["eval"](caption)
86
+
87
+ >>> sample = {"image": image, "text_input": [text_input]}
88
+
89
+ >>> features_multimodal = model.extract_features(sample)
90
+ >>> features_multimodal.keys()
91
+ odict_keys(['image_embeds', 'multimodal_embeds'])
92
+ >>> features_multimodal.image_embeds.shape
93
+ torch.Size([1, 197, 768])
94
+ >>> features_multimodal.multimodal_embeds.shape
95
+ torch.Size([1, 12, 768])
96
+
97
+ >>> features_text = model.extract_features(sample, mode="text")
98
+ >>> features_text.keys()
99
+ odict_keys(['text_embeds', 'text_features'])
100
+ >>> features_text.text_embeds.shape
101
+ torch.Size([1, 12, 768])
102
+ >>> features_text.text_features.shape
103
+ torch.Size([1, 12, 256])
104
+
105
+ >>> features_image = model.extract_features(sample, mode="image")
106
+ >>> features_image.keys()
107
+ odict_keys(['image_embeds', 'image_features'])
108
+ >>> features_image.image_embeds.shape
109
+ torch.Size([1, 197, 768])
110
+ >>> features_image.image_features.shape
111
+ torch.Size([1, 197, 256])
112
+ ```
113
+ """
114
+ image = samples.get('image')
115
+ caption = samples.get('text_input')
116
+
117
+ # assert mode is one of "image", "text", "multimodal"
118
+ assert mode in [
119
+ 'image',
120
+ 'text',
121
+ 'multimodal',
122
+ ], "mode must be one of 'image', 'text', 'multimodal'"
123
+
124
+ # initalize output
125
+ image_embeds, text_embeds, multimodal_embeds = None, None, None
126
+ image_features, text_features = None, None
127
+
128
+ if mode == 'image':
129
+ assert (image is not None), "Image is not provided for mode 'image' or 'multimodal'"
130
+ # return image features
131
+ image_embeds = self.visual_encoder.forward_features(image)
132
+
133
+ image_features = self.vision_proj(image_embeds)
134
+ image_features = F.normalize(image_features, dim=-1)
135
+
136
+ elif mode == 'text':
137
+ assert (caption is not None), "text input is None for mode 'text' or 'multimodal'"
138
+
139
+ text = self.tokenizer(caption, return_tensors='pt', padding=True).to(self.device)
140
+
141
+ # return text features
142
+ text_output = self.text_encoder(
143
+ text.input_ids,
144
+ attention_mask=text.attention_mask,
145
+ return_dict=True,
146
+ mode='text',
147
+ )
148
+ text_embeds = text_output.last_hidden_state
149
+
150
+ text_features = self.text_proj(text_embeds)
151
+ text_features = F.normalize(text_features, dim=-1)
152
+
153
+ elif mode == 'multimodal':
154
+ # return multimodel features
155
+ image_embeds = self.visual_encoder.forward_features(image)
156
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
157
+
158
+ text = self.tokenizer(caption, return_tensors='pt', padding=True).to(self.device)
159
+ text.input_ids[:, 0] = self.tokenizer.enc_token_id
160
+
161
+ output = self.text_encoder(
162
+ text.input_ids,
163
+ attention_mask=text.attention_mask,
164
+ encoder_hidden_states=image_embeds,
165
+ encoder_attention_mask=image_atts,
166
+ return_dict=True,
167
+ )
168
+
169
+ multimodal_embeds = output.last_hidden_state
170
+
171
+ return BlipOutputFeatures(
172
+ image_embeds=image_embeds,
173
+ image_embeds_proj=image_features,
174
+ text_embeds=text_embeds,
175
+ text_embeds_proj=text_features,
176
+ multimodal_embeds=multimodal_embeds,
177
+ )
178
+
179
+ @classmethod
180
+ def from_config(cls, cfg=None):
181
+ # set from_pretrained=True to load weights for 'bert-base-uncased'
182
+ image_encoder = VisionTransformerEncoder.from_config(cfg)
183
+ text_encoder = XBertEncoder.from_config(cfg)
184
+
185
+ embed_dim = cfg.get('embed_dim', 256)
186
+ max_txt_len = cfg.get('max_txt_len', 30)
187
+
188
+ model = cls(
189
+ image_encoder=image_encoder,
190
+ text_encoder=text_encoder,
191
+ embed_dim=embed_dim,
192
+ max_txt_len=max_txt_len,
193
+ )
194
+
195
+ # load pre-trained weights
196
+ pretrain_path = cfg.get('pretrained', None)
197
+ if pretrain_path is not None:
198
+ msg = model.load_from_pretrained(url_or_filename=pretrain_path)
199
+ else:
200
+ warnings.warn('No pretrained weights are loaded.')
201
+
202
+ return model
@@ -0,0 +1,185 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from ...common.registry import registry
13
+ from ..med import XBertEncoder
14
+ from ..vit import VisionTransformerEncoder
15
+ from .blip import BlipBase
16
+
17
+
18
+ @registry.register_model('blip_image_text_matching')
19
+ class BlipITM(BlipBase):
20
+ """
21
+ BLIP Image-Text Matching (ITM) model.
22
+
23
+ Supported model types:
24
+ - base: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).
25
+ - large: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).
26
+
27
+ Usage:
28
+ >>> from lavis.models import load_model
29
+ >>> model = load_model("blip_image_text_matching", "base")
30
+ >>> model = load_model("blip_image_text_matching", "large")
31
+ """
32
+
33
+ PRETRAINED_MODEL_CONFIG_DICT = {
34
+ 'base': 'configs/models/blip_itm_base.yaml',
35
+ 'large': 'configs/models/blip_itm_large.yaml',
36
+ }
37
+
38
+ def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=35):
39
+ super().__init__()
40
+
41
+ self.tokenizer = self.init_tokenizer()
42
+
43
+ self.text_encoder = text_encoder
44
+
45
+ self.visual_encoder = image_encoder
46
+
47
+ self.max_txt_len = max_txt_len
48
+
49
+ # creating projection layers for ITC
50
+ text_width = text_encoder.config.hidden_size
51
+ vision_width = image_encoder.vision_width
52
+
53
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
54
+ self.text_proj = nn.Linear(text_width, embed_dim)
55
+
56
+ self.itm_head = nn.Linear(text_width, 2)
57
+
58
+ def forward(self, samples, match_head='itm'):
59
+ image = samples['image']
60
+ caption = samples['text_input']
61
+
62
+ image_embeds = self.visual_encoder.forward_features(image)
63
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
64
+
65
+ text = self.tokenizer(
66
+ caption,
67
+ padding='longest',
68
+ truncation=True,
69
+ max_length=self.max_txt_len,
70
+ return_tensors='pt',
71
+ ).to(image.device)
72
+ if match_head == 'itm':
73
+ encoder_input_ids = text.input_ids.clone()
74
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id # extra code
75
+ output = self.text_encoder(
76
+ encoder_input_ids,
77
+ attention_mask=text.attention_mask,
78
+ encoder_hidden_states=image_embeds,
79
+ encoder_attention_mask=image_atts,
80
+ return_dict=True,
81
+ )
82
+ itm_output = self.itm_head(output.last_hidden_state[:, 0, :])
83
+ return itm_output
84
+
85
+ elif match_head == 'itc':
86
+ text_output = self.text_encoder(
87
+ text.input_ids,
88
+ attention_mask=text.attention_mask,
89
+ return_dict=True,
90
+ mode='text',
91
+ )
92
+ image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
93
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
94
+
95
+ sim = image_feat @ text_feat.t()
96
+ return sim
97
+
98
+ def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head='itm'):
99
+ # breakpoint()
100
+ encoder_input_ids = encoder_input_ids.clone()
101
+ encoder_input_ids = encoder_input_ids[:, 3:]
102
+ text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long()
103
+
104
+ if match_head == 'itm':
105
+ # encoder_input_ids = encoder_input_ids.clone()
106
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
107
+ output = self.text_encoder(
108
+ encoder_input_ids,
109
+ attention_mask=text_attention_mask,
110
+ encoder_hidden_states=image_embeds,
111
+ encoder_attention_mask=image_atts,
112
+ return_dict=True,
113
+ )
114
+ # print(output.last_hidden_state.shape)
115
+ itm_output = self.itm_head(output.last_hidden_state[:, 0, :])
116
+ itm_output = F.softmax(itm_output, dim=1)[:, 1]
117
+ return itm_output #, mask, token_length
118
+
119
+ elif match_head == 'itc':
120
+ encoder_input_ids[:, 0] = self.tokenizer.cls_token_id
121
+ text_output = self.text_encoder(
122
+ encoder_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text')
123
+ image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
124
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
125
+
126
+ sim = image_feat @ text_feat.t()
127
+ return sim
128
+
129
+ @classmethod
130
+ def from_config(cls, cfg=None):
131
+ image_encoder = VisionTransformerEncoder.from_config(cfg)
132
+ text_encoder = XBertEncoder.from_config(cfg)
133
+
134
+ embed_dim = cfg.get('embed_dim', 256)
135
+ max_txt_len = cfg.get('max_txt_len', 35)
136
+
137
+ model = cls(
138
+ image_encoder=image_encoder,
139
+ text_encoder=text_encoder,
140
+ embed_dim=embed_dim,
141
+ max_txt_len=max_txt_len,
142
+ )
143
+
144
+ model.load_checkpoint_from_config(cfg)
145
+
146
+ return model
147
+
148
+
149
+ def compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6):
150
+ model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True
151
+
152
+ output = model({'image': visual_input, 'text_input': text_input}, match_head='itm')
153
+ loss = output[:, 1].sum()
154
+
155
+ model.zero_grad()
156
+ loss.backward()
157
+ with torch.no_grad():
158
+ mask = tokenized_text.attention_mask.view(tokenized_text.attention_mask.size(0), 1, -1, 1,
159
+ 1) # (bsz,1,token_len, 1,1)
160
+ token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
161
+ token_length = token_length.cpu()
162
+ # grads and cams [bsz, num_head, seq_len, image_patch]
163
+ grads = model.text_encoder.base_model.base_model.encoder.layer[
164
+ block_num].crossattention.self.get_attn_gradients()
165
+ cams = model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()
166
+
167
+ # assume using vit with 576 num image patch
168
+ cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
169
+ grads = (grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24) * mask)
170
+
171
+ gradcams = cams * grads
172
+ gradcam_list = []
173
+
174
+ for ind in range(visual_input.size(0)):
175
+ token_length_ = token_length[ind]
176
+ gradcam = gradcams[ind].mean(0).cpu().detach()
177
+ # [enc token gradcam, average gradcam across token, gradcam for individual token]
178
+ gradcam = torch.cat((
179
+ gradcam[0:1, :],
180
+ gradcam[1:token_length_ + 1, :].sum(dim=0, keepdim=True) / token_length_,
181
+ gradcam[1:, :],
182
+ ))
183
+ gradcam_list.append(gradcam)
184
+
185
+ return gradcam_list, output
@@ -0,0 +1,178 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from transformers import BertConfig
13
+
14
+ from ...common.dist_utils import download_cached_file
15
+ from ...common.registry import registry
16
+ from ...common.utils import get_abs_path, is_url
17
+ from ..base_model import MomentumDistilationMixin
18
+ from ..vit import VisionTransformerEncoder, interpolate_pos_embed
19
+ from .blip import BlipBase
20
+ from .blip_outputs import BlipIntermediateOutput, BlipOutput
21
+ from .nlvr_encoder import BertModel
22
+
23
+
24
+ @registry.register_model('blip_nlvr')
25
+ class BlipNLVR(BlipBase, MomentumDistilationMixin):
26
+ """
27
+ Class for BLIP NLVR model.
28
+
29
+ Supported model types:
30
+ - base: model with pre-trained BLIP weights, used as initialization for fine-tuning.
31
+ - nlvr: finetuned model on NLVR2 dataset.
32
+
33
+ Usage:
34
+ >>> from lavis.models import load_model
35
+ >>> model = load_model("blip_nlvr", "nlvr")
36
+ """
37
+
38
+ PRETRAINED_MODEL_CONFIG_DICT = {
39
+ 'nlvr': 'configs/models/blip_nlvr.yaml',
40
+ }
41
+
42
+ def __init__(self, image_encoder, text_encoder, num_classes):
43
+ super().__init__()
44
+
45
+ self.tokenizer = self.init_tokenizer()
46
+ self.visual_encoder = image_encoder
47
+ self.text_encoder = text_encoder
48
+
49
+ hidden_size = text_encoder.config.hidden_size
50
+ self.cls_head = nn.Sequential(
51
+ nn.Linear(hidden_size, hidden_size),
52
+ nn.ReLU(),
53
+ nn.Linear(hidden_size, num_classes),
54
+ )
55
+
56
+ def forward(self, samples, is_train=True):
57
+ """
58
+ Forward function for training and evaluation.
59
+
60
+ Args:
61
+ samples (dict): a dict of input samples, which contains the following keys:
62
+ - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384.
63
+ - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384.
64
+ - text_input (list): list of strings, each string is a natural language sentence.
65
+ - label (torch.LongTensor): ground truth label with shape (batch_size,).
66
+ is_train (bool): whether the model is in training mode.
67
+ If True, the model will return the loss;
68
+ If False, the model will return the prediction.
69
+
70
+ Examples:
71
+ >>> import torch
72
+ >>> from lavis.models import load_model
73
+ >>> model = load_model("blip_nlvr", "nlvr")
74
+ >>> samples = {
75
+ ... "image0": torch.randn(2, 3, 384, 384),
76
+ ... "image1": torch.randn(2, 3, 384, 384),
77
+ ... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"],
78
+ ... "label": torch.tensor([0, 1]),
79
+ ... }
80
+ >>> output = model(samples)
81
+ >>> output.keys()
82
+ odict_keys(['intermediate_output', 'loss'])
83
+ """
84
+ text = samples['text_input']
85
+ text = self.tokenizer(text, padding='longest', return_tensors='pt').to(self.device)
86
+ text.input_ids[:, 0] = self.tokenizer.enc_token_id
87
+
88
+ targets = samples['label']
89
+
90
+ image0 = samples['image0']
91
+ image1 = samples['image1']
92
+ images = torch.cat([image0, image1], dim=0)
93
+
94
+ image_embeds = self.visual_encoder.forward_features(images)
95
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
96
+ image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))
97
+
98
+ encoder_output = self.text_encoder(
99
+ text.input_ids,
100
+ attention_mask=text.attention_mask,
101
+ encoder_hidden_states=[image0_embeds, image1_embeds],
102
+ encoder_attention_mask=[
103
+ image_atts[:image0_embeds.size(0)],
104
+ image_atts[image0_embeds.size(0):],
105
+ ],
106
+ return_dict=True,
107
+ )
108
+
109
+ prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])
110
+
111
+ if is_train:
112
+ loss = F.cross_entropy(prediction, targets)
113
+ # return {"loss": loss}
114
+ return BlipOutput(
115
+ loss=loss,
116
+ intermediate_output=BlipIntermediateOutput(
117
+ image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0),
118
+ encoder_output=encoder_output,
119
+ ),
120
+ )
121
+ else:
122
+ return {'predictions': prediction, 'targets': targets}
123
+
124
+ def predict(self, samples):
125
+ output = self.forward(samples, is_train=False)
126
+ return output
127
+
128
+ @classmethod
129
+ def from_config(cls, cfg=None):
130
+ image_encoder = VisionTransformerEncoder.from_config(cfg)
131
+
132
+ # text encoder + multimodal encoder
133
+ bert_config = BertConfig.from_json_file(get_abs_path(cfg['med_config_path']))
134
+ text_encoder = BertModel(config=bert_config, add_pooling_layer=False)
135
+
136
+ num_classes = cfg.get('num_classes', 3)
137
+
138
+ assert num_classes > 1, 'Invalid number of classes provided, found {}'.format(num_classes)
139
+
140
+ model = cls(
141
+ image_encoder=image_encoder,
142
+ text_encoder=text_encoder,
143
+ num_classes=num_classes,
144
+ )
145
+
146
+ model.load_checkpoint_from_config(cfg)
147
+
148
+ return model
149
+
150
+ def load_from_pretrained(self, url_or_filename):
151
+ if is_url(url_or_filename):
152
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
153
+ checkpoint = torch.load(cached_file, map_location='cpu')
154
+ elif os.path.isfile(url_or_filename):
155
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
156
+ else:
157
+ raise RuntimeError('checkpoint url or path is invalid')
158
+ state_dict = checkpoint['model']
159
+
160
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
161
+ self.visual_encoder)
162
+
163
+ for key in list(state_dict.keys()):
164
+ if 'crossattention.self.' in key:
165
+ new_key0 = key.replace('self', 'self0')
166
+ new_key1 = key.replace('self', 'self1')
167
+ state_dict[new_key0] = state_dict[key]
168
+ state_dict[new_key1] = state_dict[key]
169
+ elif 'crossattention.output.dense.' in key:
170
+ new_key0 = key.replace('dense', 'dense0')
171
+ new_key1 = key.replace('dense', 'dense1')
172
+ state_dict[new_key0] = state_dict[key]
173
+ state_dict[new_key1] = state_dict[key]
174
+
175
+ msg = self.load_state_dict(state_dict, strict=False)
176
+ print('load checkpoint from %s' % url_or_filename)
177
+ print(f'missing keys {msg.missing_keys}')
178
+ return msg
@@ -0,0 +1,112 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import torch
9
+ from dataclasses import dataclass
10
+ from transformers.modeling_outputs import (BaseModelOutputWithPoolingAndCrossAttentions,
11
+ CausalLMOutputWithCrossAttentions, ModelOutput)
12
+ from typing import Optional
13
+
14
+
15
+ @dataclass
16
+ class BlipSimilarity(ModelOutput):
17
+ sim_i2t: torch.FloatTensor = None
18
+ sim_t2i: torch.FloatTensor = None
19
+
20
+ sim_i2t_m: Optional[torch.FloatTensor] = None
21
+ sim_t2i_m: Optional[torch.FloatTensor] = None
22
+
23
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
24
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
25
+
26
+
27
+ @dataclass
28
+ class BlipIntermediateOutput(ModelOutput):
29
+ """
30
+ Data class for intermediate outputs of BLIP models.
31
+
32
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
33
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
34
+
35
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
36
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
37
+
38
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
39
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
40
+
41
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
42
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
43
+
44
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
45
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
46
+
47
+ """
48
+
49
+ # uni-modal features
50
+ image_embeds: torch.FloatTensor = None
51
+ text_embeds: Optional[torch.FloatTensor] = None
52
+
53
+ image_embeds_m: Optional[torch.FloatTensor] = None
54
+ text_embeds_m: Optional[torch.FloatTensor] = None
55
+
56
+ # intermediate outputs of multimodal encoder
57
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
58
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
59
+
60
+ itm_logits: Optional[torch.FloatTensor] = None
61
+ itm_labels: Optional[torch.LongTensor] = None
62
+
63
+ # intermediate outputs of multimodal decoder
64
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
65
+ decoder_labels: Optional[torch.LongTensor] = None
66
+
67
+
68
+ @dataclass
69
+ class BlipOutput(ModelOutput):
70
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
71
+ sims: Optional[BlipSimilarity] = None
72
+
73
+ intermediate_output: BlipIntermediateOutput = None
74
+
75
+ loss: Optional[torch.FloatTensor] = None
76
+
77
+ loss_itc: Optional[torch.FloatTensor] = None
78
+
79
+ loss_itm: Optional[torch.FloatTensor] = None
80
+
81
+ loss_lm: Optional[torch.FloatTensor] = None
82
+
83
+
84
+ @dataclass
85
+ class BlipOutputWithLogits(BlipOutput):
86
+ logits: torch.FloatTensor = None
87
+ logits_m: torch.FloatTensor = None
88
+
89
+
90
+ @dataclass
91
+ class BlipOutputFeatures(ModelOutput):
92
+ """
93
+ Data class of features from BlipFeatureExtractor.
94
+
95
+ Args:
96
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
97
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
98
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
99
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
100
+
101
+ The first embedding or feature is for the [CLS] token.
102
+
103
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
104
+ """
105
+
106
+ image_embeds: Optional[torch.FloatTensor] = None
107
+ image_embeds_proj: Optional[torch.FloatTensor] = None
108
+
109
+ text_embeds: Optional[torch.FloatTensor] = None
110
+ text_embeds_proj: Optional[torch.FloatTensor] = None
111
+
112
+ multimodal_embeds: Optional[torch.FloatTensor] = None