evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,188 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ import torch
12
+ import torch.distributed as dist
13
+ from collections import defaultdict, deque
14
+
15
+ from . import dist_utils
16
+
17
+
18
+ class SmoothedValue(object):
19
+ """Track a series of values and provide access to smoothed values over a
20
+ window or the global series average.
21
+ """
22
+
23
+ def __init__(self, window_size=20, fmt=None):
24
+ if fmt is None:
25
+ fmt = '{median:.4f} ({global_avg:.4f})'
26
+ self.deque = deque(maxlen=window_size)
27
+ self.total = 0.0
28
+ self.count = 0
29
+ self.fmt = fmt
30
+
31
+ def update(self, value, n=1):
32
+ self.deque.append(value)
33
+ self.count += n
34
+ self.total += value * n
35
+
36
+ def synchronize_between_processes(self):
37
+ """
38
+ Warning: does not synchronize the deque!
39
+ """
40
+ if not dist_utils.is_dist_avail_and_initialized():
41
+ return
42
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43
+ dist.barrier()
44
+ dist.all_reduce(t)
45
+ t = t.tolist()
46
+ self.count = int(t[0])
47
+ self.total = t[1]
48
+
49
+ @property
50
+ def median(self):
51
+ d = torch.tensor(list(self.deque))
52
+ return d.median().item()
53
+
54
+ @property
55
+ def avg(self):
56
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
57
+ return d.mean().item()
58
+
59
+ @property
60
+ def global_avg(self):
61
+ return self.total / self.count
62
+
63
+ @property
64
+ def max(self):
65
+ return max(self.deque)
66
+
67
+ @property
68
+ def value(self):
69
+ return self.deque[-1]
70
+
71
+ def __str__(self):
72
+ return self.fmt.format(
73
+ median=self.median,
74
+ avg=self.avg,
75
+ global_avg=self.global_avg,
76
+ max=self.max,
77
+ value=self.value,
78
+ )
79
+
80
+
81
+ class MetricLogger(object):
82
+
83
+ def __init__(self, delimiter='\t'):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
100
+
101
+ def __str__(self):
102
+ loss_str = []
103
+ for name, meter in self.meters.items():
104
+ loss_str.append('{}: {}'.format(name, str(meter)))
105
+ return self.delimiter.join(loss_str)
106
+
107
+ def global_avg(self):
108
+ loss_str = []
109
+ for name, meter in self.meters.items():
110
+ loss_str.append('{}: {:.4f}'.format(name, meter.global_avg))
111
+ return self.delimiter.join(loss_str)
112
+
113
+ def synchronize_between_processes(self):
114
+ for meter in self.meters.values():
115
+ meter.synchronize_between_processes()
116
+
117
+ def add_meter(self, name, meter):
118
+ self.meters[name] = meter
119
+
120
+ def log_every(self, iterable, print_freq, header=None):
121
+ i = 0
122
+ if not header:
123
+ header = ''
124
+ start_time = time.time()
125
+ end = time.time()
126
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
127
+ data_time = SmoothedValue(fmt='{avg:.4f}')
128
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
129
+ log_msg = [
130
+ header,
131
+ '[{0' + space_fmt + '}/{1}]',
132
+ 'eta: {eta}',
133
+ '{meters}',
134
+ 'time: {time}',
135
+ 'data: {data}',
136
+ ]
137
+ if torch.cuda.is_available():
138
+ log_msg.append('max mem: {memory:.0f}')
139
+ log_msg = self.delimiter.join(log_msg)
140
+ MB = 1024.0 * 1024.0
141
+ for obj in iterable:
142
+ data_time.update(time.time() - end)
143
+ yield obj
144
+ iter_time.update(time.time() - end)
145
+ if i % print_freq == 0 or i == len(iterable) - 1:
146
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
147
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
148
+ if torch.cuda.is_available():
149
+ print(
150
+ log_msg.format(
151
+ i,
152
+ len(iterable),
153
+ eta=eta_string,
154
+ meters=str(self),
155
+ time=str(iter_time),
156
+ data=str(data_time),
157
+ memory=torch.cuda.max_memory_allocated() / MB,
158
+ ))
159
+ else:
160
+ print(
161
+ log_msg.format(
162
+ i,
163
+ len(iterable),
164
+ eta=eta_string,
165
+ meters=str(self),
166
+ time=str(iter_time),
167
+ data=str(data_time),
168
+ ))
169
+ i += 1
170
+ end = time.time()
171
+ total_time = time.time() - start_time
172
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
173
+ print('{} Total time: {} ({:.4f} s / it)'.format(header, total_time_str, total_time / len(iterable)))
174
+
175
+
176
+ class AttrDict(dict):
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ super(AttrDict, self).__init__(*args, **kwargs)
180
+ self.__dict__ = self
181
+
182
+
183
+ def setup_logger():
184
+ logging.basicConfig(
185
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
186
+ format='%(asctime)s [%(levelname)s] %(message)s',
187
+ handlers=[logging.StreamHandler()],
188
+ )
@@ -0,0 +1,106 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from . import registry
11
+
12
+
13
+ @registry.register_lr_scheduler('linear_warmup_step_lr')
14
+ class LinearWarmupStepLRScheduler:
15
+
16
+ def __init__(self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs):
25
+ self.optimizer = optimizer
26
+
27
+ self.max_epoch = max_epoch
28
+ self.min_lr = min_lr
29
+
30
+ self.decay_rate = decay_rate
31
+
32
+ self.init_lr = init_lr
33
+ self.warmup_steps = warmup_steps
34
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
35
+
36
+ def step(self, cur_epoch, cur_step):
37
+ if cur_epoch == 0:
38
+ warmup_lr_schedule(
39
+ step=cur_step,
40
+ optimizer=self.optimizer,
41
+ max_step=self.warmup_steps,
42
+ init_lr=self.warmup_start_lr,
43
+ max_lr=self.init_lr,
44
+ )
45
+ else:
46
+ step_lr_schedule(
47
+ epoch=cur_epoch,
48
+ optimizer=self.optimizer,
49
+ init_lr=self.init_lr,
50
+ min_lr=self.min_lr,
51
+ decay_rate=self.decay_rate,
52
+ )
53
+
54
+
55
+ @registry.register_lr_scheduler('linear_warmup_cosine_lr')
56
+ class LinearWarmupCosineLRScheduler:
57
+
58
+ def __init__(self, optimizer, max_epoch, min_lr, init_lr, warmup_steps=0, warmup_start_lr=-1, **kwargs):
59
+ self.optimizer = optimizer
60
+
61
+ self.max_epoch = max_epoch
62
+ self.min_lr = min_lr
63
+
64
+ self.init_lr = init_lr
65
+ self.warmup_steps = warmup_steps
66
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
67
+
68
+ def step(self, cur_epoch, cur_step):
69
+ # assuming the warmup iters less than one epoch
70
+ if cur_epoch == 0:
71
+ warmup_lr_schedule(
72
+ step=cur_step,
73
+ optimizer=self.optimizer,
74
+ max_step=self.warmup_steps,
75
+ init_lr=self.warmup_start_lr,
76
+ max_lr=self.init_lr,
77
+ )
78
+ else:
79
+ cosine_lr_schedule(
80
+ epoch=cur_epoch,
81
+ optimizer=self.optimizer,
82
+ max_epoch=self.max_epoch,
83
+ init_lr=self.init_lr,
84
+ min_lr=self.min_lr,
85
+ )
86
+
87
+
88
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
89
+ """Decay the learning rate"""
90
+ lr = (init_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * epoch / max_epoch)) + min_lr
91
+ for param_group in optimizer.param_groups:
92
+ param_group['lr'] = lr
93
+
94
+
95
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
96
+ """Warmup the learning rate"""
97
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
98
+ for param_group in optimizer.param_groups:
99
+ param_group['lr'] = lr
100
+
101
+
102
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
103
+ """Decay the learning rate"""
104
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
105
+ for param_group in optimizer.param_groups:
106
+ param_group['lr'] = lr
@@ -0,0 +1,307 @@
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ 'builder_name_mapping': {},
12
+ 'task_name_mapping': {},
13
+ 'processor_name_mapping': {},
14
+ 'model_name_mapping': {},
15
+ 'lr_scheduler_name_mapping': {},
16
+ 'runner_name_mapping': {},
17
+ 'state': {},
18
+ 'paths': {},
19
+ }
20
+
21
+ # @classmethod
22
+ # def register_builder(cls, name):
23
+ # r"""Register a dataset builder to registry with key 'name'
24
+
25
+ # Args:
26
+ # name: Key with which the builder will be registered.
27
+
28
+ # Usage:
29
+
30
+ # from lavis.common.registry import registry
31
+ # from lavis.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ # """
33
+
34
+ # def wrap(builder_cls):
35
+ # from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ # assert issubclass(
38
+ # builder_cls, BaseDatasetBuilder
39
+ # ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ # builder_cls
41
+ # )
42
+ # if name in cls.mapping["builder_name_mapping"]:
43
+ # raise KeyError(
44
+ # "Name '{}' already registered for {}.".format(
45
+ # name, cls.mapping["builder_name_mapping"][name]
46
+ # )
47
+ # )
48
+ # cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ # return builder_cls
50
+
51
+ # return wrap
52
+
53
+ # @classmethod
54
+ # def register_task(cls, name):
55
+ # r"""Register a task to registry with key 'name'
56
+
57
+ # Args:
58
+ # name: Key with which the task will be registered.
59
+
60
+ # Usage:
61
+
62
+ # from lavis.common.registry import registry
63
+ # """
64
+
65
+ # def wrap(task_cls):
66
+ # from lavis.tasks.base_task import BaseTask
67
+
68
+ # assert issubclass(
69
+ # task_cls, BaseTask
70
+ # ), "All tasks must inherit BaseTask class"
71
+ # if name in cls.mapping["task_name_mapping"]:
72
+ # raise KeyError(
73
+ # "Name '{}' already registered for {}.".format(
74
+ # name, cls.mapping["task_name_mapping"][name]
75
+ # )
76
+ # )
77
+ # cls.mapping["task_name_mapping"][name] = task_cls
78
+ # return task_cls
79
+
80
+ # return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from lavis.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from ..models import BaseModel
96
+
97
+ assert issubclass(model_cls, BaseModel), 'All models must inherit BaseModel class'
98
+ if name in cls.mapping['model_name_mapping']:
99
+ raise KeyError("Name '{}' already registered for {}.".format(name,
100
+ cls.mapping['model_name_mapping'][name]))
101
+ cls.mapping['model_name_mapping'][name] = model_cls
102
+ return model_cls
103
+
104
+ return wrap
105
+
106
+ @classmethod
107
+ def register_processor(cls, name):
108
+ r"""Register a processor to registry with key 'name'
109
+
110
+ Args:
111
+ name: Key with which the task will be registered.
112
+
113
+ Usage:
114
+
115
+ from lavis.common.registry import registry
116
+ """
117
+
118
+ def wrap(processor_cls):
119
+ from ..processors import BaseProcessor
120
+
121
+ assert issubclass(processor_cls, BaseProcessor), 'All processors must inherit BaseProcessor class'
122
+ if name in cls.mapping['processor_name_mapping']:
123
+ raise KeyError("Name '{}' already registered for {}.".format(
124
+ name, cls.mapping['processor_name_mapping'][name]))
125
+ cls.mapping['processor_name_mapping'][name] = processor_cls
126
+ return processor_cls
127
+
128
+ return wrap
129
+
130
+ @classmethod
131
+ def register_lr_scheduler(cls, name):
132
+ r"""Register a model to registry with key 'name'
133
+
134
+ Args:
135
+ name: Key with which the task will be registered.
136
+
137
+ Usage:
138
+
139
+ from lavis.common.registry import registry
140
+ """
141
+
142
+ def wrap(lr_sched_cls):
143
+ if name in cls.mapping['lr_scheduler_name_mapping']:
144
+ raise KeyError("Name '{}' already registered for {}.".format(
145
+ name, cls.mapping['lr_scheduler_name_mapping'][name]))
146
+ cls.mapping['lr_scheduler_name_mapping'][name] = lr_sched_cls
147
+ return lr_sched_cls
148
+
149
+ return wrap
150
+
151
+ @classmethod
152
+ def register_runner(cls, name):
153
+ r"""Register a model to registry with key 'name'
154
+
155
+ Args:
156
+ name: Key with which the task will be registered.
157
+
158
+ Usage:
159
+
160
+ from lavis.common.registry import registry
161
+ """
162
+
163
+ def wrap(runner_cls):
164
+ if name in cls.mapping['runner_name_mapping']:
165
+ raise KeyError("Name '{}' already registered for {}.".format(name,
166
+ cls.mapping['runner_name_mapping'][name]))
167
+ cls.mapping['runner_name_mapping'][name] = runner_cls
168
+ return runner_cls
169
+
170
+ return wrap
171
+
172
+ @classmethod
173
+ def register_path(cls, name, path):
174
+ r"""Register a path to registry with key 'name'
175
+
176
+ Args:
177
+ name: Key with which the path will be registered.
178
+
179
+ Usage:
180
+
181
+ from lavis.common.registry import registry
182
+ """
183
+ assert isinstance(path, str), 'All path must be str.'
184
+ if name in cls.mapping['paths']:
185
+ raise KeyError("Name '{}' already registered.".format(name))
186
+ cls.mapping['paths'][name] = path
187
+
188
+ @classmethod
189
+ def register(cls, name, obj):
190
+ r"""Register an item to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the item will be registered.
194
+
195
+ Usage::
196
+
197
+ from lavis.common.registry import registry
198
+
199
+ registry.register("config", {})
200
+ """
201
+ path = name.split('.')
202
+ current = cls.mapping['state']
203
+
204
+ for part in path[:-1]:
205
+ if part not in current:
206
+ current[part] = {}
207
+ current = current[part]
208
+
209
+ current[path[-1]] = obj
210
+
211
+ # @classmethod
212
+ # def get_trainer_class(cls, name):
213
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
214
+
215
+ @classmethod
216
+ def get_builder_class(cls, name):
217
+ return cls.mapping['builder_name_mapping'].get(name, None)
218
+
219
+ @classmethod
220
+ def get_model_class(cls, name):
221
+ return cls.mapping['model_name_mapping'].get(name, None)
222
+
223
+ @classmethod
224
+ def get_task_class(cls, name):
225
+ return cls.mapping['task_name_mapping'].get(name, None)
226
+
227
+ @classmethod
228
+ def get_processor_class(cls, name):
229
+ return cls.mapping['processor_name_mapping'].get(name, None)
230
+
231
+ @classmethod
232
+ def get_lr_scheduler_class(cls, name):
233
+ return cls.mapping['lr_scheduler_name_mapping'].get(name, None)
234
+
235
+ @classmethod
236
+ def get_runner_class(cls, name):
237
+ return cls.mapping['runner_name_mapping'].get(name, None)
238
+
239
+ @classmethod
240
+ def list_runners(cls):
241
+ return sorted(cls.mapping['runner_name_mapping'].keys())
242
+
243
+ @classmethod
244
+ def list_models(cls):
245
+ return sorted(cls.mapping['model_name_mapping'].keys())
246
+
247
+ @classmethod
248
+ def list_tasks(cls):
249
+ return sorted(cls.mapping['task_name_mapping'].keys())
250
+
251
+ @classmethod
252
+ def list_processors(cls):
253
+ return sorted(cls.mapping['processor_name_mapping'].keys())
254
+
255
+ @classmethod
256
+ def list_lr_schedulers(cls):
257
+ return sorted(cls.mapping['lr_scheduler_name_mapping'].keys())
258
+
259
+ @classmethod
260
+ def list_datasets(cls):
261
+ return sorted(cls.mapping['builder_name_mapping'].keys())
262
+
263
+ @classmethod
264
+ def get_path(cls, name):
265
+ return cls.mapping['paths'].get(name, None)
266
+
267
+ @classmethod
268
+ def get(cls, name, default=None, no_warning=False):
269
+ r"""Get an item from registry with key 'name'
270
+
271
+ Args:
272
+ name (string): Key whose value needs to be retrieved.
273
+ default: If passed and key is not in registry, default value will
274
+ be returned with a warning. Default: None
275
+ no_warning (bool): If passed as True, warning when key doesn't exist
276
+ will not be generated. Useful for MMF's
277
+ internal operations. Default: False
278
+ """
279
+ original_name = name
280
+ name = name.split('.')
281
+ value = cls.mapping['state']
282
+ for subname in name:
283
+ value = value.get(subname, default)
284
+ if value is default:
285
+ break
286
+
287
+ if ('writer' in cls.mapping['state'] and value == default and no_warning is False):
288
+ cls.mapping['state']['writer'].warning('Key {} is not present in registry, returning default value '
289
+ 'of {}'.format(original_name, default))
290
+ return value
291
+
292
+ @classmethod
293
+ def unregister(cls, name):
294
+ r"""Remove an item from registry with key 'name'
295
+
296
+ Args:
297
+ name: Key which needs to be removed.
298
+ Usage::
299
+
300
+ from mmf.common.registry import registry
301
+
302
+ config = registry.unregister("config")
303
+ """
304
+ return cls.mapping['state'].pop(name, None)
305
+
306
+
307
+ registry = Registry()