evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +9 -3
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/benchmark.py +4 -3
- evalscope/perf/main.py +4 -2
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +16 -8
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +20 -7
- tests/perf/test_perf.py +6 -3
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import datetime
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
import torch
|
|
12
|
+
import torch.distributed as dist
|
|
13
|
+
from collections import defaultdict, deque
|
|
14
|
+
|
|
15
|
+
from . import dist_utils
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SmoothedValue(object):
|
|
19
|
+
"""Track a series of values and provide access to smoothed values over a
|
|
20
|
+
window or the global series average.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, window_size=20, fmt=None):
|
|
24
|
+
if fmt is None:
|
|
25
|
+
fmt = '{median:.4f} ({global_avg:.4f})'
|
|
26
|
+
self.deque = deque(maxlen=window_size)
|
|
27
|
+
self.total = 0.0
|
|
28
|
+
self.count = 0
|
|
29
|
+
self.fmt = fmt
|
|
30
|
+
|
|
31
|
+
def update(self, value, n=1):
|
|
32
|
+
self.deque.append(value)
|
|
33
|
+
self.count += n
|
|
34
|
+
self.total += value * n
|
|
35
|
+
|
|
36
|
+
def synchronize_between_processes(self):
|
|
37
|
+
"""
|
|
38
|
+
Warning: does not synchronize the deque!
|
|
39
|
+
"""
|
|
40
|
+
if not dist_utils.is_dist_avail_and_initialized():
|
|
41
|
+
return
|
|
42
|
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
|
43
|
+
dist.barrier()
|
|
44
|
+
dist.all_reduce(t)
|
|
45
|
+
t = t.tolist()
|
|
46
|
+
self.count = int(t[0])
|
|
47
|
+
self.total = t[1]
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def median(self):
|
|
51
|
+
d = torch.tensor(list(self.deque))
|
|
52
|
+
return d.median().item()
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def avg(self):
|
|
56
|
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
|
57
|
+
return d.mean().item()
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def global_avg(self):
|
|
61
|
+
return self.total / self.count
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def max(self):
|
|
65
|
+
return max(self.deque)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def value(self):
|
|
69
|
+
return self.deque[-1]
|
|
70
|
+
|
|
71
|
+
def __str__(self):
|
|
72
|
+
return self.fmt.format(
|
|
73
|
+
median=self.median,
|
|
74
|
+
avg=self.avg,
|
|
75
|
+
global_avg=self.global_avg,
|
|
76
|
+
max=self.max,
|
|
77
|
+
value=self.value,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class MetricLogger(object):
|
|
82
|
+
|
|
83
|
+
def __init__(self, delimiter='\t'):
|
|
84
|
+
self.meters = defaultdict(SmoothedValue)
|
|
85
|
+
self.delimiter = delimiter
|
|
86
|
+
|
|
87
|
+
def update(self, **kwargs):
|
|
88
|
+
for k, v in kwargs.items():
|
|
89
|
+
if isinstance(v, torch.Tensor):
|
|
90
|
+
v = v.item()
|
|
91
|
+
assert isinstance(v, (float, int))
|
|
92
|
+
self.meters[k].update(v)
|
|
93
|
+
|
|
94
|
+
def __getattr__(self, attr):
|
|
95
|
+
if attr in self.meters:
|
|
96
|
+
return self.meters[attr]
|
|
97
|
+
if attr in self.__dict__:
|
|
98
|
+
return self.__dict__[attr]
|
|
99
|
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
|
100
|
+
|
|
101
|
+
def __str__(self):
|
|
102
|
+
loss_str = []
|
|
103
|
+
for name, meter in self.meters.items():
|
|
104
|
+
loss_str.append('{}: {}'.format(name, str(meter)))
|
|
105
|
+
return self.delimiter.join(loss_str)
|
|
106
|
+
|
|
107
|
+
def global_avg(self):
|
|
108
|
+
loss_str = []
|
|
109
|
+
for name, meter in self.meters.items():
|
|
110
|
+
loss_str.append('{}: {:.4f}'.format(name, meter.global_avg))
|
|
111
|
+
return self.delimiter.join(loss_str)
|
|
112
|
+
|
|
113
|
+
def synchronize_between_processes(self):
|
|
114
|
+
for meter in self.meters.values():
|
|
115
|
+
meter.synchronize_between_processes()
|
|
116
|
+
|
|
117
|
+
def add_meter(self, name, meter):
|
|
118
|
+
self.meters[name] = meter
|
|
119
|
+
|
|
120
|
+
def log_every(self, iterable, print_freq, header=None):
|
|
121
|
+
i = 0
|
|
122
|
+
if not header:
|
|
123
|
+
header = ''
|
|
124
|
+
start_time = time.time()
|
|
125
|
+
end = time.time()
|
|
126
|
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
|
127
|
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
|
128
|
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
|
129
|
+
log_msg = [
|
|
130
|
+
header,
|
|
131
|
+
'[{0' + space_fmt + '}/{1}]',
|
|
132
|
+
'eta: {eta}',
|
|
133
|
+
'{meters}',
|
|
134
|
+
'time: {time}',
|
|
135
|
+
'data: {data}',
|
|
136
|
+
]
|
|
137
|
+
if torch.cuda.is_available():
|
|
138
|
+
log_msg.append('max mem: {memory:.0f}')
|
|
139
|
+
log_msg = self.delimiter.join(log_msg)
|
|
140
|
+
MB = 1024.0 * 1024.0
|
|
141
|
+
for obj in iterable:
|
|
142
|
+
data_time.update(time.time() - end)
|
|
143
|
+
yield obj
|
|
144
|
+
iter_time.update(time.time() - end)
|
|
145
|
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
|
146
|
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
|
147
|
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
148
|
+
if torch.cuda.is_available():
|
|
149
|
+
print(
|
|
150
|
+
log_msg.format(
|
|
151
|
+
i,
|
|
152
|
+
len(iterable),
|
|
153
|
+
eta=eta_string,
|
|
154
|
+
meters=str(self),
|
|
155
|
+
time=str(iter_time),
|
|
156
|
+
data=str(data_time),
|
|
157
|
+
memory=torch.cuda.max_memory_allocated() / MB,
|
|
158
|
+
))
|
|
159
|
+
else:
|
|
160
|
+
print(
|
|
161
|
+
log_msg.format(
|
|
162
|
+
i,
|
|
163
|
+
len(iterable),
|
|
164
|
+
eta=eta_string,
|
|
165
|
+
meters=str(self),
|
|
166
|
+
time=str(iter_time),
|
|
167
|
+
data=str(data_time),
|
|
168
|
+
))
|
|
169
|
+
i += 1
|
|
170
|
+
end = time.time()
|
|
171
|
+
total_time = time.time() - start_time
|
|
172
|
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
173
|
+
print('{} Total time: {} ({:.4f} s / it)'.format(header, total_time_str, total_time / len(iterable)))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class AttrDict(dict):
|
|
177
|
+
|
|
178
|
+
def __init__(self, *args, **kwargs):
|
|
179
|
+
super(AttrDict, self).__init__(*args, **kwargs)
|
|
180
|
+
self.__dict__ = self
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def setup_logger():
|
|
184
|
+
logging.basicConfig(
|
|
185
|
+
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
|
186
|
+
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
187
|
+
handlers=[logging.StreamHandler()],
|
|
188
|
+
)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
|
|
10
|
+
from . import registry
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@registry.register_lr_scheduler('linear_warmup_step_lr')
|
|
14
|
+
class LinearWarmupStepLRScheduler:
|
|
15
|
+
|
|
16
|
+
def __init__(self,
|
|
17
|
+
optimizer,
|
|
18
|
+
max_epoch,
|
|
19
|
+
min_lr,
|
|
20
|
+
init_lr,
|
|
21
|
+
decay_rate=1,
|
|
22
|
+
warmup_start_lr=-1,
|
|
23
|
+
warmup_steps=0,
|
|
24
|
+
**kwargs):
|
|
25
|
+
self.optimizer = optimizer
|
|
26
|
+
|
|
27
|
+
self.max_epoch = max_epoch
|
|
28
|
+
self.min_lr = min_lr
|
|
29
|
+
|
|
30
|
+
self.decay_rate = decay_rate
|
|
31
|
+
|
|
32
|
+
self.init_lr = init_lr
|
|
33
|
+
self.warmup_steps = warmup_steps
|
|
34
|
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
|
35
|
+
|
|
36
|
+
def step(self, cur_epoch, cur_step):
|
|
37
|
+
if cur_epoch == 0:
|
|
38
|
+
warmup_lr_schedule(
|
|
39
|
+
step=cur_step,
|
|
40
|
+
optimizer=self.optimizer,
|
|
41
|
+
max_step=self.warmup_steps,
|
|
42
|
+
init_lr=self.warmup_start_lr,
|
|
43
|
+
max_lr=self.init_lr,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
step_lr_schedule(
|
|
47
|
+
epoch=cur_epoch,
|
|
48
|
+
optimizer=self.optimizer,
|
|
49
|
+
init_lr=self.init_lr,
|
|
50
|
+
min_lr=self.min_lr,
|
|
51
|
+
decay_rate=self.decay_rate,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@registry.register_lr_scheduler('linear_warmup_cosine_lr')
|
|
56
|
+
class LinearWarmupCosineLRScheduler:
|
|
57
|
+
|
|
58
|
+
def __init__(self, optimizer, max_epoch, min_lr, init_lr, warmup_steps=0, warmup_start_lr=-1, **kwargs):
|
|
59
|
+
self.optimizer = optimizer
|
|
60
|
+
|
|
61
|
+
self.max_epoch = max_epoch
|
|
62
|
+
self.min_lr = min_lr
|
|
63
|
+
|
|
64
|
+
self.init_lr = init_lr
|
|
65
|
+
self.warmup_steps = warmup_steps
|
|
66
|
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
|
67
|
+
|
|
68
|
+
def step(self, cur_epoch, cur_step):
|
|
69
|
+
# assuming the warmup iters less than one epoch
|
|
70
|
+
if cur_epoch == 0:
|
|
71
|
+
warmup_lr_schedule(
|
|
72
|
+
step=cur_step,
|
|
73
|
+
optimizer=self.optimizer,
|
|
74
|
+
max_step=self.warmup_steps,
|
|
75
|
+
init_lr=self.warmup_start_lr,
|
|
76
|
+
max_lr=self.init_lr,
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
cosine_lr_schedule(
|
|
80
|
+
epoch=cur_epoch,
|
|
81
|
+
optimizer=self.optimizer,
|
|
82
|
+
max_epoch=self.max_epoch,
|
|
83
|
+
init_lr=self.init_lr,
|
|
84
|
+
min_lr=self.min_lr,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
|
89
|
+
"""Decay the learning rate"""
|
|
90
|
+
lr = (init_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * epoch / max_epoch)) + min_lr
|
|
91
|
+
for param_group in optimizer.param_groups:
|
|
92
|
+
param_group['lr'] = lr
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
|
96
|
+
"""Warmup the learning rate"""
|
|
97
|
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
|
98
|
+
for param_group in optimizer.param_groups:
|
|
99
|
+
param_group['lr'] = lr
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
|
103
|
+
"""Decay the learning rate"""
|
|
104
|
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
|
105
|
+
for param_group in optimizer.param_groups:
|
|
106
|
+
param_group['lr'] = lr
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Registry:
|
|
10
|
+
mapping = {
|
|
11
|
+
'builder_name_mapping': {},
|
|
12
|
+
'task_name_mapping': {},
|
|
13
|
+
'processor_name_mapping': {},
|
|
14
|
+
'model_name_mapping': {},
|
|
15
|
+
'lr_scheduler_name_mapping': {},
|
|
16
|
+
'runner_name_mapping': {},
|
|
17
|
+
'state': {},
|
|
18
|
+
'paths': {},
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
# @classmethod
|
|
22
|
+
# def register_builder(cls, name):
|
|
23
|
+
# r"""Register a dataset builder to registry with key 'name'
|
|
24
|
+
|
|
25
|
+
# Args:
|
|
26
|
+
# name: Key with which the builder will be registered.
|
|
27
|
+
|
|
28
|
+
# Usage:
|
|
29
|
+
|
|
30
|
+
# from lavis.common.registry import registry
|
|
31
|
+
# from lavis.datasets.base_dataset_builder import BaseDatasetBuilder
|
|
32
|
+
# """
|
|
33
|
+
|
|
34
|
+
# def wrap(builder_cls):
|
|
35
|
+
# from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
|
36
|
+
|
|
37
|
+
# assert issubclass(
|
|
38
|
+
# builder_cls, BaseDatasetBuilder
|
|
39
|
+
# ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
|
40
|
+
# builder_cls
|
|
41
|
+
# )
|
|
42
|
+
# if name in cls.mapping["builder_name_mapping"]:
|
|
43
|
+
# raise KeyError(
|
|
44
|
+
# "Name '{}' already registered for {}.".format(
|
|
45
|
+
# name, cls.mapping["builder_name_mapping"][name]
|
|
46
|
+
# )
|
|
47
|
+
# )
|
|
48
|
+
# cls.mapping["builder_name_mapping"][name] = builder_cls
|
|
49
|
+
# return builder_cls
|
|
50
|
+
|
|
51
|
+
# return wrap
|
|
52
|
+
|
|
53
|
+
# @classmethod
|
|
54
|
+
# def register_task(cls, name):
|
|
55
|
+
# r"""Register a task to registry with key 'name'
|
|
56
|
+
|
|
57
|
+
# Args:
|
|
58
|
+
# name: Key with which the task will be registered.
|
|
59
|
+
|
|
60
|
+
# Usage:
|
|
61
|
+
|
|
62
|
+
# from lavis.common.registry import registry
|
|
63
|
+
# """
|
|
64
|
+
|
|
65
|
+
# def wrap(task_cls):
|
|
66
|
+
# from lavis.tasks.base_task import BaseTask
|
|
67
|
+
|
|
68
|
+
# assert issubclass(
|
|
69
|
+
# task_cls, BaseTask
|
|
70
|
+
# ), "All tasks must inherit BaseTask class"
|
|
71
|
+
# if name in cls.mapping["task_name_mapping"]:
|
|
72
|
+
# raise KeyError(
|
|
73
|
+
# "Name '{}' already registered for {}.".format(
|
|
74
|
+
# name, cls.mapping["task_name_mapping"][name]
|
|
75
|
+
# )
|
|
76
|
+
# )
|
|
77
|
+
# cls.mapping["task_name_mapping"][name] = task_cls
|
|
78
|
+
# return task_cls
|
|
79
|
+
|
|
80
|
+
# return wrap
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def register_model(cls, name):
|
|
84
|
+
r"""Register a task to registry with key 'name'
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
name: Key with which the task will be registered.
|
|
88
|
+
|
|
89
|
+
Usage:
|
|
90
|
+
|
|
91
|
+
from lavis.common.registry import registry
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def wrap(model_cls):
|
|
95
|
+
from ..models import BaseModel
|
|
96
|
+
|
|
97
|
+
assert issubclass(model_cls, BaseModel), 'All models must inherit BaseModel class'
|
|
98
|
+
if name in cls.mapping['model_name_mapping']:
|
|
99
|
+
raise KeyError("Name '{}' already registered for {}.".format(name,
|
|
100
|
+
cls.mapping['model_name_mapping'][name]))
|
|
101
|
+
cls.mapping['model_name_mapping'][name] = model_cls
|
|
102
|
+
return model_cls
|
|
103
|
+
|
|
104
|
+
return wrap
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def register_processor(cls, name):
|
|
108
|
+
r"""Register a processor to registry with key 'name'
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
name: Key with which the task will be registered.
|
|
112
|
+
|
|
113
|
+
Usage:
|
|
114
|
+
|
|
115
|
+
from lavis.common.registry import registry
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def wrap(processor_cls):
|
|
119
|
+
from ..processors import BaseProcessor
|
|
120
|
+
|
|
121
|
+
assert issubclass(processor_cls, BaseProcessor), 'All processors must inherit BaseProcessor class'
|
|
122
|
+
if name in cls.mapping['processor_name_mapping']:
|
|
123
|
+
raise KeyError("Name '{}' already registered for {}.".format(
|
|
124
|
+
name, cls.mapping['processor_name_mapping'][name]))
|
|
125
|
+
cls.mapping['processor_name_mapping'][name] = processor_cls
|
|
126
|
+
return processor_cls
|
|
127
|
+
|
|
128
|
+
return wrap
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def register_lr_scheduler(cls, name):
|
|
132
|
+
r"""Register a model to registry with key 'name'
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
name: Key with which the task will be registered.
|
|
136
|
+
|
|
137
|
+
Usage:
|
|
138
|
+
|
|
139
|
+
from lavis.common.registry import registry
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def wrap(lr_sched_cls):
|
|
143
|
+
if name in cls.mapping['lr_scheduler_name_mapping']:
|
|
144
|
+
raise KeyError("Name '{}' already registered for {}.".format(
|
|
145
|
+
name, cls.mapping['lr_scheduler_name_mapping'][name]))
|
|
146
|
+
cls.mapping['lr_scheduler_name_mapping'][name] = lr_sched_cls
|
|
147
|
+
return lr_sched_cls
|
|
148
|
+
|
|
149
|
+
return wrap
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def register_runner(cls, name):
|
|
153
|
+
r"""Register a model to registry with key 'name'
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
name: Key with which the task will be registered.
|
|
157
|
+
|
|
158
|
+
Usage:
|
|
159
|
+
|
|
160
|
+
from lavis.common.registry import registry
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def wrap(runner_cls):
|
|
164
|
+
if name in cls.mapping['runner_name_mapping']:
|
|
165
|
+
raise KeyError("Name '{}' already registered for {}.".format(name,
|
|
166
|
+
cls.mapping['runner_name_mapping'][name]))
|
|
167
|
+
cls.mapping['runner_name_mapping'][name] = runner_cls
|
|
168
|
+
return runner_cls
|
|
169
|
+
|
|
170
|
+
return wrap
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def register_path(cls, name, path):
|
|
174
|
+
r"""Register a path to registry with key 'name'
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
name: Key with which the path will be registered.
|
|
178
|
+
|
|
179
|
+
Usage:
|
|
180
|
+
|
|
181
|
+
from lavis.common.registry import registry
|
|
182
|
+
"""
|
|
183
|
+
assert isinstance(path, str), 'All path must be str.'
|
|
184
|
+
if name in cls.mapping['paths']:
|
|
185
|
+
raise KeyError("Name '{}' already registered.".format(name))
|
|
186
|
+
cls.mapping['paths'][name] = path
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def register(cls, name, obj):
|
|
190
|
+
r"""Register an item to registry with key 'name'
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
name: Key with which the item will be registered.
|
|
194
|
+
|
|
195
|
+
Usage::
|
|
196
|
+
|
|
197
|
+
from lavis.common.registry import registry
|
|
198
|
+
|
|
199
|
+
registry.register("config", {})
|
|
200
|
+
"""
|
|
201
|
+
path = name.split('.')
|
|
202
|
+
current = cls.mapping['state']
|
|
203
|
+
|
|
204
|
+
for part in path[:-1]:
|
|
205
|
+
if part not in current:
|
|
206
|
+
current[part] = {}
|
|
207
|
+
current = current[part]
|
|
208
|
+
|
|
209
|
+
current[path[-1]] = obj
|
|
210
|
+
|
|
211
|
+
# @classmethod
|
|
212
|
+
# def get_trainer_class(cls, name):
|
|
213
|
+
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def get_builder_class(cls, name):
|
|
217
|
+
return cls.mapping['builder_name_mapping'].get(name, None)
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def get_model_class(cls, name):
|
|
221
|
+
return cls.mapping['model_name_mapping'].get(name, None)
|
|
222
|
+
|
|
223
|
+
@classmethod
|
|
224
|
+
def get_task_class(cls, name):
|
|
225
|
+
return cls.mapping['task_name_mapping'].get(name, None)
|
|
226
|
+
|
|
227
|
+
@classmethod
|
|
228
|
+
def get_processor_class(cls, name):
|
|
229
|
+
return cls.mapping['processor_name_mapping'].get(name, None)
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def get_lr_scheduler_class(cls, name):
|
|
233
|
+
return cls.mapping['lr_scheduler_name_mapping'].get(name, None)
|
|
234
|
+
|
|
235
|
+
@classmethod
|
|
236
|
+
def get_runner_class(cls, name):
|
|
237
|
+
return cls.mapping['runner_name_mapping'].get(name, None)
|
|
238
|
+
|
|
239
|
+
@classmethod
|
|
240
|
+
def list_runners(cls):
|
|
241
|
+
return sorted(cls.mapping['runner_name_mapping'].keys())
|
|
242
|
+
|
|
243
|
+
@classmethod
|
|
244
|
+
def list_models(cls):
|
|
245
|
+
return sorted(cls.mapping['model_name_mapping'].keys())
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def list_tasks(cls):
|
|
249
|
+
return sorted(cls.mapping['task_name_mapping'].keys())
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def list_processors(cls):
|
|
253
|
+
return sorted(cls.mapping['processor_name_mapping'].keys())
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def list_lr_schedulers(cls):
|
|
257
|
+
return sorted(cls.mapping['lr_scheduler_name_mapping'].keys())
|
|
258
|
+
|
|
259
|
+
@classmethod
|
|
260
|
+
def list_datasets(cls):
|
|
261
|
+
return sorted(cls.mapping['builder_name_mapping'].keys())
|
|
262
|
+
|
|
263
|
+
@classmethod
|
|
264
|
+
def get_path(cls, name):
|
|
265
|
+
return cls.mapping['paths'].get(name, None)
|
|
266
|
+
|
|
267
|
+
@classmethod
|
|
268
|
+
def get(cls, name, default=None, no_warning=False):
|
|
269
|
+
r"""Get an item from registry with key 'name'
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
name (string): Key whose value needs to be retrieved.
|
|
273
|
+
default: If passed and key is not in registry, default value will
|
|
274
|
+
be returned with a warning. Default: None
|
|
275
|
+
no_warning (bool): If passed as True, warning when key doesn't exist
|
|
276
|
+
will not be generated. Useful for MMF's
|
|
277
|
+
internal operations. Default: False
|
|
278
|
+
"""
|
|
279
|
+
original_name = name
|
|
280
|
+
name = name.split('.')
|
|
281
|
+
value = cls.mapping['state']
|
|
282
|
+
for subname in name:
|
|
283
|
+
value = value.get(subname, default)
|
|
284
|
+
if value is default:
|
|
285
|
+
break
|
|
286
|
+
|
|
287
|
+
if ('writer' in cls.mapping['state'] and value == default and no_warning is False):
|
|
288
|
+
cls.mapping['state']['writer'].warning('Key {} is not present in registry, returning default value '
|
|
289
|
+
'of {}'.format(original_name, default))
|
|
290
|
+
return value
|
|
291
|
+
|
|
292
|
+
@classmethod
|
|
293
|
+
def unregister(cls, name):
|
|
294
|
+
r"""Remove an item from registry with key 'name'
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
name: Key which needs to be removed.
|
|
298
|
+
Usage::
|
|
299
|
+
|
|
300
|
+
from mmf.common.registry import registry
|
|
301
|
+
|
|
302
|
+
config = registry.unregister("config")
|
|
303
|
+
"""
|
|
304
|
+
return cls.mapping['state'].pop(name, None)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
registry = Registry()
|