evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +9 -3
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/benchmark.py +4 -3
- evalscope/perf/main.py +4 -2
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +16 -8
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +20 -7
- tests/perf/test_perf.py +6 -3
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import os
|
|
3
|
+
import torch
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
from ..constants import CACHE_DIR
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def image_loader(image_path):
|
|
12
|
+
if image_path.split('.')[-1] == 'npy':
|
|
13
|
+
return Image.fromarray(np.load(image_path)[:, :, [2, 1, 0]], 'RGB')
|
|
14
|
+
else:
|
|
15
|
+
return Image.open(image_path).convert('RGB')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ScoreModel(ABC):
|
|
19
|
+
|
|
20
|
+
def __init__(self, model_name='clip-flant5-xxl', device='cuda', cache_dir=CACHE_DIR):
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.device = device
|
|
23
|
+
self.cache_dir = cache_dir
|
|
24
|
+
if not os.path.exists(self.cache_dir):
|
|
25
|
+
os.makedirs(self.cache_dir)
|
|
26
|
+
self.image_loader = image_loader
|
|
27
|
+
self.load_model()
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def load_model(self):
|
|
31
|
+
"""Load the model, tokenizer, and etc.
|
|
32
|
+
"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def load_images(self, image: List[str]) -> torch.Tensor:
|
|
37
|
+
"""Load the image(s), and return a tensor (after preprocessing) put on self.device
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def forward(self, images: List[str], texts: List[str], **kwargs) -> torch.Tensor:
|
|
43
|
+
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
44
|
+
"""
|
|
45
|
+
pass
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from modelscope import snapshot_download
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def download_open_clip_model(model_name, tag, cache_dir):
|
|
6
|
+
import open_clip
|
|
7
|
+
|
|
8
|
+
# get pretrained config
|
|
9
|
+
pretrained_cfg = open_clip.get_pretrained_cfg(model_name, tag)
|
|
10
|
+
model_hub = pretrained_cfg.get('hf_hub').strip('/')
|
|
11
|
+
# load model from modelscope
|
|
12
|
+
model_weight_name = 'open_clip_model.safetensors'
|
|
13
|
+
local_path = snapshot_download(model_id=model_hub, cache_dir=cache_dir, allow_patterns=model_weight_name)
|
|
14
|
+
model_file_path = os.path.join(local_path, model_weight_name)
|
|
15
|
+
|
|
16
|
+
return model_file_path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def download_file(model_id, file_name=None, cache_dir=None):
|
|
20
|
+
# download file from modelscope
|
|
21
|
+
local_path = snapshot_download(model_id=model_id, cache_dir=cache_dir, allow_patterns=file_name)
|
|
22
|
+
if file_name is None:
|
|
23
|
+
return local_path
|
|
24
|
+
else:
|
|
25
|
+
return os.path.join(local_path, file_name)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ...constants import CACHE_DIR
|
|
2
|
+
from .clip_t5_model import CLIP_T5_MODELS, CLIPT5Model
|
|
3
|
+
from .gpt4v_model import GPT4V_MODELS, GPT4VModel
|
|
4
|
+
|
|
5
|
+
ALL_VQA_MODELS = [
|
|
6
|
+
CLIP_T5_MODELS,
|
|
7
|
+
GPT4V_MODELS,
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def list_all_vqascore_models():
|
|
12
|
+
return [model for models in ALL_VQA_MODELS for model in models]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_vqascore_model(model_name, device='cuda', cache_dir=CACHE_DIR, **kwargs):
|
|
16
|
+
assert model_name in list_all_vqascore_models()
|
|
17
|
+
if model_name in CLIP_T5_MODELS:
|
|
18
|
+
return CLIPT5Model(model_name, device=device, cache_dir=cache_dir, **kwargs)
|
|
19
|
+
elif model_name in GPT4V_MODELS:
|
|
20
|
+
return GPT4VModel(model_name, device=device, cache_dir=cache_dir, **kwargs)
|
|
21
|
+
else:
|
|
22
|
+
raise NotImplementedError()
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .language_model.clip_t5 import CLIPT5Config, CLIPT5ForConditionalGeneration, ModelArguments
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
# Copyright 2023 Zhiqiu Lin
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from transformers import AutoConfig, AutoModelForSeq2SeqLM, T5Config, T5ForConditionalGeneration
|
|
18
|
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
|
19
|
+
from typing import List, Optional, Tuple, Union
|
|
20
|
+
|
|
21
|
+
from ..multimodal_encoder.builder import build_vision_tower
|
|
22
|
+
from ..multimodal_projector.builder import build_vision_projector
|
|
23
|
+
|
|
24
|
+
IMAGE_TOKEN_INDEX = -200
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ModelArguments:
|
|
29
|
+
tune_mm_mlp_adapter: bool = field(default=False)
|
|
30
|
+
vision_tower: Optional[str] = field(default='openai/clip-vit-large-patch14-336')
|
|
31
|
+
mm_vision_select_layer: Optional[int] = field(default=-2) # default to the second last layer in llava1.5
|
|
32
|
+
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
|
33
|
+
mm_projector_type: Optional[str] = field(default='mlp2x_gelu')
|
|
34
|
+
mm_vision_select_feature: Optional[str] = field(default='patch')
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CLIPT5Config(T5Config):
|
|
38
|
+
model_type = 'clip_t5'
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
42
|
+
# This class supports both T5 and FlanT5
|
|
43
|
+
config_class = CLIPT5Config
|
|
44
|
+
|
|
45
|
+
def __init__(self, config):
|
|
46
|
+
super(CLIPT5ForConditionalGeneration, self).__init__(config)
|
|
47
|
+
self.embed_tokens = self.encoder.embed_tokens
|
|
48
|
+
if hasattr(config, 'mm_vision_tower'):
|
|
49
|
+
self.vision_tower = build_vision_tower(config, delay_load=False)
|
|
50
|
+
self.mm_projector = build_vision_projector(config)
|
|
51
|
+
|
|
52
|
+
def get_vision_tower(self):
|
|
53
|
+
vision_tower = getattr(self, 'vision_tower', None)
|
|
54
|
+
if type(vision_tower) is list:
|
|
55
|
+
vision_tower = vision_tower[0]
|
|
56
|
+
return vision_tower
|
|
57
|
+
|
|
58
|
+
def get_model(self):
|
|
59
|
+
return self # for compatibility with LlavaMetaForCausalLM
|
|
60
|
+
|
|
61
|
+
def prepare_inputs_labels_for_multimodal(self, input_ids, attention_mask, decoder_attention_mask, past_key_values,
|
|
62
|
+
labels, images):
|
|
63
|
+
# The labels are now separated from the input_ids.
|
|
64
|
+
vision_tower = self.get_vision_tower()
|
|
65
|
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
66
|
+
raise NotImplementedError()
|
|
67
|
+
|
|
68
|
+
if type(images) is list or images.ndim == 5:
|
|
69
|
+
concat_images = torch.cat([image for image in images], dim=0)
|
|
70
|
+
image_features = self.encode_images(concat_images)
|
|
71
|
+
split_sizes = [image.shape[0] for image in images]
|
|
72
|
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
|
73
|
+
image_features = [x.flatten(0, 1) for x in image_features]
|
|
74
|
+
else:
|
|
75
|
+
image_features = self.encode_images(images)
|
|
76
|
+
|
|
77
|
+
new_input_embeds = []
|
|
78
|
+
cur_image_idx = 0
|
|
79
|
+
for _, cur_input_ids in enumerate(input_ids):
|
|
80
|
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
|
81
|
+
# multimodal LLM, but the current sample is not multimodal
|
|
82
|
+
raise NotImplementedError()
|
|
83
|
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
|
84
|
+
cur_new_input_embeds = []
|
|
85
|
+
while image_token_indices.numel() > 0:
|
|
86
|
+
cur_image_features = image_features[cur_image_idx]
|
|
87
|
+
image_token_start = image_token_indices[0]
|
|
88
|
+
cur_new_input_embeds.append(self.embed_tokens(cur_input_ids[:image_token_start]))
|
|
89
|
+
cur_new_input_embeds.append(cur_image_features)
|
|
90
|
+
cur_image_idx += 1
|
|
91
|
+
cur_input_ids = cur_input_ids[image_token_start + 1:]
|
|
92
|
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
|
93
|
+
if cur_input_ids.numel() > 0:
|
|
94
|
+
cur_new_input_embeds.append(self.embed_tokens(cur_input_ids))
|
|
95
|
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
|
96
|
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
|
97
|
+
new_input_embeds.append(cur_new_input_embeds)
|
|
98
|
+
|
|
99
|
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
|
100
|
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
|
101
|
+
|
|
102
|
+
new_input_embeds_align = []
|
|
103
|
+
_input_embeds_lengths = []
|
|
104
|
+
for cur_new_embed in new_input_embeds:
|
|
105
|
+
_input_embeds_lengths.append(cur_new_embed.shape[0])
|
|
106
|
+
cur_new_embed = torch.cat((cur_new_embed,
|
|
107
|
+
torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
|
108
|
+
dtype=cur_new_embed.dtype,
|
|
109
|
+
device=cur_new_embed.device)),
|
|
110
|
+
dim=0)
|
|
111
|
+
new_input_embeds_align.append(cur_new_embed)
|
|
112
|
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
|
113
|
+
|
|
114
|
+
if attention_mask is not None:
|
|
115
|
+
new_attention_mask = []
|
|
116
|
+
for cur_attention_mask, _input_embeds_length in zip(attention_mask, _input_embeds_lengths):
|
|
117
|
+
new_attn_mask_pad_left = torch.full((_input_embeds_length - input_ids.shape[1], ),
|
|
118
|
+
True,
|
|
119
|
+
dtype=attention_mask.dtype,
|
|
120
|
+
device=attention_mask.device)
|
|
121
|
+
new_attn_mask_pad_right = torch.full((new_input_embeds.shape[1] - _input_embeds_length, ),
|
|
122
|
+
False,
|
|
123
|
+
dtype=attention_mask.dtype,
|
|
124
|
+
device=attention_mask.device)
|
|
125
|
+
cur_new_attention_mask = torch.cat(
|
|
126
|
+
(new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
|
127
|
+
new_attention_mask.append(cur_new_attention_mask)
|
|
128
|
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
|
129
|
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
|
130
|
+
else:
|
|
131
|
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
|
132
|
+
|
|
133
|
+
if attention_mask is not None:
|
|
134
|
+
new_attn_mask_pad_left = torch.full(
|
|
135
|
+
(attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]),
|
|
136
|
+
True,
|
|
137
|
+
dtype=attention_mask.dtype,
|
|
138
|
+
device=attention_mask.device)
|
|
139
|
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
|
140
|
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
|
141
|
+
|
|
142
|
+
return None, attention_mask, decoder_attention_mask, past_key_values, new_input_embeds, labels
|
|
143
|
+
|
|
144
|
+
def encode_images(self, images):
|
|
145
|
+
image_features = self.get_vision_tower()(images)
|
|
146
|
+
image_features = self.mm_projector(image_features)
|
|
147
|
+
return image_features
|
|
148
|
+
|
|
149
|
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
|
150
|
+
vision_tower = model_args.vision_tower
|
|
151
|
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
|
152
|
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
|
153
|
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
|
154
|
+
|
|
155
|
+
self.config.mm_vision_tower = vision_tower
|
|
156
|
+
self.config.pretrain_mm_mlp_adapter = pretrain_mm_mlp_adapter
|
|
157
|
+
|
|
158
|
+
if self.get_vision_tower() is None:
|
|
159
|
+
vision_tower = build_vision_tower(model_args)
|
|
160
|
+
|
|
161
|
+
if fsdp is not None and len(fsdp) > 0:
|
|
162
|
+
self.vision_tower = [vision_tower]
|
|
163
|
+
else:
|
|
164
|
+
self.vision_tower = vision_tower
|
|
165
|
+
else:
|
|
166
|
+
if fsdp is not None and len(fsdp) > 0:
|
|
167
|
+
vision_tower = self.vision_tower[0]
|
|
168
|
+
else:
|
|
169
|
+
vision_tower = self.vision_tower
|
|
170
|
+
if not vision_tower.is_loaded:
|
|
171
|
+
vision_tower.load_model()
|
|
172
|
+
|
|
173
|
+
self.config.use_mm_proj = True
|
|
174
|
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'mlp2x_gelu')
|
|
175
|
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
|
176
|
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
|
177
|
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
|
178
|
+
|
|
179
|
+
if getattr(self, 'mm_projector', None) is None:
|
|
180
|
+
self.mm_projector = build_vision_projector(self.config)
|
|
181
|
+
|
|
182
|
+
if pretrain_mm_mlp_adapter is not None:
|
|
183
|
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
|
184
|
+
|
|
185
|
+
def get_w(weights, keyword):
|
|
186
|
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
|
187
|
+
|
|
188
|
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
|
189
|
+
|
|
190
|
+
def forward(
|
|
191
|
+
self,
|
|
192
|
+
input_ids: torch.LongTensor = None,
|
|
193
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
194
|
+
decoder_attention_mask: Optional[torch.Tensor] = None,
|
|
195
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
196
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
197
|
+
labels: Optional[torch.LongTensor] = None,
|
|
198
|
+
use_cache: Optional[bool] = None,
|
|
199
|
+
output_attentions: Optional[bool] = None,
|
|
200
|
+
output_hidden_states: Optional[bool] = None,
|
|
201
|
+
images: Optional[torch.FloatTensor] = None,
|
|
202
|
+
return_dict: Optional[bool] = None,
|
|
203
|
+
**kwargs,
|
|
204
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
205
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
206
|
+
output_hidden_states = (
|
|
207
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
|
208
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
209
|
+
|
|
210
|
+
if inputs_embeds is None:
|
|
211
|
+
_, attention_mask, decoder_attention_mask, past_key_values, inputs_embeds, labels = \
|
|
212
|
+
self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, decoder_attention_mask, past_key_values, labels, images)
|
|
213
|
+
|
|
214
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
215
|
+
outputs = super(CLIPT5ForConditionalGeneration, self).forward(
|
|
216
|
+
input_ids=None, # will be None if inputs_embeds is not None
|
|
217
|
+
attention_mask=attention_mask,
|
|
218
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
219
|
+
labels=labels,
|
|
220
|
+
past_key_values=past_key_values,
|
|
221
|
+
inputs_embeds=inputs_embeds,
|
|
222
|
+
use_cache=use_cache,
|
|
223
|
+
output_attentions=output_attentions,
|
|
224
|
+
output_hidden_states=output_hidden_states,
|
|
225
|
+
return_dict=return_dict,
|
|
226
|
+
**kwargs,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return outputs
|
|
230
|
+
|
|
231
|
+
@torch.no_grad()
|
|
232
|
+
def generate(
|
|
233
|
+
self,
|
|
234
|
+
inputs: Optional[torch.Tensor] = None,
|
|
235
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
236
|
+
images: Optional[torch.Tensor] = None,
|
|
237
|
+
**kwargs,
|
|
238
|
+
):
|
|
239
|
+
assert images is not None, 'images must be provided'
|
|
240
|
+
assert inputs is not None, 'inputs must be provided'
|
|
241
|
+
assert attention_mask is not None, 'attention_mask must be provided'
|
|
242
|
+
_, attention_mask, _, _, inputs_embeds, _ = \
|
|
243
|
+
self.prepare_inputs_labels_for_multimodal(inputs, attention_mask, None, None, None, images)
|
|
244
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
245
|
+
outputs = super(CLIPT5ForConditionalGeneration, self).generate(
|
|
246
|
+
input_ids=None, # will be None if inputs_embeds is not None
|
|
247
|
+
attention_mask=attention_mask,
|
|
248
|
+
inputs_embeds=inputs_embeds,
|
|
249
|
+
)
|
|
250
|
+
return outputs
|
|
251
|
+
|
|
252
|
+
def prepare_inputs_for_generation(
|
|
253
|
+
self,
|
|
254
|
+
input_ids,
|
|
255
|
+
past_key_values=None,
|
|
256
|
+
attention_mask=None,
|
|
257
|
+
head_mask=None,
|
|
258
|
+
decoder_head_mask=None,
|
|
259
|
+
decoder_attention_mask=None,
|
|
260
|
+
cross_attn_head_mask=None,
|
|
261
|
+
use_cache=None,
|
|
262
|
+
encoder_outputs=None,
|
|
263
|
+
inputs_embeds=None,
|
|
264
|
+
**kwargs,
|
|
265
|
+
):
|
|
266
|
+
# cut decoder_input_ids if past_key_values is used
|
|
267
|
+
if past_key_values is not None:
|
|
268
|
+
past_length = past_key_values[0][0].shape[2]
|
|
269
|
+
|
|
270
|
+
# Some generation methods already pass only the last input ID
|
|
271
|
+
if input_ids.shape[1] > past_length:
|
|
272
|
+
remove_prefix_length = past_length
|
|
273
|
+
else:
|
|
274
|
+
# Default to old behavior: keep only final ID
|
|
275
|
+
remove_prefix_length = input_ids.shape[1] - 1
|
|
276
|
+
|
|
277
|
+
input_ids = input_ids[:, remove_prefix_length:]
|
|
278
|
+
|
|
279
|
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
280
|
+
if inputs_embeds is not None and past_key_values is None:
|
|
281
|
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
|
282
|
+
else:
|
|
283
|
+
model_inputs = {'input_ids': input_ids}
|
|
284
|
+
|
|
285
|
+
model_inputs.update({
|
|
286
|
+
'decoder_input_ids': input_ids,
|
|
287
|
+
'past_key_values': past_key_values,
|
|
288
|
+
'encoder_outputs': encoder_outputs,
|
|
289
|
+
'attention_mask': attention_mask,
|
|
290
|
+
'head_mask': head_mask,
|
|
291
|
+
'decoder_head_mask': decoder_head_mask,
|
|
292
|
+
'decoder_attention_mask': decoder_attention_mask,
|
|
293
|
+
'cross_attn_head_mask': cross_attn_head_mask,
|
|
294
|
+
'use_cache': use_cache,
|
|
295
|
+
})
|
|
296
|
+
return model_inputs
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
AutoConfig.register('clip_t5', CLIPT5Config)
|
|
300
|
+
AutoModelForSeq2SeqLM.register(CLIPT5Config, CLIPT5ForConditionalGeneration)
|
evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from .clip_encoder import CLIPVisionTower
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
|
7
|
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
|
8
|
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
|
9
|
+
if is_absolute_path_exists or vision_tower.startswith('openai') or vision_tower.startswith('laion'):
|
|
10
|
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
|
11
|
+
|
|
12
|
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CLIPVisionTower(nn.Module):
|
|
7
|
+
|
|
8
|
+
def __init__(self, vision_tower, args, delay_load=False):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
self.is_loaded = False
|
|
12
|
+
|
|
13
|
+
self.vision_tower_name = vision_tower
|
|
14
|
+
self.select_layer = args.mm_vision_select_layer
|
|
15
|
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
|
16
|
+
|
|
17
|
+
if not delay_load:
|
|
18
|
+
self.load_model()
|
|
19
|
+
else:
|
|
20
|
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
|
21
|
+
|
|
22
|
+
def load_model(self):
|
|
23
|
+
from .....utils import download_file
|
|
24
|
+
model_path = download_file(self.vision_tower_name.replace('openai', 'openai-mirror'))
|
|
25
|
+
self.image_processor = CLIPImageProcessor.from_pretrained(model_path)
|
|
26
|
+
self.vision_tower = CLIPVisionModel.from_pretrained(model_path)
|
|
27
|
+
self.vision_tower.requires_grad_(False)
|
|
28
|
+
|
|
29
|
+
self.is_loaded = True
|
|
30
|
+
|
|
31
|
+
def feature_select(self, image_forward_outs):
|
|
32
|
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
|
33
|
+
if self.select_feature == 'patch':
|
|
34
|
+
image_features = image_features[:, 1:]
|
|
35
|
+
elif self.select_feature == 'cls_patch':
|
|
36
|
+
image_features = image_features
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
|
39
|
+
return image_features
|
|
40
|
+
|
|
41
|
+
@torch.no_grad()
|
|
42
|
+
def forward(self, images):
|
|
43
|
+
if type(images) is list:
|
|
44
|
+
image_features = []
|
|
45
|
+
for image in images:
|
|
46
|
+
image_forward_out = self.vision_tower(
|
|
47
|
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
|
48
|
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
49
|
+
image_features.append(image_feature)
|
|
50
|
+
else:
|
|
51
|
+
image_forward_outs = self.vision_tower(
|
|
52
|
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
|
53
|
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
54
|
+
|
|
55
|
+
return image_features
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def dummy_feature(self):
|
|
59
|
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def dtype(self):
|
|
63
|
+
return self.vision_tower.dtype
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def device(self):
|
|
67
|
+
return self.vision_tower.device
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def config(self):
|
|
71
|
+
if self.is_loaded:
|
|
72
|
+
return self.vision_tower.config
|
|
73
|
+
else:
|
|
74
|
+
return self.cfg_only
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def hidden_size(self):
|
|
78
|
+
return self.config.hidden_size
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def num_patches(self):
|
|
82
|
+
return (self.config.image_size // self.config.patch_size)**2
|
evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class IdentityMap(nn.Module):
|
|
7
|
+
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
def forward(self, x, *args, **kwargs):
|
|
12
|
+
return x
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def config(self):
|
|
16
|
+
return {'mm_projector_type': 'identity'}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleResBlock(nn.Module):
|
|
20
|
+
|
|
21
|
+
def __init__(self, channels):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.pre_norm = nn.LayerNorm(channels)
|
|
24
|
+
|
|
25
|
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
x = self.pre_norm(x)
|
|
29
|
+
return x + self.proj(x)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
|
33
|
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
|
34
|
+
|
|
35
|
+
if projector_type == 'linear':
|
|
36
|
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
|
37
|
+
|
|
38
|
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
|
39
|
+
if mlp_gelu_match:
|
|
40
|
+
mlp_depth = int(mlp_gelu_match.group(1))
|
|
41
|
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
|
42
|
+
for _ in range(1, mlp_depth):
|
|
43
|
+
modules.append(nn.GELU())
|
|
44
|
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
45
|
+
return nn.Sequential(*modules)
|
|
46
|
+
|
|
47
|
+
if projector_type == 'identity':
|
|
48
|
+
return IdentityMap()
|
|
49
|
+
|
|
50
|
+
raise ValueError(f'Unknown projector type: {projector_type}')
|