evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +9 -3
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/benchmark.py +4 -3
- evalscope/perf/main.py +4 -2
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +16 -8
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +20 -7
- tests/perf/test_perf.py +6 -3
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from modelscope import AutoProcessor
|
|
3
|
+
from transformers import CLIPConfig
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from ...constants import CACHE_DIR
|
|
7
|
+
from ..model import ScoreModel
|
|
8
|
+
|
|
9
|
+
MPS_MODELS = ['mps']
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MPSModel(ScoreModel):
|
|
13
|
+
'A wrapper for MPS Score models'
|
|
14
|
+
|
|
15
|
+
def __init__(self, model_name='mps', device='cuda', cache_dir=CACHE_DIR):
|
|
16
|
+
assert model_name in MPS_MODELS
|
|
17
|
+
super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
|
|
18
|
+
|
|
19
|
+
def load_model(self):
|
|
20
|
+
"""Load the model, tokenizer, image transform
|
|
21
|
+
"""
|
|
22
|
+
from ..utils import download_file
|
|
23
|
+
from .build_mps_model.clip_model import CLIPModel
|
|
24
|
+
|
|
25
|
+
assert self.model_name == 'mps'
|
|
26
|
+
|
|
27
|
+
processor_name_or_path = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
|
|
28
|
+
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
|
29
|
+
|
|
30
|
+
config = download_file('AI-ModelScope/MPS', file_name='config.json', cache_dir=self.cache_dir)
|
|
31
|
+
model_pretrained_path = download_file(
|
|
32
|
+
'AI-ModelScope/MPS', file_name='MPS_overall_state_dict.pt', cache_dir=self.cache_dir) # modelscope model
|
|
33
|
+
model_weight = torch.load(model_pretrained_path, weights_only=True, map_location='cpu')
|
|
34
|
+
|
|
35
|
+
self.model = CLIPModel(config=CLIPConfig.from_json_file(config))
|
|
36
|
+
self.model.load_state_dict(model_weight, strict=False)
|
|
37
|
+
self.model.eval().to(self.device)
|
|
38
|
+
|
|
39
|
+
def load_images(self, image: List[str]) -> torch.Tensor:
|
|
40
|
+
"""Load the image(s), and return a tensor (no preprocessing!!) put on self.device
|
|
41
|
+
"""
|
|
42
|
+
image = [self.image_loader(x) for x in image]
|
|
43
|
+
image = self.processor(images=image, return_tensors='pt')['pixel_values']
|
|
44
|
+
return image
|
|
45
|
+
|
|
46
|
+
def process_text(self, text: List[str]) -> dict:
|
|
47
|
+
"""Process the text(s), and return a tensor (after preprocessing) put on self.device
|
|
48
|
+
"""
|
|
49
|
+
text_inputs = self.processor(
|
|
50
|
+
text=text,
|
|
51
|
+
padding='max_length',
|
|
52
|
+
truncation=True,
|
|
53
|
+
return_tensors='pt',
|
|
54
|
+
).input_ids
|
|
55
|
+
return text_inputs
|
|
56
|
+
|
|
57
|
+
@torch.no_grad()
|
|
58
|
+
def forward(self, images: List[str], texts: List[str], condition=None) -> torch.Tensor:
|
|
59
|
+
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
60
|
+
"""
|
|
61
|
+
assert len(images) == len(texts)
|
|
62
|
+
image_input = self.load_images(images).to(self.device)
|
|
63
|
+
text_input = self.process_text(texts).to(self.device)
|
|
64
|
+
if condition is None:
|
|
65
|
+
condition = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things.'
|
|
66
|
+
condition_batch = self.process_text(condition).repeat(text_input.shape[0], 1).to(self.device)
|
|
67
|
+
|
|
68
|
+
# embed
|
|
69
|
+
text_f, text_features = self.model.model.get_text_features(text_input)
|
|
70
|
+
|
|
71
|
+
image_f = self.model.model.get_image_features(image_input.half())
|
|
72
|
+
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
|
73
|
+
|
|
74
|
+
sim_text_condition = torch.einsum('b i d, b j d -> b j i', text_f, condition_f)
|
|
75
|
+
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
|
76
|
+
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
|
77
|
+
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
|
78
|
+
mask = mask.repeat(1, image_f.shape[1], 1)
|
|
79
|
+
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
|
80
|
+
|
|
81
|
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
|
82
|
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
83
|
+
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
|
84
|
+
|
|
85
|
+
return image_score[0]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from modelscope import AutoModel, AutoProcessor
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from ...constants import CACHE_DIR
|
|
7
|
+
from ..model import ScoreModel
|
|
8
|
+
|
|
9
|
+
PICKSCORE_MODELS = ['pickscore-v1']
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PickScoreModel(ScoreModel):
|
|
13
|
+
'A wrapper for PickScore models'
|
|
14
|
+
|
|
15
|
+
def __init__(self, model_name='pickscore-v1', device='cuda', cache_dir=CACHE_DIR):
|
|
16
|
+
assert model_name in PICKSCORE_MODELS
|
|
17
|
+
super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
|
|
18
|
+
|
|
19
|
+
def load_model(self):
|
|
20
|
+
"""Load the model, tokenizer, image transform
|
|
21
|
+
"""
|
|
22
|
+
assert self.model_name == 'pickscore-v1'
|
|
23
|
+
processor_name_or_path = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
|
|
24
|
+
# model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
|
|
25
|
+
model_pretrained_name_or_path = 'AI-ModelScope/PickScore_v1' # modelscope model
|
|
26
|
+
|
|
27
|
+
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
|
28
|
+
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
|
|
29
|
+
|
|
30
|
+
def load_images(self, image: List[str]) -> torch.Tensor:
|
|
31
|
+
"""Load the image(s), and return a tensor (no preprocessing!!) put on self.device
|
|
32
|
+
"""
|
|
33
|
+
image = [self.image_loader(x) for x in image]
|
|
34
|
+
image = self.processor(
|
|
35
|
+
images=image, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device)
|
|
36
|
+
# image = torch.stack(image, dim=0).to(self.device)
|
|
37
|
+
return image
|
|
38
|
+
|
|
39
|
+
@torch.no_grad()
|
|
40
|
+
def forward(self, images: List[str], texts: List[str]) -> torch.Tensor:
|
|
41
|
+
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
42
|
+
"""
|
|
43
|
+
assert len(images) == len(texts)
|
|
44
|
+
image = self.load_images(images)
|
|
45
|
+
text_inputs = self.processor(
|
|
46
|
+
text=texts,
|
|
47
|
+
padding=True,
|
|
48
|
+
truncation=True,
|
|
49
|
+
max_length=77,
|
|
50
|
+
return_tensors='pt',
|
|
51
|
+
).to(self.device)
|
|
52
|
+
|
|
53
|
+
# embed
|
|
54
|
+
image_embs = self.model.get_image_features(**image)
|
|
55
|
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
|
56
|
+
|
|
57
|
+
text_embs = self.model.get_text_features(**text_inputs)
|
|
58
|
+
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
|
59
|
+
|
|
60
|
+
# score
|
|
61
|
+
scores = (image_embs * text_embs).sum(dim=-1)
|
|
62
|
+
return scores
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from ...constants import CACHE_DIR
|
|
2
|
+
from .blip2_itm_model import BLIP2_ITM_MODELS, BLIP2ITMScoreModel
|
|
3
|
+
from .fga_blip2_model import FGA_BLIP2_MODELS, FGA_BLIP2ScoreModel
|
|
4
|
+
from .image_reward_model import IMAGE_REWARD_MODELS, ImageRewardScoreModel
|
|
5
|
+
|
|
6
|
+
ALL_ITM_MODELS = [
|
|
7
|
+
BLIP2_ITM_MODELS,
|
|
8
|
+
IMAGE_REWARD_MODELS,
|
|
9
|
+
FGA_BLIP2_MODELS,
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def list_all_itmscore_models():
|
|
14
|
+
return [model for models in ALL_ITM_MODELS for model in models]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_itmscore_model(model_name, device='cuda', cache_dir=CACHE_DIR):
|
|
18
|
+
assert model_name in list_all_itmscore_models()
|
|
19
|
+
if model_name in BLIP2_ITM_MODELS:
|
|
20
|
+
return BLIP2ITMScoreModel(model_name, device=device, cache_dir=cache_dir)
|
|
21
|
+
elif model_name in IMAGE_REWARD_MODELS:
|
|
22
|
+
return ImageRewardScoreModel(model_name, device=device, cache_dir=cache_dir)
|
|
23
|
+
elif model_name in FGA_BLIP2_MODELS:
|
|
24
|
+
return FGA_BLIP2ScoreModel(model_name, device=device, cache_dir=cache_dir)
|
|
25
|
+
else:
|
|
26
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torchvision import transforms
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from ...constants import CACHE_DIR
|
|
7
|
+
from ..model import ScoreModel
|
|
8
|
+
from ..vqascore_models.lavis.models import load_model
|
|
9
|
+
|
|
10
|
+
BLIP2_ITM_MODELS = {
|
|
11
|
+
'blip2-itm': {
|
|
12
|
+
'variant': 'pretrain'
|
|
13
|
+
},
|
|
14
|
+
'blip2-itm-vitL': {
|
|
15
|
+
'variant': 'pretrain_vitL'
|
|
16
|
+
},
|
|
17
|
+
'blip2-itm-coco': {
|
|
18
|
+
'variant': 'coco'
|
|
19
|
+
},
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BLIP2ITMScoreModel(ScoreModel):
|
|
24
|
+
'A wrapper for BLIP-2 ITMScore models'
|
|
25
|
+
|
|
26
|
+
def __init__(self, model_name='blip2-itm', device='cuda', cache_dir=CACHE_DIR):
|
|
27
|
+
assert model_name in BLIP2_ITM_MODELS, f'Model name must be one of {BLIP2_ITM_MODELS.keys()}'
|
|
28
|
+
os.environ['TORCH_HOME'] = cache_dir
|
|
29
|
+
super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
|
|
30
|
+
|
|
31
|
+
def load_model(self):
|
|
32
|
+
"""Load the model, tokenizer, image transform
|
|
33
|
+
"""
|
|
34
|
+
self.variant = BLIP2_ITM_MODELS[self.model_name]['variant']
|
|
35
|
+
self.model = load_model('blip2', self.variant, is_eval=True, device=self.device)
|
|
36
|
+
if self.variant == 'coco':
|
|
37
|
+
size = 364
|
|
38
|
+
else:
|
|
39
|
+
size = 224
|
|
40
|
+
self.image_preprocess = transforms.Compose([
|
|
41
|
+
transforms.Resize((size, size), interpolation=transforms.functional.InterpolationMode.BICUBIC),
|
|
42
|
+
transforms.ToTensor(),
|
|
43
|
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
|
44
|
+
])
|
|
45
|
+
|
|
46
|
+
def load_images(self, image: List[str]) -> torch.Tensor:
|
|
47
|
+
"""Load the image(s), and return a tensor (after preprocessing) put on self.device
|
|
48
|
+
"""
|
|
49
|
+
image = [self.image_loader(x) for x in image]
|
|
50
|
+
image = [self.image_preprocess(image) for image in image]
|
|
51
|
+
assert all(x.shape == image[0].shape for x in image)
|
|
52
|
+
image = torch.stack(image, dim=0).to(self.device)
|
|
53
|
+
return image
|
|
54
|
+
|
|
55
|
+
@torch.no_grad()
|
|
56
|
+
@torch.autocast(device_type='cuda', dtype=torch.float16)
|
|
57
|
+
def forward(self, images: List[str], texts: List[str]) -> torch.Tensor:
|
|
58
|
+
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
59
|
+
"""
|
|
60
|
+
assert len(images) == len(texts), 'Number of images and texts must match'
|
|
61
|
+
|
|
62
|
+
images = self.load_images(images)
|
|
63
|
+
image_feat = self.model.ln_vision(self.model.visual_encoder(images))
|
|
64
|
+
image_att = torch.ones(image_feat.size()[:-1], dtype=torch.long).to(self.device)
|
|
65
|
+
query_token = self.model.query_tokens.expand(image_feat.shape[0], -1, -1)
|
|
66
|
+
query_att = torch.ones(query_token.size()[:-1], dtype=torch.long).to(query_token.device)
|
|
67
|
+
|
|
68
|
+
text_input = self.model.tokenizer(
|
|
69
|
+
texts, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
|
|
70
|
+
|
|
71
|
+
attention_mask_all = torch.cat([query_att, text_input.attention_mask], dim=1)
|
|
72
|
+
output_itm = self.model.Qformer.bert(
|
|
73
|
+
text_input.input_ids,
|
|
74
|
+
query_embeds=query_token,
|
|
75
|
+
attention_mask=attention_mask_all,
|
|
76
|
+
encoder_hidden_states=image_feat,
|
|
77
|
+
encoder_attention_mask=image_att,
|
|
78
|
+
return_dict=True,
|
|
79
|
+
)
|
|
80
|
+
vl_embeddings = output_itm.last_hidden_state[:, :query_token.size(1), :]
|
|
81
|
+
vl_output = self.model.itm_head(vl_embeddings)
|
|
82
|
+
itm_logits = vl_output.mean(dim=1)
|
|
83
|
+
itm_prob = torch.nn.functional.softmax(itm_logits, dim=-1)[:, 1]
|
|
84
|
+
return itm_prob
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from modelscope import AutoTokenizer
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
|
|
6
|
+
from ...constants import CACHE_DIR
|
|
7
|
+
from ..model import ScoreModel
|
|
8
|
+
from ..vqascore_models.lavis.models import load_model_and_preprocess
|
|
9
|
+
|
|
10
|
+
FGA_BLIP2_MODELS = {
|
|
11
|
+
'fga_blip2': {
|
|
12
|
+
'variant': 'coco'
|
|
13
|
+
},
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_index(list1, list2):
|
|
18
|
+
len_list1 = len(list1)
|
|
19
|
+
len_list2 = len(list2)
|
|
20
|
+
for i in range(len_list2 - len_list1 + 1):
|
|
21
|
+
if list2[i:i + len_list1] == list1:
|
|
22
|
+
return i
|
|
23
|
+
return 0
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FGA_BLIP2ScoreModel(ScoreModel):
|
|
27
|
+
'A wrapper for FGA BLIP-2 ITMScore models'
|
|
28
|
+
|
|
29
|
+
def __init__(self, model_name='fga_blip2', device='cuda', cache_dir=CACHE_DIR):
|
|
30
|
+
assert model_name in FGA_BLIP2_MODELS, f'Model name must be one of {FGA_BLIP2_MODELS.keys()}'
|
|
31
|
+
os.environ['TORCH_HOME'] = cache_dir
|
|
32
|
+
super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
|
|
33
|
+
|
|
34
|
+
def load_model(self):
|
|
35
|
+
"""Load the model, tokenizer, image transform
|
|
36
|
+
"""
|
|
37
|
+
from ..utils import download_file
|
|
38
|
+
|
|
39
|
+
# load tokenizer
|
|
40
|
+
self.tokenizer = AutoTokenizer.from_pretrained('AI-ModelScope/bert-base-uncased', truncation_side='right')
|
|
41
|
+
self.tokenizer.add_special_tokens({'bos_token': '[DEC]'})
|
|
42
|
+
# load model
|
|
43
|
+
self.variant = FGA_BLIP2_MODELS[self.model_name]['variant']
|
|
44
|
+
self.model, self.vis_processors, self.text_processors = load_model_and_preprocess(
|
|
45
|
+
'fga_blip2', self.variant, is_eval=True, device=self.device)
|
|
46
|
+
# load pretrained weights
|
|
47
|
+
model_weight_path = download_file(
|
|
48
|
+
'AI-ModelScope/FGA-BLIP2', file_name='fga_blip2.pth', cache_dir=self.cache_dir)
|
|
49
|
+
self.model.load_checkpoint(model_weight_path)
|
|
50
|
+
self.model.eval()
|
|
51
|
+
|
|
52
|
+
def load_images(self, image):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@torch.no_grad()
|
|
56
|
+
@torch.autocast(device_type='cuda', dtype=torch.float16)
|
|
57
|
+
def forward(self, images: List[str], texts: List[Union[str, dict]]) -> torch.Tensor:
|
|
58
|
+
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
59
|
+
"""
|
|
60
|
+
assert len(images) == len(texts), 'Number of images and texts must match'
|
|
61
|
+
|
|
62
|
+
result_list = []
|
|
63
|
+
for image_path, text in zip(images, texts):
|
|
64
|
+
if isinstance(text, str):
|
|
65
|
+
elements = [] # elements scores
|
|
66
|
+
prompt = text
|
|
67
|
+
else:
|
|
68
|
+
elements = text['tags']
|
|
69
|
+
prompt = text['prompt']
|
|
70
|
+
|
|
71
|
+
image = self.image_loader(image_path)
|
|
72
|
+
image = self.vis_processors['eval'](image).to(self.device)
|
|
73
|
+
prompt = self.text_processors['eval'](prompt)
|
|
74
|
+
prompt_ids = self.tokenizer(prompt).input_ids
|
|
75
|
+
|
|
76
|
+
alignment_score, scores = self.model.element_score(image.unsqueeze(0), [prompt])
|
|
77
|
+
|
|
78
|
+
elements_score = dict()
|
|
79
|
+
for element in elements:
|
|
80
|
+
element_ = element.rpartition('(')[0]
|
|
81
|
+
element_ids = self.tokenizer(element_).input_ids[1:-1]
|
|
82
|
+
|
|
83
|
+
idx = get_index(element_ids, prompt_ids)
|
|
84
|
+
if idx:
|
|
85
|
+
mask = [0] * len(prompt_ids)
|
|
86
|
+
mask[idx:idx + len(element_ids)] = [1] * len(element_ids)
|
|
87
|
+
|
|
88
|
+
mask = torch.tensor(mask).to(self.device)
|
|
89
|
+
elements_score[element] = (scores * mask).sum() / mask.sum()
|
|
90
|
+
else:
|
|
91
|
+
elements_score[element] = torch.tensor(0.0).to(self.device)
|
|
92
|
+
if elements_score:
|
|
93
|
+
result_list.append({'overall_score': alignment_score, **elements_score})
|
|
94
|
+
else:
|
|
95
|
+
result_list.append(alignment_score)
|
|
96
|
+
|
|
97
|
+
return result_list
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# adapted from https://github.com/THUDM/ImageReward/tree/main/ImageReward
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
|
8
|
+
|
|
9
|
+
from .blip_pretrain import BLIP_Pretrain
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from torchvision.transforms import InterpolationMode
|
|
13
|
+
BICUBIC = InterpolationMode.BICUBIC
|
|
14
|
+
except ImportError:
|
|
15
|
+
BICUBIC = Image.BICUBIC
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _convert_image_to_rgb(image):
|
|
19
|
+
return image.convert('RGB')
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _transform(n_px):
|
|
23
|
+
return Compose([
|
|
24
|
+
Resize(n_px, interpolation=BICUBIC),
|
|
25
|
+
CenterCrop(n_px),
|
|
26
|
+
_convert_image_to_rgb,
|
|
27
|
+
ToTensor(),
|
|
28
|
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
|
29
|
+
])
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MLP(nn.Module):
|
|
33
|
+
|
|
34
|
+
def __init__(self, input_size):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.input_size = input_size
|
|
37
|
+
|
|
38
|
+
self.layers = nn.Sequential(
|
|
39
|
+
nn.Linear(self.input_size, 1024),
|
|
40
|
+
#nn.ReLU(),
|
|
41
|
+
nn.Dropout(0.2),
|
|
42
|
+
nn.Linear(1024, 128),
|
|
43
|
+
#nn.ReLU(),
|
|
44
|
+
nn.Dropout(0.2),
|
|
45
|
+
nn.Linear(128, 64),
|
|
46
|
+
#nn.ReLU(),
|
|
47
|
+
nn.Dropout(0.1),
|
|
48
|
+
nn.Linear(64, 16),
|
|
49
|
+
#nn.ReLU(),
|
|
50
|
+
nn.Linear(16, 1))
|
|
51
|
+
|
|
52
|
+
# initial MLP param
|
|
53
|
+
for name, param in self.layers.named_parameters():
|
|
54
|
+
if 'weight' in name:
|
|
55
|
+
nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1))
|
|
56
|
+
if 'bias' in name:
|
|
57
|
+
nn.init.constant_(param, val=0)
|
|
58
|
+
|
|
59
|
+
def forward(self, input):
|
|
60
|
+
return self.layers(input)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ImageReward(nn.Module):
|
|
64
|
+
|
|
65
|
+
def __init__(self, med_config, device='cpu'):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.device = device
|
|
68
|
+
|
|
69
|
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
|
70
|
+
self.preprocess = _transform(224)
|
|
71
|
+
self.mlp = MLP(768)
|
|
72
|
+
|
|
73
|
+
self.mean = 0.16717362830052426
|
|
74
|
+
self.std = 1.0333394966054072
|
|
75
|
+
|
|
76
|
+
def score_gard(self, prompt_ids, prompt_attention_mask, image):
|
|
77
|
+
|
|
78
|
+
image_embeds = self.blip.visual_encoder(image)
|
|
79
|
+
# text encode cross attention with image
|
|
80
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
81
|
+
text_output = self.blip.text_encoder(
|
|
82
|
+
prompt_ids,
|
|
83
|
+
attention_mask=prompt_attention_mask,
|
|
84
|
+
encoder_hidden_states=image_embeds,
|
|
85
|
+
encoder_attention_mask=image_atts,
|
|
86
|
+
return_dict=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim)
|
|
90
|
+
rewards = self.mlp(txt_features)
|
|
91
|
+
rewards = (rewards - self.mean) / self.std
|
|
92
|
+
|
|
93
|
+
return rewards
|
|
94
|
+
|
|
95
|
+
def score(self, prompt, image):
|
|
96
|
+
|
|
97
|
+
if (type(image).__name__ == 'list'):
|
|
98
|
+
_, rewards = self.inference_rank(prompt, image)
|
|
99
|
+
return rewards
|
|
100
|
+
|
|
101
|
+
# text encode
|
|
102
|
+
text_input = self.blip.tokenizer(
|
|
103
|
+
prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
|
|
104
|
+
|
|
105
|
+
# image encode
|
|
106
|
+
if isinstance(image, Image.Image):
|
|
107
|
+
pil_image = image
|
|
108
|
+
elif isinstance(image, str) and os.path.isfile(image):
|
|
109
|
+
pil_image = Image.open(image)
|
|
110
|
+
else:
|
|
111
|
+
raise TypeError(
|
|
112
|
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
|
113
|
+
|
|
114
|
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
115
|
+
image_embeds = self.blip.visual_encoder(image)
|
|
116
|
+
|
|
117
|
+
# text encode cross attention with image
|
|
118
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
119
|
+
text_output = self.blip.text_encoder(
|
|
120
|
+
text_input.input_ids,
|
|
121
|
+
attention_mask=text_input.attention_mask,
|
|
122
|
+
encoder_hidden_states=image_embeds,
|
|
123
|
+
encoder_attention_mask=image_atts,
|
|
124
|
+
return_dict=True,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
|
|
128
|
+
rewards = self.mlp(txt_features)
|
|
129
|
+
rewards = (rewards - self.mean) / self.std
|
|
130
|
+
|
|
131
|
+
return rewards.detach().cpu().numpy().item()
|
|
132
|
+
|
|
133
|
+
def inference_rank(self, prompt, generations_list):
|
|
134
|
+
|
|
135
|
+
text_input = self.blip.tokenizer(
|
|
136
|
+
prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
|
|
137
|
+
|
|
138
|
+
txt_set = []
|
|
139
|
+
for generation in generations_list:
|
|
140
|
+
# image encode
|
|
141
|
+
if isinstance(generation, Image.Image):
|
|
142
|
+
pil_image = generation
|
|
143
|
+
elif isinstance(generation, str):
|
|
144
|
+
if os.path.isfile(generation):
|
|
145
|
+
pil_image = Image.open(generation)
|
|
146
|
+
else:
|
|
147
|
+
raise TypeError(
|
|
148
|
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
|
149
|
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
150
|
+
image_embeds = self.blip.visual_encoder(image)
|
|
151
|
+
|
|
152
|
+
# text encode cross attention with image
|
|
153
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
154
|
+
text_output = self.blip.text_encoder(
|
|
155
|
+
text_input.input_ids,
|
|
156
|
+
attention_mask=text_input.attention_mask,
|
|
157
|
+
encoder_hidden_states=image_embeds,
|
|
158
|
+
encoder_attention_mask=image_atts,
|
|
159
|
+
return_dict=True,
|
|
160
|
+
)
|
|
161
|
+
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
|
162
|
+
|
|
163
|
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
|
164
|
+
rewards = self.mlp(txt_features) # [image_num, 1]
|
|
165
|
+
rewards = (rewards - self.mean) / self.std
|
|
166
|
+
rewards = torch.squeeze(rewards)
|
|
167
|
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
|
168
|
+
_, indices = torch.sort(rank, dim=0)
|
|
169
|
+
indices = indices + 1
|
|
170
|
+
|
|
171
|
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
|
File without changes
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
'''
|
|
2
|
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
|
3
|
+
'''
|
|
4
|
+
|
|
5
|
+
from modelscope import AutoTokenizer
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from ...vqascore_models.lavis.models.med import BertConfig, BertModel
|
|
9
|
+
from ...vqascore_models.lavis.models.vit import VisionTransformer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def init_tokenizer():
|
|
13
|
+
tokenizer = AutoTokenizer.from_pretrained('AI-ModelScope/bert-base-uncased')
|
|
14
|
+
tokenizer.add_special_tokens({'bos_token': '[DEC]'})
|
|
15
|
+
tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
|
|
16
|
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
|
17
|
+
return tokenizer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
|
21
|
+
|
|
22
|
+
assert vit in ['base', 'large'], 'vit parameter must be base or large'
|
|
23
|
+
if vit == 'base':
|
|
24
|
+
vision_width = 768
|
|
25
|
+
visual_encoder = VisionTransformer(
|
|
26
|
+
img_size=image_size,
|
|
27
|
+
patch_size=16,
|
|
28
|
+
embed_dim=vision_width,
|
|
29
|
+
depth=12,
|
|
30
|
+
num_heads=12,
|
|
31
|
+
use_grad_checkpointing=use_grad_checkpointing,
|
|
32
|
+
ckpt_layer=ckpt_layer,
|
|
33
|
+
drop_path_rate=0 or drop_path_rate)
|
|
34
|
+
elif vit == 'large':
|
|
35
|
+
vision_width = 1024
|
|
36
|
+
visual_encoder = VisionTransformer(
|
|
37
|
+
img_size=image_size,
|
|
38
|
+
patch_size=16,
|
|
39
|
+
embed_dim=vision_width,
|
|
40
|
+
depth=24,
|
|
41
|
+
num_heads=16,
|
|
42
|
+
use_grad_checkpointing=use_grad_checkpointing,
|
|
43
|
+
ckpt_layer=ckpt_layer,
|
|
44
|
+
drop_path_rate=0.1 or drop_path_rate)
|
|
45
|
+
return visual_encoder, vision_width
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class BLIP_Pretrain(nn.Module):
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
med_config='med_config.json',
|
|
53
|
+
image_size=224,
|
|
54
|
+
vit='base',
|
|
55
|
+
vit_grad_ckpt=False,
|
|
56
|
+
vit_ckpt_layer=0,
|
|
57
|
+
embed_dim=256,
|
|
58
|
+
queue_size=57600,
|
|
59
|
+
momentum=0.995,
|
|
60
|
+
):
|
|
61
|
+
"""
|
|
62
|
+
Args:
|
|
63
|
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
|
64
|
+
image_size (int): input image size
|
|
65
|
+
vit (str): model size of vision transformer
|
|
66
|
+
"""
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
|
70
|
+
|
|
71
|
+
self.tokenizer = init_tokenizer()
|
|
72
|
+
encoder_config = BertConfig.from_json_file(med_config)
|
|
73
|
+
encoder_config.encoder_width = vision_width
|
|
74
|
+
encoder_config.add_type_embeddings = False
|
|
75
|
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
|
76
|
+
|
|
77
|
+
text_width = self.text_encoder.config.hidden_size
|
|
78
|
+
|
|
79
|
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
|
80
|
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from ...constants import CACHE_DIR
|
|
5
|
+
from ..model import ScoreModel
|
|
6
|
+
|
|
7
|
+
IMAGE_REWARD_MODELS = {
|
|
8
|
+
'image-reward-v1': {
|
|
9
|
+
'variant': 'ImageReward-v1.0'
|
|
10
|
+
},
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ImageRewardScoreModel(ScoreModel):
|
|
15
|
+
'A wrapper for ImageReward ITMScore (finetuned on human preference) models'
|
|
16
|
+
|
|
17
|
+
def __init__(self, model_name='image-reward-v1', device='cuda', cache_dir=CACHE_DIR):
|
|
18
|
+
assert model_name in IMAGE_REWARD_MODELS, f'Model name must be one of {IMAGE_REWARD_MODELS.keys()}'
|
|
19
|
+
|
|
20
|
+
super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
|
|
21
|
+
|
|
22
|
+
def load_model(self):
|
|
23
|
+
"""Load the model, tokenizer, image transform
|
|
24
|
+
"""
|
|
25
|
+
from ..utils import download_file
|
|
26
|
+
from .image_reward.ImageReward import ImageReward
|
|
27
|
+
|
|
28
|
+
self.variant = IMAGE_REWARD_MODELS[self.model_name]['variant']
|
|
29
|
+
|
|
30
|
+
self.model_path = download_file('ZhipuAI/ImageReward', file_name='ImageReward.pt', cache_dir=self.cache_dir)
|
|
31
|
+
self.med_config = download_file('ZhipuAI/ImageReward', file_name='med_config.json', cache_dir=self.cache_dir)
|
|
32
|
+
|
|
33
|
+
state_dict = torch.load(self.model_path, map_location='cpu')
|
|
34
|
+
self.model = ImageReward(device=self.device, med_config=self.med_config).to(self.device)
|
|
35
|
+
msg = self.model.load_state_dict(state_dict, strict=False)
|
|
36
|
+
self.model.eval()
|
|
37
|
+
|
|
38
|
+
def load_images(self, image: List[str]) -> torch.Tensor:
|
|
39
|
+
"""Load the image(s), and return a tensor (after preprocessing) put on self.device
|
|
40
|
+
"""
|
|
41
|
+
image = [self.image_loader(x) for x in image]
|
|
42
|
+
image = [self.model.preprocess(image) for image in image]
|
|
43
|
+
assert all(x.shape == image[0].shape for x in image)
|
|
44
|
+
image = torch.stack(image, dim=0).to(self.device)
|
|
45
|
+
return image
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def forward(self, images: List[str], texts: List[str]) -> torch.Tensor:
|
|
49
|
+
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
50
|
+
"""
|
|
51
|
+
assert len(images) == len(texts), 'Number of images and texts must match'
|
|
52
|
+
rewards = torch.zeros(len(texts), dtype=torch.float32).to(self.device)
|
|
53
|
+
images = self.load_images(images)
|
|
54
|
+
for index in range(len(texts)):
|
|
55
|
+
text_input = self.model.blip.tokenizer(
|
|
56
|
+
texts[index], padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
|
|
57
|
+
image_embeds = self.model.blip.visual_encoder(images[index].unsqueeze(0))
|
|
58
|
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
59
|
+
text_output = self.model.blip.text_encoder(
|
|
60
|
+
text_input.input_ids,
|
|
61
|
+
attention_mask=text_input.attention_mask,
|
|
62
|
+
encoder_hidden_states=image_embeds,
|
|
63
|
+
encoder_attention_mask=image_atts,
|
|
64
|
+
return_dict=True,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
|
|
68
|
+
reward_score = self.model.mlp(txt_features)
|
|
69
|
+
reward_score = (reward_score - self.model.mean) / self.model.std
|
|
70
|
+
|
|
71
|
+
rewards[index] = reward_score
|
|
72
|
+
|
|
73
|
+
return rewards
|