evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +9 -3
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/benchmark.py +4 -3
- evalscope/perf/main.py +4 -2
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +16 -8
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +20 -7
- tests/perf/test_perf.py +6 -3
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
import warnings
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from ...common.registry import registry
|
|
14
|
+
from ..med import XBertEncoder
|
|
15
|
+
from ..vit import VisionTransformerEncoder
|
|
16
|
+
from .blip import BlipBase
|
|
17
|
+
from .blip_outputs import BlipOutputFeatures
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@registry.register_model('blip_feature_extractor')
|
|
21
|
+
class BlipFeatureExtractor(BlipBase):
|
|
22
|
+
"""
|
|
23
|
+
Class for BLIP feature extractor.
|
|
24
|
+
|
|
25
|
+
Supported model types:
|
|
26
|
+
- base: BLIP base model with pre-trained weights from capfilt by BLIP large model.
|
|
27
|
+
|
|
28
|
+
Usage:
|
|
29
|
+
>>> from lavis.models import load_model
|
|
30
|
+
>>> model = load_model("blip_feature_extractor", "base")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
|
34
|
+
'base': 'configs/models/blip_feature_extractor_base.yaml',
|
|
35
|
+
# "large": "configs/models/blip_feature_extractor_large.yaml",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def __init__(self, image_encoder, text_encoder, embed_dim, max_txt_len=40):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.tokenizer = self.init_tokenizer()
|
|
42
|
+
|
|
43
|
+
self.visual_encoder = image_encoder
|
|
44
|
+
self.text_encoder = text_encoder
|
|
45
|
+
|
|
46
|
+
# creating projection layers for ITC
|
|
47
|
+
text_width = text_encoder.config.hidden_size
|
|
48
|
+
vision_width = image_encoder.vision_width
|
|
49
|
+
|
|
50
|
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
|
51
|
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
|
52
|
+
|
|
53
|
+
self.max_txt_len = max_txt_len
|
|
54
|
+
|
|
55
|
+
self.temp = nn.Parameter(0.07 * torch.ones([]))
|
|
56
|
+
|
|
57
|
+
@torch.no_grad()
|
|
58
|
+
def extract_features(self, samples, mode='multimodal'):
|
|
59
|
+
"""
|
|
60
|
+
Extract features for multimodal or unimodal samples.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
samples (dict): A dictionary of samples, containing the following keys:
|
|
64
|
+
- image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.
|
|
65
|
+
Raw images should be preprocessed before being passed to feature extractor.
|
|
66
|
+
- text_input (list): A list of strings containing the text, length B.
|
|
67
|
+
mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image".
|
|
68
|
+
If "multimodal", return image features and multimodal features;
|
|
69
|
+
if "text", return text features;
|
|
70
|
+
if "image", return image features.
|
|
71
|
+
Default: "multimodal".
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
BlipOutputFeatures: A BlipOutputFeatures object containing the features.
|
|
75
|
+
See lavis/models/blip_models/blip_outputs.py for more details.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
```python
|
|
79
|
+
>>> from PIL import Image
|
|
80
|
+
>>> from lavis.models import load_model_and_preprocess
|
|
81
|
+
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
|
|
82
|
+
>>> caption = "a large fountain spewing water into the air"
|
|
83
|
+
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_feature_extractor", is_eval=True)
|
|
84
|
+
>>> image = vis_processors["eval"](raw_image).unsqueeze(0)
|
|
85
|
+
>>> text_input = txt_processors["eval"](caption)
|
|
86
|
+
|
|
87
|
+
>>> sample = {"image": image, "text_input": [text_input]}
|
|
88
|
+
|
|
89
|
+
>>> features_multimodal = model.extract_features(sample)
|
|
90
|
+
>>> features_multimodal.keys()
|
|
91
|
+
odict_keys(['image_embeds', 'multimodal_embeds'])
|
|
92
|
+
>>> features_multimodal.image_embeds.shape
|
|
93
|
+
torch.Size([1, 197, 768])
|
|
94
|
+
>>> features_multimodal.multimodal_embeds.shape
|
|
95
|
+
torch.Size([1, 12, 768])
|
|
96
|
+
|
|
97
|
+
>>> features_text = model.extract_features(sample, mode="text")
|
|
98
|
+
>>> features_text.keys()
|
|
99
|
+
odict_keys(['text_embeds', 'text_features'])
|
|
100
|
+
>>> features_text.text_embeds.shape
|
|
101
|
+
torch.Size([1, 12, 768])
|
|
102
|
+
>>> features_text.text_features.shape
|
|
103
|
+
torch.Size([1, 12, 256])
|
|
104
|
+
|
|
105
|
+
>>> features_image = model.extract_features(sample, mode="image")
|
|
106
|
+
>>> features_image.keys()
|
|
107
|
+
odict_keys(['image_embeds', 'image_features'])
|
|
108
|
+
>>> features_image.image_embeds.shape
|
|
109
|
+
torch.Size([1, 197, 768])
|
|
110
|
+
>>> features_image.image_features.shape
|
|
111
|
+
torch.Size([1, 197, 256])
|
|
112
|
+
```
|
|
113
|
+
"""
|
|
114
|
+
image = samples.get('image')
|
|
115
|
+
caption = samples.get('text_input')
|
|
116
|
+
|
|
117
|
+
# assert mode is one of "image", "text", "multimodal"
|
|
118
|
+
assert mode in [
|
|
119
|
+
'image',
|
|
120
|
+
'text',
|
|
121
|
+
'multimodal',
|
|
122
|
+
], "mode must be one of 'image', 'text', 'multimodal'"
|
|
123
|
+
|
|
124
|
+
# initalize output
|
|
125
|
+
image_embeds, text_embeds, multimodal_embeds = None, None, None
|
|
126
|
+
image_features, text_features = None, None
|
|
127
|
+
|
|
128
|
+
if mode == 'image':
|
|
129
|
+
assert (image is not None), "Image is not provided for mode 'image' or 'multimodal'"
|
|
130
|
+
# return image features
|
|
131
|
+
image_embeds = self.visual_encoder.forward_features(image)
|
|
132
|
+
|
|
133
|
+
image_features = self.vision_proj(image_embeds)
|
|
134
|
+
image_features = F.normalize(image_features, dim=-1)
|
|
135
|
+
|
|
136
|
+
elif mode == 'text':
|
|
137
|
+
assert (caption is not None), "text input is None for mode 'text' or 'multimodal'"
|
|
138
|
+
|
|
139
|
+
text = self.tokenizer(caption, return_tensors='pt', padding=True).to(self.device)
|
|
140
|
+
|
|
141
|
+
# return text features
|
|
142
|
+
text_output = self.text_encoder(
|
|
143
|
+
text.input_ids,
|
|
144
|
+
attention_mask=text.attention_mask,
|
|
145
|
+
return_dict=True,
|
|
146
|
+
mode='text',
|
|
147
|
+
)
|
|
148
|
+
text_embeds = text_output.last_hidden_state
|
|
149
|
+
|
|
150
|
+
text_features = self.text_proj(text_embeds)
|
|
151
|
+
text_features = F.normalize(text_features, dim=-1)
|
|
152
|
+
|
|
153
|
+
elif mode == 'multimodal':
|
|
154
|
+
# return multimodel features
|
|
155
|
+
image_embeds = self.visual_encoder.forward_features(image)
|
|
156
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
157
|
+
|
|
158
|
+
text = self.tokenizer(caption, return_tensors='pt', padding=True).to(self.device)
|
|
159
|
+
text.input_ids[:, 0] = self.tokenizer.enc_token_id
|
|
160
|
+
|
|
161
|
+
output = self.text_encoder(
|
|
162
|
+
text.input_ids,
|
|
163
|
+
attention_mask=text.attention_mask,
|
|
164
|
+
encoder_hidden_states=image_embeds,
|
|
165
|
+
encoder_attention_mask=image_atts,
|
|
166
|
+
return_dict=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
multimodal_embeds = output.last_hidden_state
|
|
170
|
+
|
|
171
|
+
return BlipOutputFeatures(
|
|
172
|
+
image_embeds=image_embeds,
|
|
173
|
+
image_embeds_proj=image_features,
|
|
174
|
+
text_embeds=text_embeds,
|
|
175
|
+
text_embeds_proj=text_features,
|
|
176
|
+
multimodal_embeds=multimodal_embeds,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def from_config(cls, cfg=None):
|
|
181
|
+
# set from_pretrained=True to load weights for 'bert-base-uncased'
|
|
182
|
+
image_encoder = VisionTransformerEncoder.from_config(cfg)
|
|
183
|
+
text_encoder = XBertEncoder.from_config(cfg)
|
|
184
|
+
|
|
185
|
+
embed_dim = cfg.get('embed_dim', 256)
|
|
186
|
+
max_txt_len = cfg.get('max_txt_len', 30)
|
|
187
|
+
|
|
188
|
+
model = cls(
|
|
189
|
+
image_encoder=image_encoder,
|
|
190
|
+
text_encoder=text_encoder,
|
|
191
|
+
embed_dim=embed_dim,
|
|
192
|
+
max_txt_len=max_txt_len,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# load pre-trained weights
|
|
196
|
+
pretrain_path = cfg.get('pretrained', None)
|
|
197
|
+
if pretrain_path is not None:
|
|
198
|
+
msg = model.load_from_pretrained(url_or_filename=pretrain_path)
|
|
199
|
+
else:
|
|
200
|
+
warnings.warn('No pretrained weights are loaded.')
|
|
201
|
+
|
|
202
|
+
return model
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from ...common.registry import registry
|
|
13
|
+
from ..med import XBertEncoder
|
|
14
|
+
from ..vit import VisionTransformerEncoder
|
|
15
|
+
from .blip import BlipBase
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@registry.register_model('blip_image_text_matching')
|
|
19
|
+
class BlipITM(BlipBase):
|
|
20
|
+
"""
|
|
21
|
+
BLIP Image-Text Matching (ITM) model.
|
|
22
|
+
|
|
23
|
+
Supported model types:
|
|
24
|
+
- base: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).
|
|
25
|
+
- large: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).
|
|
26
|
+
|
|
27
|
+
Usage:
|
|
28
|
+
>>> from lavis.models import load_model
|
|
29
|
+
>>> model = load_model("blip_image_text_matching", "base")
|
|
30
|
+
>>> model = load_model("blip_image_text_matching", "large")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
|
34
|
+
'base': 'configs/models/blip_itm_base.yaml',
|
|
35
|
+
'large': 'configs/models/blip_itm_large.yaml',
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=35):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.tokenizer = self.init_tokenizer()
|
|
42
|
+
|
|
43
|
+
self.text_encoder = text_encoder
|
|
44
|
+
|
|
45
|
+
self.visual_encoder = image_encoder
|
|
46
|
+
|
|
47
|
+
self.max_txt_len = max_txt_len
|
|
48
|
+
|
|
49
|
+
# creating projection layers for ITC
|
|
50
|
+
text_width = text_encoder.config.hidden_size
|
|
51
|
+
vision_width = image_encoder.vision_width
|
|
52
|
+
|
|
53
|
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
|
54
|
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
|
55
|
+
|
|
56
|
+
self.itm_head = nn.Linear(text_width, 2)
|
|
57
|
+
|
|
58
|
+
def forward(self, samples, match_head='itm'):
|
|
59
|
+
image = samples['image']
|
|
60
|
+
caption = samples['text_input']
|
|
61
|
+
|
|
62
|
+
image_embeds = self.visual_encoder.forward_features(image)
|
|
63
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
|
64
|
+
|
|
65
|
+
text = self.tokenizer(
|
|
66
|
+
caption,
|
|
67
|
+
padding='longest',
|
|
68
|
+
truncation=True,
|
|
69
|
+
max_length=self.max_txt_len,
|
|
70
|
+
return_tensors='pt',
|
|
71
|
+
).to(image.device)
|
|
72
|
+
if match_head == 'itm':
|
|
73
|
+
encoder_input_ids = text.input_ids.clone()
|
|
74
|
+
encoder_input_ids[:, 0] = self.tokenizer.enc_token_id # extra code
|
|
75
|
+
output = self.text_encoder(
|
|
76
|
+
encoder_input_ids,
|
|
77
|
+
attention_mask=text.attention_mask,
|
|
78
|
+
encoder_hidden_states=image_embeds,
|
|
79
|
+
encoder_attention_mask=image_atts,
|
|
80
|
+
return_dict=True,
|
|
81
|
+
)
|
|
82
|
+
itm_output = self.itm_head(output.last_hidden_state[:, 0, :])
|
|
83
|
+
return itm_output
|
|
84
|
+
|
|
85
|
+
elif match_head == 'itc':
|
|
86
|
+
text_output = self.text_encoder(
|
|
87
|
+
text.input_ids,
|
|
88
|
+
attention_mask=text.attention_mask,
|
|
89
|
+
return_dict=True,
|
|
90
|
+
mode='text',
|
|
91
|
+
)
|
|
92
|
+
image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
|
|
93
|
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
|
|
94
|
+
|
|
95
|
+
sim = image_feat @ text_feat.t()
|
|
96
|
+
return sim
|
|
97
|
+
|
|
98
|
+
def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head='itm'):
|
|
99
|
+
# breakpoint()
|
|
100
|
+
encoder_input_ids = encoder_input_ids.clone()
|
|
101
|
+
encoder_input_ids = encoder_input_ids[:, 3:]
|
|
102
|
+
text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long()
|
|
103
|
+
|
|
104
|
+
if match_head == 'itm':
|
|
105
|
+
# encoder_input_ids = encoder_input_ids.clone()
|
|
106
|
+
encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
|
|
107
|
+
output = self.text_encoder(
|
|
108
|
+
encoder_input_ids,
|
|
109
|
+
attention_mask=text_attention_mask,
|
|
110
|
+
encoder_hidden_states=image_embeds,
|
|
111
|
+
encoder_attention_mask=image_atts,
|
|
112
|
+
return_dict=True,
|
|
113
|
+
)
|
|
114
|
+
# print(output.last_hidden_state.shape)
|
|
115
|
+
itm_output = self.itm_head(output.last_hidden_state[:, 0, :])
|
|
116
|
+
itm_output = F.softmax(itm_output, dim=1)[:, 1]
|
|
117
|
+
return itm_output #, mask, token_length
|
|
118
|
+
|
|
119
|
+
elif match_head == 'itc':
|
|
120
|
+
encoder_input_ids[:, 0] = self.tokenizer.cls_token_id
|
|
121
|
+
text_output = self.text_encoder(
|
|
122
|
+
encoder_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text')
|
|
123
|
+
image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
|
|
124
|
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
|
|
125
|
+
|
|
126
|
+
sim = image_feat @ text_feat.t()
|
|
127
|
+
return sim
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_config(cls, cfg=None):
|
|
131
|
+
image_encoder = VisionTransformerEncoder.from_config(cfg)
|
|
132
|
+
text_encoder = XBertEncoder.from_config(cfg)
|
|
133
|
+
|
|
134
|
+
embed_dim = cfg.get('embed_dim', 256)
|
|
135
|
+
max_txt_len = cfg.get('max_txt_len', 35)
|
|
136
|
+
|
|
137
|
+
model = cls(
|
|
138
|
+
image_encoder=image_encoder,
|
|
139
|
+
text_encoder=text_encoder,
|
|
140
|
+
embed_dim=embed_dim,
|
|
141
|
+
max_txt_len=max_txt_len,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
model.load_checkpoint_from_config(cfg)
|
|
145
|
+
|
|
146
|
+
return model
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6):
|
|
150
|
+
model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True
|
|
151
|
+
|
|
152
|
+
output = model({'image': visual_input, 'text_input': text_input}, match_head='itm')
|
|
153
|
+
loss = output[:, 1].sum()
|
|
154
|
+
|
|
155
|
+
model.zero_grad()
|
|
156
|
+
loss.backward()
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
mask = tokenized_text.attention_mask.view(tokenized_text.attention_mask.size(0), 1, -1, 1,
|
|
159
|
+
1) # (bsz,1,token_len, 1,1)
|
|
160
|
+
token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
|
|
161
|
+
token_length = token_length.cpu()
|
|
162
|
+
# grads and cams [bsz, num_head, seq_len, image_patch]
|
|
163
|
+
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
|
164
|
+
block_num].crossattention.self.get_attn_gradients()
|
|
165
|
+
cams = model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()
|
|
166
|
+
|
|
167
|
+
# assume using vit with 576 num image patch
|
|
168
|
+
cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
|
169
|
+
grads = (grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24) * mask)
|
|
170
|
+
|
|
171
|
+
gradcams = cams * grads
|
|
172
|
+
gradcam_list = []
|
|
173
|
+
|
|
174
|
+
for ind in range(visual_input.size(0)):
|
|
175
|
+
token_length_ = token_length[ind]
|
|
176
|
+
gradcam = gradcams[ind].mean(0).cpu().detach()
|
|
177
|
+
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
|
178
|
+
gradcam = torch.cat((
|
|
179
|
+
gradcam[0:1, :],
|
|
180
|
+
gradcam[1:token_length_ + 1, :].sum(dim=0, keepdim=True) / token_length_,
|
|
181
|
+
gradcam[1:, :],
|
|
182
|
+
))
|
|
183
|
+
gradcam_list.append(gradcam)
|
|
184
|
+
|
|
185
|
+
return gradcam_list, output
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
from transformers import BertConfig
|
|
13
|
+
|
|
14
|
+
from ...common.dist_utils import download_cached_file
|
|
15
|
+
from ...common.registry import registry
|
|
16
|
+
from ...common.utils import get_abs_path, is_url
|
|
17
|
+
from ..base_model import MomentumDistilationMixin
|
|
18
|
+
from ..vit import VisionTransformerEncoder, interpolate_pos_embed
|
|
19
|
+
from .blip import BlipBase
|
|
20
|
+
from .blip_outputs import BlipIntermediateOutput, BlipOutput
|
|
21
|
+
from .nlvr_encoder import BertModel
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@registry.register_model('blip_nlvr')
|
|
25
|
+
class BlipNLVR(BlipBase, MomentumDistilationMixin):
|
|
26
|
+
"""
|
|
27
|
+
Class for BLIP NLVR model.
|
|
28
|
+
|
|
29
|
+
Supported model types:
|
|
30
|
+
- base: model with pre-trained BLIP weights, used as initialization for fine-tuning.
|
|
31
|
+
- nlvr: finetuned model on NLVR2 dataset.
|
|
32
|
+
|
|
33
|
+
Usage:
|
|
34
|
+
>>> from lavis.models import load_model
|
|
35
|
+
>>> model = load_model("blip_nlvr", "nlvr")
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
|
39
|
+
'nlvr': 'configs/models/blip_nlvr.yaml',
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
def __init__(self, image_encoder, text_encoder, num_classes):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
self.tokenizer = self.init_tokenizer()
|
|
46
|
+
self.visual_encoder = image_encoder
|
|
47
|
+
self.text_encoder = text_encoder
|
|
48
|
+
|
|
49
|
+
hidden_size = text_encoder.config.hidden_size
|
|
50
|
+
self.cls_head = nn.Sequential(
|
|
51
|
+
nn.Linear(hidden_size, hidden_size),
|
|
52
|
+
nn.ReLU(),
|
|
53
|
+
nn.Linear(hidden_size, num_classes),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def forward(self, samples, is_train=True):
|
|
57
|
+
"""
|
|
58
|
+
Forward function for training and evaluation.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
samples (dict): a dict of input samples, which contains the following keys:
|
|
62
|
+
- image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384.
|
|
63
|
+
- image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384.
|
|
64
|
+
- text_input (list): list of strings, each string is a natural language sentence.
|
|
65
|
+
- label (torch.LongTensor): ground truth label with shape (batch_size,).
|
|
66
|
+
is_train (bool): whether the model is in training mode.
|
|
67
|
+
If True, the model will return the loss;
|
|
68
|
+
If False, the model will return the prediction.
|
|
69
|
+
|
|
70
|
+
Examples:
|
|
71
|
+
>>> import torch
|
|
72
|
+
>>> from lavis.models import load_model
|
|
73
|
+
>>> model = load_model("blip_nlvr", "nlvr")
|
|
74
|
+
>>> samples = {
|
|
75
|
+
... "image0": torch.randn(2, 3, 384, 384),
|
|
76
|
+
... "image1": torch.randn(2, 3, 384, 384),
|
|
77
|
+
... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"],
|
|
78
|
+
... "label": torch.tensor([0, 1]),
|
|
79
|
+
... }
|
|
80
|
+
>>> output = model(samples)
|
|
81
|
+
>>> output.keys()
|
|
82
|
+
odict_keys(['intermediate_output', 'loss'])
|
|
83
|
+
"""
|
|
84
|
+
text = samples['text_input']
|
|
85
|
+
text = self.tokenizer(text, padding='longest', return_tensors='pt').to(self.device)
|
|
86
|
+
text.input_ids[:, 0] = self.tokenizer.enc_token_id
|
|
87
|
+
|
|
88
|
+
targets = samples['label']
|
|
89
|
+
|
|
90
|
+
image0 = samples['image0']
|
|
91
|
+
image1 = samples['image1']
|
|
92
|
+
images = torch.cat([image0, image1], dim=0)
|
|
93
|
+
|
|
94
|
+
image_embeds = self.visual_encoder.forward_features(images)
|
|
95
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
96
|
+
image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))
|
|
97
|
+
|
|
98
|
+
encoder_output = self.text_encoder(
|
|
99
|
+
text.input_ids,
|
|
100
|
+
attention_mask=text.attention_mask,
|
|
101
|
+
encoder_hidden_states=[image0_embeds, image1_embeds],
|
|
102
|
+
encoder_attention_mask=[
|
|
103
|
+
image_atts[:image0_embeds.size(0)],
|
|
104
|
+
image_atts[image0_embeds.size(0):],
|
|
105
|
+
],
|
|
106
|
+
return_dict=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])
|
|
110
|
+
|
|
111
|
+
if is_train:
|
|
112
|
+
loss = F.cross_entropy(prediction, targets)
|
|
113
|
+
# return {"loss": loss}
|
|
114
|
+
return BlipOutput(
|
|
115
|
+
loss=loss,
|
|
116
|
+
intermediate_output=BlipIntermediateOutput(
|
|
117
|
+
image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0),
|
|
118
|
+
encoder_output=encoder_output,
|
|
119
|
+
),
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
return {'predictions': prediction, 'targets': targets}
|
|
123
|
+
|
|
124
|
+
def predict(self, samples):
|
|
125
|
+
output = self.forward(samples, is_train=False)
|
|
126
|
+
return output
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def from_config(cls, cfg=None):
|
|
130
|
+
image_encoder = VisionTransformerEncoder.from_config(cfg)
|
|
131
|
+
|
|
132
|
+
# text encoder + multimodal encoder
|
|
133
|
+
bert_config = BertConfig.from_json_file(get_abs_path(cfg['med_config_path']))
|
|
134
|
+
text_encoder = BertModel(config=bert_config, add_pooling_layer=False)
|
|
135
|
+
|
|
136
|
+
num_classes = cfg.get('num_classes', 3)
|
|
137
|
+
|
|
138
|
+
assert num_classes > 1, 'Invalid number of classes provided, found {}'.format(num_classes)
|
|
139
|
+
|
|
140
|
+
model = cls(
|
|
141
|
+
image_encoder=image_encoder,
|
|
142
|
+
text_encoder=text_encoder,
|
|
143
|
+
num_classes=num_classes,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
model.load_checkpoint_from_config(cfg)
|
|
147
|
+
|
|
148
|
+
return model
|
|
149
|
+
|
|
150
|
+
def load_from_pretrained(self, url_or_filename):
|
|
151
|
+
if is_url(url_or_filename):
|
|
152
|
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
|
153
|
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
|
154
|
+
elif os.path.isfile(url_or_filename):
|
|
155
|
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
|
156
|
+
else:
|
|
157
|
+
raise RuntimeError('checkpoint url or path is invalid')
|
|
158
|
+
state_dict = checkpoint['model']
|
|
159
|
+
|
|
160
|
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
|
|
161
|
+
self.visual_encoder)
|
|
162
|
+
|
|
163
|
+
for key in list(state_dict.keys()):
|
|
164
|
+
if 'crossattention.self.' in key:
|
|
165
|
+
new_key0 = key.replace('self', 'self0')
|
|
166
|
+
new_key1 = key.replace('self', 'self1')
|
|
167
|
+
state_dict[new_key0] = state_dict[key]
|
|
168
|
+
state_dict[new_key1] = state_dict[key]
|
|
169
|
+
elif 'crossattention.output.dense.' in key:
|
|
170
|
+
new_key0 = key.replace('dense', 'dense0')
|
|
171
|
+
new_key1 = key.replace('dense', 'dense1')
|
|
172
|
+
state_dict[new_key0] = state_dict[key]
|
|
173
|
+
state_dict[new_key1] = state_dict[key]
|
|
174
|
+
|
|
175
|
+
msg = self.load_state_dict(state_dict, strict=False)
|
|
176
|
+
print('load checkpoint from %s' % url_or_filename)
|
|
177
|
+
print(f'missing keys {msg.missing_keys}')
|
|
178
|
+
return msg
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from transformers.modeling_outputs import (BaseModelOutputWithPoolingAndCrossAttentions,
|
|
11
|
+
CausalLMOutputWithCrossAttentions, ModelOutput)
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class BlipSimilarity(ModelOutput):
|
|
17
|
+
sim_i2t: torch.FloatTensor = None
|
|
18
|
+
sim_t2i: torch.FloatTensor = None
|
|
19
|
+
|
|
20
|
+
sim_i2t_m: Optional[torch.FloatTensor] = None
|
|
21
|
+
sim_t2i_m: Optional[torch.FloatTensor] = None
|
|
22
|
+
|
|
23
|
+
sim_i2t_targets: Optional[torch.FloatTensor] = None
|
|
24
|
+
sim_t2i_targets: Optional[torch.FloatTensor] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class BlipIntermediateOutput(ModelOutput):
|
|
29
|
+
"""
|
|
30
|
+
Data class for intermediate outputs of BLIP models.
|
|
31
|
+
|
|
32
|
+
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
|
|
33
|
+
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
|
|
34
|
+
|
|
35
|
+
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
|
|
36
|
+
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
|
|
37
|
+
|
|
38
|
+
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
|
|
39
|
+
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
|
|
40
|
+
|
|
41
|
+
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
|
|
42
|
+
decoder_labels (torch.LongTensor): labels for the captioning loss.
|
|
43
|
+
|
|
44
|
+
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
|
|
45
|
+
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
# uni-modal features
|
|
50
|
+
image_embeds: torch.FloatTensor = None
|
|
51
|
+
text_embeds: Optional[torch.FloatTensor] = None
|
|
52
|
+
|
|
53
|
+
image_embeds_m: Optional[torch.FloatTensor] = None
|
|
54
|
+
text_embeds_m: Optional[torch.FloatTensor] = None
|
|
55
|
+
|
|
56
|
+
# intermediate outputs of multimodal encoder
|
|
57
|
+
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
|
58
|
+
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
|
59
|
+
|
|
60
|
+
itm_logits: Optional[torch.FloatTensor] = None
|
|
61
|
+
itm_labels: Optional[torch.LongTensor] = None
|
|
62
|
+
|
|
63
|
+
# intermediate outputs of multimodal decoder
|
|
64
|
+
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
|
|
65
|
+
decoder_labels: Optional[torch.LongTensor] = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class BlipOutput(ModelOutput):
|
|
70
|
+
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
|
|
71
|
+
sims: Optional[BlipSimilarity] = None
|
|
72
|
+
|
|
73
|
+
intermediate_output: BlipIntermediateOutput = None
|
|
74
|
+
|
|
75
|
+
loss: Optional[torch.FloatTensor] = None
|
|
76
|
+
|
|
77
|
+
loss_itc: Optional[torch.FloatTensor] = None
|
|
78
|
+
|
|
79
|
+
loss_itm: Optional[torch.FloatTensor] = None
|
|
80
|
+
|
|
81
|
+
loss_lm: Optional[torch.FloatTensor] = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class BlipOutputWithLogits(BlipOutput):
|
|
86
|
+
logits: torch.FloatTensor = None
|
|
87
|
+
logits_m: torch.FloatTensor = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class BlipOutputFeatures(ModelOutput):
|
|
92
|
+
"""
|
|
93
|
+
Data class of features from BlipFeatureExtractor.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
|
|
97
|
+
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
|
|
98
|
+
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
|
|
99
|
+
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
|
|
100
|
+
|
|
101
|
+
The first embedding or feature is for the [CLS] token.
|
|
102
|
+
|
|
103
|
+
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
image_embeds: Optional[torch.FloatTensor] = None
|
|
107
|
+
image_embeds_proj: Optional[torch.FloatTensor] = None
|
|
108
|
+
|
|
109
|
+
text_embeds: Optional[torch.FloatTensor] = None
|
|
110
|
+
text_embeds_proj: Optional[torch.FloatTensor] = None
|
|
111
|
+
|
|
112
|
+
multimodal_embeds: Optional[torch.FloatTensor] = None
|