evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +9 -3
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/benchmark.py +4 -3
- evalscope/perf/main.py +4 -2
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +16 -8
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +20 -7
- tests/perf/test_perf.py +6 -3
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import torch
|
|
10
|
+
from omegaconf import OmegaConf
|
|
11
|
+
|
|
12
|
+
from ..common.registry import registry
|
|
13
|
+
from ..processors.base_processor import BaseProcessor
|
|
14
|
+
from .base_model import BaseModel
|
|
15
|
+
from .blip2_models.blip2 import Blip2Base
|
|
16
|
+
from .blip2_models.blip2_image_text_matching import Blip2ITM
|
|
17
|
+
from .blip2_models.blip2_qformer import Blip2Qformer
|
|
18
|
+
from .blip2_models.blip2_t5 import Blip2T5
|
|
19
|
+
from .blip2_models.blip2_t5_instruct import Blip2T5Instruct
|
|
20
|
+
from .blip2_models.fga_blip2 import FGA_Blip2
|
|
21
|
+
from .blip_models.blip import BlipBase
|
|
22
|
+
from .blip_models.blip_caption import BlipCaption
|
|
23
|
+
from .blip_models.blip_classification import BlipClassification
|
|
24
|
+
from .blip_models.blip_feature_extractor import BlipFeatureExtractor
|
|
25
|
+
from .blip_models.blip_image_text_matching import BlipITM
|
|
26
|
+
from .blip_models.blip_nlvr import BlipNLVR
|
|
27
|
+
from .blip_models.blip_pretrain import BlipPretrain
|
|
28
|
+
from .blip_models.blip_vqa import BlipVQA
|
|
29
|
+
from .med import XBertLMHeadDecoder
|
|
30
|
+
from .vit import VisionTransformerEncoder
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
'load_model',
|
|
34
|
+
'BaseModel',
|
|
35
|
+
'BlipBase',
|
|
36
|
+
'BlipFeatureExtractor',
|
|
37
|
+
'BlipCaption',
|
|
38
|
+
'BlipClassification',
|
|
39
|
+
'BlipITM',
|
|
40
|
+
'BlipNLVR',
|
|
41
|
+
'BlipPretrain',
|
|
42
|
+
'BlipVQA',
|
|
43
|
+
'Blip2Qformer',
|
|
44
|
+
'Blip2Base',
|
|
45
|
+
'Blip2ITM',
|
|
46
|
+
'Blip2T5',
|
|
47
|
+
'Blip2T5Instruct',
|
|
48
|
+
'FGA_Blip2',
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_model(name, model_type, is_eval=False, device='cpu', checkpoint=None):
|
|
53
|
+
"""
|
|
54
|
+
Load supported models.
|
|
55
|
+
|
|
56
|
+
To list all available models and types in registry:
|
|
57
|
+
>>> from import model_zoo
|
|
58
|
+
>>> print(model_zoo)
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
name (str): name of the model.
|
|
62
|
+
model_type (str): type of the model.
|
|
63
|
+
is_eval (bool): whether the model is in eval mode. Default: False.
|
|
64
|
+
device (str): device to use. Default: "cpu".
|
|
65
|
+
checkpoint (str): path or to checkpoint. Default: None.
|
|
66
|
+
Note that expecting the checkpoint to have the same keys in state_dict as the model.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
model (torch.nn.Module): model.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
model = registry.get_model_class(name).from_pretrained(model_type=model_type)
|
|
73
|
+
|
|
74
|
+
if checkpoint is not None:
|
|
75
|
+
model.load_checkpoint(checkpoint)
|
|
76
|
+
|
|
77
|
+
if is_eval:
|
|
78
|
+
model.eval()
|
|
79
|
+
|
|
80
|
+
if device == 'cpu':
|
|
81
|
+
model = model.float()
|
|
82
|
+
|
|
83
|
+
return model.to(device)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def load_preprocess(config):
|
|
87
|
+
"""
|
|
88
|
+
Load preprocessor configs and construct preprocessors.
|
|
89
|
+
|
|
90
|
+
If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
config (dict): preprocessor configs.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
vis_processors (dict): preprocessors for visual inputs.
|
|
97
|
+
txt_processors (dict): preprocessors for text inputs.
|
|
98
|
+
|
|
99
|
+
Key is "train" or "eval" for processors used in training and evaluation respectively.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def _build_proc_from_cfg(cfg):
|
|
103
|
+
return (registry.get_processor_class(cfg.name).from_config(cfg) if cfg is not None else BaseProcessor())
|
|
104
|
+
|
|
105
|
+
vis_processors = dict()
|
|
106
|
+
txt_processors = dict()
|
|
107
|
+
|
|
108
|
+
vis_proc_cfg = config.get('vis_processor')
|
|
109
|
+
txt_proc_cfg = config.get('text_processor')
|
|
110
|
+
|
|
111
|
+
if vis_proc_cfg is not None:
|
|
112
|
+
vis_train_cfg = vis_proc_cfg.get('train')
|
|
113
|
+
vis_eval_cfg = vis_proc_cfg.get('eval')
|
|
114
|
+
else:
|
|
115
|
+
vis_train_cfg = None
|
|
116
|
+
vis_eval_cfg = None
|
|
117
|
+
|
|
118
|
+
vis_processors['train'] = _build_proc_from_cfg(vis_train_cfg)
|
|
119
|
+
vis_processors['eval'] = _build_proc_from_cfg(vis_eval_cfg)
|
|
120
|
+
|
|
121
|
+
if txt_proc_cfg is not None:
|
|
122
|
+
txt_train_cfg = txt_proc_cfg.get('train')
|
|
123
|
+
txt_eval_cfg = txt_proc_cfg.get('eval')
|
|
124
|
+
else:
|
|
125
|
+
txt_train_cfg = None
|
|
126
|
+
txt_eval_cfg = None
|
|
127
|
+
|
|
128
|
+
txt_processors['train'] = _build_proc_from_cfg(txt_train_cfg)
|
|
129
|
+
txt_processors['eval'] = _build_proc_from_cfg(txt_eval_cfg)
|
|
130
|
+
|
|
131
|
+
return vis_processors, txt_processors
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def load_model_and_preprocess(name, model_type, is_eval=False, device='cpu'):
|
|
135
|
+
"""
|
|
136
|
+
Load model and its related preprocessors.
|
|
137
|
+
|
|
138
|
+
List all available models and types in registry:
|
|
139
|
+
>>> from import model_zoo
|
|
140
|
+
>>> print(model_zoo)
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
name (str): name of the model.
|
|
144
|
+
model_type (str): type of the model.
|
|
145
|
+
is_eval (bool): whether the model is in eval mode. Default: False.
|
|
146
|
+
device (str): device to use. Default: "cpu".
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
model (torch.nn.Module): model.
|
|
150
|
+
vis_processors (dict): preprocessors for visual inputs.
|
|
151
|
+
txt_processors (dict): preprocessors for text inputs.
|
|
152
|
+
"""
|
|
153
|
+
model_cls = registry.get_model_class(name)
|
|
154
|
+
|
|
155
|
+
# load model
|
|
156
|
+
model = model_cls.from_pretrained(model_type=model_type)
|
|
157
|
+
|
|
158
|
+
if is_eval:
|
|
159
|
+
model.eval()
|
|
160
|
+
|
|
161
|
+
# load preprocess
|
|
162
|
+
cfg = OmegaConf.load(model_cls.default_config_path(model_type))
|
|
163
|
+
if cfg is not None:
|
|
164
|
+
preprocess_cfg = cfg.preprocess
|
|
165
|
+
|
|
166
|
+
vis_processors, txt_processors = load_preprocess(preprocess_cfg)
|
|
167
|
+
else:
|
|
168
|
+
vis_processors, txt_processors = None, None
|
|
169
|
+
logging.info(f"""No default preprocess for model {name} ({model_type}).
|
|
170
|
+
This can happen if the model is not finetuned on downstream datasets,
|
|
171
|
+
or it is not intended for direct use without finetuning.
|
|
172
|
+
""")
|
|
173
|
+
|
|
174
|
+
if device == 'cpu' or device == torch.device('cpu'):
|
|
175
|
+
model = model.float()
|
|
176
|
+
|
|
177
|
+
return model.to(device), vis_processors, txt_processors
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class ModelZoo:
|
|
181
|
+
"""
|
|
182
|
+
A utility class to create string representation of available model architectures and types.
|
|
183
|
+
|
|
184
|
+
>>> from import model_zoo
|
|
185
|
+
>>> # list all available models
|
|
186
|
+
>>> print(model_zoo)
|
|
187
|
+
>>> # show total number of models
|
|
188
|
+
>>> print(len(model_zoo))
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self) -> None:
|
|
192
|
+
self.model_zoo = {
|
|
193
|
+
k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
|
|
194
|
+
for k, v in registry.mapping['model_name_mapping'].items()
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
def __str__(self) -> str:
|
|
198
|
+
return ('=' * 50 + '\n' + f"{'Architectures':<30} {'Types'}\n" + '=' * 50 + '\n'
|
|
199
|
+
+ '\n'.join([f"{name:<30} {', '.join(types)}" for name, types in self.model_zoo.items()]))
|
|
200
|
+
|
|
201
|
+
def __iter__(self):
|
|
202
|
+
return iter(self.model_zoo.items())
|
|
203
|
+
|
|
204
|
+
def __len__(self):
|
|
205
|
+
return sum([len(v) for v in self.model_zoo.values()])
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
model_zoo = ModelZoo()
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import numpy as np
|
|
10
|
+
import os
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from omegaconf import OmegaConf
|
|
14
|
+
|
|
15
|
+
from ..common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
|
16
|
+
from ..common.utils import get_abs_path, is_url
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseModel(nn.Module):
|
|
20
|
+
"""Base class for models."""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
super().__init__()
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def device(self):
|
|
27
|
+
return list(self.parameters())[0].device
|
|
28
|
+
|
|
29
|
+
def load_checkpoint(self, url_or_filename):
|
|
30
|
+
"""
|
|
31
|
+
Load from a finetuned checkpoint.
|
|
32
|
+
|
|
33
|
+
This should expect no mismatch in the model keys and the checkpoint keys.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
if is_url(url_or_filename):
|
|
37
|
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
|
38
|
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
|
39
|
+
elif os.path.isfile(url_or_filename):
|
|
40
|
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
|
41
|
+
else:
|
|
42
|
+
raise RuntimeError('checkpoint url or path is invalid')
|
|
43
|
+
|
|
44
|
+
if 'model' in checkpoint.keys():
|
|
45
|
+
state_dict = checkpoint['model']
|
|
46
|
+
else:
|
|
47
|
+
state_dict = checkpoint
|
|
48
|
+
|
|
49
|
+
msg = self.load_state_dict(state_dict, strict=False)
|
|
50
|
+
|
|
51
|
+
# logging.info('Missing keys {}'.format(msg.missing_keys))
|
|
52
|
+
logging.info('load checkpoint from %s' % url_or_filename)
|
|
53
|
+
|
|
54
|
+
return msg
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_pretrained(cls, model_type):
|
|
58
|
+
"""
|
|
59
|
+
Build a pretrained model from default configuration file, specified by model_type.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
- model_type (str): model type, specifying architecture and checkpoints.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
|
66
|
+
"""
|
|
67
|
+
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
|
68
|
+
model = cls.from_config(model_cfg)
|
|
69
|
+
|
|
70
|
+
return model
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def default_config_path(cls, model_type):
|
|
74
|
+
assert (model_type in cls.PRETRAINED_MODEL_CONFIG_DICT), 'Unknown model type {}'.format(model_type)
|
|
75
|
+
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
|
76
|
+
|
|
77
|
+
def load_checkpoint_from_config(self, cfg, **kwargs):
|
|
78
|
+
"""
|
|
79
|
+
Load checkpoint as specified in the config file.
|
|
80
|
+
|
|
81
|
+
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
|
82
|
+
When loading the pretrained model, each task-specific architecture may define their
|
|
83
|
+
own load_from_pretrained() method.
|
|
84
|
+
"""
|
|
85
|
+
load_finetuned = cfg.get('load_finetuned', True)
|
|
86
|
+
if load_finetuned:
|
|
87
|
+
finetune_path = cfg.get('finetuned', None)
|
|
88
|
+
assert (finetune_path is not None), 'Found load_finetuned is True, but finetune_path is None.'
|
|
89
|
+
self.load_checkpoint(url_or_filename=finetune_path)
|
|
90
|
+
else:
|
|
91
|
+
# load pre-trained weights
|
|
92
|
+
pretrain_path = cfg.get('pretrained', None)
|
|
93
|
+
assert 'Found load_finetuned is False, but pretrain_path is None.'
|
|
94
|
+
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
|
95
|
+
|
|
96
|
+
def before_evaluation(self, **kwargs):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
def show_n_params(self, return_str=True):
|
|
100
|
+
tot = 0
|
|
101
|
+
for p in self.parameters():
|
|
102
|
+
w = 1
|
|
103
|
+
for x in p.shape:
|
|
104
|
+
w *= x
|
|
105
|
+
tot += w
|
|
106
|
+
if return_str:
|
|
107
|
+
if tot >= 1e6:
|
|
108
|
+
return '{:.1f}M'.format(tot / 1e6)
|
|
109
|
+
else:
|
|
110
|
+
return '{:.1f}K'.format(tot / 1e3)
|
|
111
|
+
else:
|
|
112
|
+
return tot
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class BaseEncoder(nn.Module):
|
|
116
|
+
"""
|
|
117
|
+
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(self):
|
|
121
|
+
super().__init__()
|
|
122
|
+
|
|
123
|
+
def forward_features(self, samples, **kwargs):
|
|
124
|
+
raise NotImplementedError
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def device(self):
|
|
128
|
+
return list(self.parameters())[0].device
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class SharedQueueMixin:
|
|
132
|
+
|
|
133
|
+
@torch.no_grad()
|
|
134
|
+
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
|
135
|
+
# gather keys before updating queue
|
|
136
|
+
image_feats = concat_all_gather(image_feat)
|
|
137
|
+
text_feats = concat_all_gather(text_feat)
|
|
138
|
+
|
|
139
|
+
batch_size = image_feats.shape[0]
|
|
140
|
+
|
|
141
|
+
ptr = int(self.queue_ptr)
|
|
142
|
+
assert self.queue_size % batch_size == 0 # for simplicity
|
|
143
|
+
|
|
144
|
+
# replace the keys at ptr (dequeue and enqueue)
|
|
145
|
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
|
146
|
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
|
147
|
+
|
|
148
|
+
if idxs is not None:
|
|
149
|
+
idxs = concat_all_gather(idxs)
|
|
150
|
+
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
|
|
151
|
+
|
|
152
|
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
|
153
|
+
self.queue_ptr[0] = ptr
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class MomentumDistilationMixin:
|
|
157
|
+
|
|
158
|
+
@torch.no_grad()
|
|
159
|
+
def copy_params(self):
|
|
160
|
+
for model_pair in self.model_pairs:
|
|
161
|
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
|
162
|
+
param_m.data.copy_(param.data) # initialize
|
|
163
|
+
param_m.requires_grad = False # not update by gradient
|
|
164
|
+
|
|
165
|
+
@torch.no_grad()
|
|
166
|
+
def _momentum_update(self):
|
|
167
|
+
for model_pair in self.model_pairs:
|
|
168
|
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
|
169
|
+
param_m.data = param_m.data * self.momentum + param.data * (1.0 - self.momentum)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class GatherLayer(torch.autograd.Function):
|
|
173
|
+
"""
|
|
174
|
+
Gather tensors from all workers with support for backward propagation:
|
|
175
|
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def forward(ctx, x):
|
|
180
|
+
output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
|
|
181
|
+
torch.distributed.all_gather(output, x)
|
|
182
|
+
return tuple(output)
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def backward(ctx, *grads):
|
|
186
|
+
all_gradients = torch.stack(grads)
|
|
187
|
+
torch.distributed.all_reduce(all_gradients)
|
|
188
|
+
return all_gradients[torch.distributed.get_rank()]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def all_gather_with_grad(tensors):
|
|
192
|
+
"""
|
|
193
|
+
Performs all_gather operation on the provided tensors.
|
|
194
|
+
Graph remains connected for backward grad computation.
|
|
195
|
+
"""
|
|
196
|
+
# Queue the gathered tensors
|
|
197
|
+
world_size = torch.distributed.get_world_size()
|
|
198
|
+
# There is no need for reduction in the single-proc case
|
|
199
|
+
if world_size == 1:
|
|
200
|
+
return tensors
|
|
201
|
+
|
|
202
|
+
# tensor_all = GatherLayer.apply(tensors)
|
|
203
|
+
tensor_all = GatherLayer.apply(tensors)
|
|
204
|
+
|
|
205
|
+
return torch.cat(tensor_all, dim=0)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@torch.no_grad()
|
|
209
|
+
def concat_all_gather(tensor):
|
|
210
|
+
"""
|
|
211
|
+
Performs all_gather operation on the provided tensors.
|
|
212
|
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
213
|
+
"""
|
|
214
|
+
# if use distributed training
|
|
215
|
+
if not is_dist_avail_and_initialized():
|
|
216
|
+
return tensor
|
|
217
|
+
|
|
218
|
+
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
|
219
|
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
220
|
+
|
|
221
|
+
output = torch.cat(tensors_gather, dim=0)
|
|
222
|
+
return output
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def tile(x, dim, n_tile):
|
|
226
|
+
init_dim = x.size(dim)
|
|
227
|
+
repeat_idx = [1] * x.dim()
|
|
228
|
+
repeat_idx[dim] = n_tile
|
|
229
|
+
x = x.repeat(*(repeat_idx))
|
|
230
|
+
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
|
231
|
+
return torch.index_select(x, dim, order_index.to(x.device))
|