evalscope 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (178) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  24. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  25. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  26. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  27. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  28. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  29. evalscope/benchmarks/utils.py +7 -16
  30. evalscope/cli/start_app.py +1 -1
  31. evalscope/collections/evaluator.py +16 -4
  32. evalscope/config.py +7 -3
  33. evalscope/constants.py +11 -0
  34. evalscope/evaluator/evaluator.py +2 -2
  35. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  36. evalscope/metrics/__init__.py +49 -4
  37. evalscope/metrics/llm_judge.py +1 -1
  38. evalscope/metrics/named_metrics.py +13 -0
  39. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  40. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  41. evalscope/metrics/t2v_metrics/constants.py +12 -0
  42. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  43. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  44. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  53. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  60. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  61. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  62. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  138. evalscope/metrics/t2v_metrics/score.py +78 -0
  139. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  140. evalscope/models/__init__.py +50 -14
  141. evalscope/models/adapters/__init__.py +17 -0
  142. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  143. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  144. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  145. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  146. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  147. evalscope/models/adapters/t2i_adapter.py +76 -0
  148. evalscope/models/custom/__init__.py +2 -1
  149. evalscope/models/custom/dummy_model.py +11 -13
  150. evalscope/models/local_model.py +82 -33
  151. evalscope/models/model.py +2 -42
  152. evalscope/models/register.py +26 -0
  153. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  154. evalscope/perf/utils/benchmark_util.py +2 -2
  155. evalscope/perf/utils/db_util.py +8 -2
  156. evalscope/report/__init__.py +1 -0
  157. evalscope/report/app.py +117 -67
  158. evalscope/report/app_arguments.py +11 -0
  159. evalscope/report/generator.py +1 -1
  160. evalscope/run.py +3 -3
  161. evalscope/third_party/thinkbench/eval.py +19 -7
  162. evalscope/utils/chat_service.py +2 -2
  163. evalscope/utils/import_utils.py +66 -0
  164. evalscope/utils/utils.py +12 -4
  165. evalscope/version.py +2 -2
  166. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/METADATA +18 -1
  167. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/RECORD +175 -63
  168. tests/aigc/__init__.py +1 -0
  169. tests/aigc/test_t2i.py +87 -0
  170. tests/cli/test_run.py +11 -5
  171. tests/perf/test_perf.py +2 -1
  172. evalscope/metrics/code_metric.py +0 -98
  173. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  174. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  175. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
  176. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
  177. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,503 @@
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+ from functools import partial
14
+
15
+ try:
16
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
17
+ except ImportError:
18
+ pass
19
+
20
+ from ..common.dist_utils import download_cached_file
21
+
22
+
23
+ def _cfg(url='', **kwargs):
24
+ return {
25
+ 'url': url,
26
+ 'num_classes': 1000,
27
+ 'input_size': (3, 224, 224),
28
+ 'pool_size': None,
29
+ 'crop_pct': .9,
30
+ 'interpolation': 'bicubic',
31
+ 'mean': (0.5, 0.5, 0.5),
32
+ 'std': (0.5, 0.5, 0.5),
33
+ **kwargs
34
+ }
35
+
36
+
37
+ class DropPath(nn.Module):
38
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39
+ """
40
+
41
+ def __init__(self, drop_prob=None):
42
+ super(DropPath, self).__init__()
43
+ self.drop_prob = drop_prob
44
+
45
+ def forward(self, x):
46
+ return drop_path(x, self.drop_prob, self.training)
47
+
48
+ def extra_repr(self) -> str:
49
+ return 'p={}'.format(self.drop_prob)
50
+
51
+
52
+ class Mlp(nn.Module):
53
+
54
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
55
+ super().__init__()
56
+ out_features = out_features or in_features
57
+ hidden_features = hidden_features or in_features
58
+ self.fc1 = nn.Linear(in_features, hidden_features)
59
+ self.act = act_layer()
60
+ self.fc2 = nn.Linear(hidden_features, out_features)
61
+ self.drop = nn.Dropout(drop)
62
+
63
+ def forward(self, x):
64
+ x = self.fc1(x)
65
+ x = self.act(x)
66
+ # x = self.drop(x)
67
+ # commit this for the orignal BERT implement
68
+ x = self.fc2(x)
69
+ x = self.drop(x)
70
+ return x
71
+
72
+
73
+ class Attention(nn.Module):
74
+
75
+ def __init__(self,
76
+ dim,
77
+ num_heads=8,
78
+ qkv_bias=False,
79
+ qk_scale=None,
80
+ attn_drop=0.,
81
+ proj_drop=0.,
82
+ window_size=None,
83
+ attn_head_dim=None):
84
+ super().__init__()
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+ if attn_head_dim is not None:
88
+ head_dim = attn_head_dim
89
+ all_head_dim = head_dim * self.num_heads
90
+ self.scale = qk_scale or head_dim**-0.5
91
+
92
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
93
+ if qkv_bias:
94
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
95
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
96
+ else:
97
+ self.q_bias = None
98
+ self.v_bias = None
99
+
100
+ if window_size:
101
+ self.window_size = window_size
102
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
103
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
104
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
105
+ # cls to token & token 2 cls & cls to cls
106
+
107
+ # get pair-wise relative position index for each token inside the window
108
+ coords_h = torch.arange(window_size[0])
109
+ coords_w = torch.arange(window_size[1])
110
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
111
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
112
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
113
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
114
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
115
+ relative_coords[:, :, 1] += window_size[1] - 1
116
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
117
+ relative_position_index = \
118
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
119
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
120
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
121
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
122
+ relative_position_index[0, 0] = self.num_relative_distance - 1
123
+
124
+ self.register_buffer('relative_position_index', relative_position_index)
125
+ else:
126
+ self.window_size = None
127
+ self.relative_position_bias_table = None
128
+ self.relative_position_index = None
129
+
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(all_head_dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+
134
+ def forward(self, x, rel_pos_bias=None):
135
+ B, N, C = x.shape
136
+ qkv_bias = None
137
+ if self.q_bias is not None:
138
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
139
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
140
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
141
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
142
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
143
+
144
+ q = q * self.scale
145
+ attn = (q @ k.transpose(-2, -1))
146
+
147
+ if self.relative_position_bias_table is not None:
148
+ relative_position_bias = \
149
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
150
+ self.window_size[0] * self.window_size[1] + 1,
151
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
152
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
153
+ attn = attn + relative_position_bias.unsqueeze(0)
154
+
155
+ if rel_pos_bias is not None:
156
+ attn = attn + rel_pos_bias
157
+
158
+ attn = attn.softmax(dim=-1)
159
+ attn = self.attn_drop(attn)
160
+
161
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
162
+ x = self.proj(x)
163
+ x = self.proj_drop(x)
164
+ return x
165
+
166
+
167
+ class Block(nn.Module):
168
+
169
+ def __init__(self,
170
+ dim,
171
+ num_heads,
172
+ mlp_ratio=4.,
173
+ qkv_bias=False,
174
+ qk_scale=None,
175
+ drop=0.,
176
+ attn_drop=0.,
177
+ drop_path=0.,
178
+ init_values=None,
179
+ act_layer=nn.GELU,
180
+ norm_layer=nn.LayerNorm,
181
+ window_size=None,
182
+ attn_head_dim=None):
183
+ super().__init__()
184
+ self.norm1 = norm_layer(dim)
185
+ self.attn = Attention(
186
+ dim,
187
+ num_heads=num_heads,
188
+ qkv_bias=qkv_bias,
189
+ qk_scale=qk_scale,
190
+ attn_drop=attn_drop,
191
+ proj_drop=drop,
192
+ window_size=window_size,
193
+ attn_head_dim=attn_head_dim)
194
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
195
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
196
+ self.norm2 = norm_layer(dim)
197
+ mlp_hidden_dim = int(dim * mlp_ratio)
198
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
199
+
200
+ if init_values is not None and init_values > 0:
201
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
202
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
203
+ else:
204
+ self.gamma_1, self.gamma_2 = None, None
205
+
206
+ def forward(self, x, rel_pos_bias=None):
207
+ if self.gamma_1 is None:
208
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
209
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
210
+ else:
211
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
212
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
213
+ return x
214
+
215
+
216
+ class PatchEmbed(nn.Module):
217
+ """ Image to Patch Embedding
218
+ """
219
+
220
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
221
+ super().__init__()
222
+ img_size = to_2tuple(img_size)
223
+ patch_size = to_2tuple(patch_size)
224
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
225
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
226
+ self.img_size = img_size
227
+ self.patch_size = patch_size
228
+ self.num_patches = num_patches
229
+
230
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
231
+
232
+ def forward(self, x, **kwargs):
233
+ B, C, H, W = x.shape
234
+ # FIXME look at relaxing size constraints
235
+ assert H == self.img_size[0] and W == self.img_size[1], \
236
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
237
+ x = self.proj(x).flatten(2).transpose(1, 2)
238
+ return x
239
+
240
+
241
+ class RelativePositionBias(nn.Module):
242
+
243
+ def __init__(self, window_size, num_heads):
244
+ super().__init__()
245
+ self.window_size = window_size
246
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
247
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
248
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
249
+ # cls to token & token 2 cls & cls to cls
250
+
251
+ # get pair-wise relative position index for each token inside the window
252
+ coords_h = torch.arange(window_size[0])
253
+ coords_w = torch.arange(window_size[1])
254
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
255
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
256
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
257
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
258
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
259
+ relative_coords[:, :, 1] += window_size[1] - 1
260
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
261
+ relative_position_index = \
262
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
263
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
264
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
265
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
266
+ relative_position_index[0, 0] = self.num_relative_distance - 1
267
+
268
+ self.register_buffer('relative_position_index', relative_position_index)
269
+
270
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
271
+
272
+ def forward(self):
273
+ relative_position_bias = \
274
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
275
+ self.window_size[0] * self.window_size[1] + 1,
276
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
277
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
278
+
279
+
280
+ class VisionTransformer(nn.Module):
281
+ """ Vision Transformer with support for patch or hybrid CNN input stage
282
+ """
283
+
284
+ def __init__(self,
285
+ img_size=224,
286
+ patch_size=16,
287
+ in_chans=3,
288
+ num_classes=1000,
289
+ embed_dim=768,
290
+ depth=12,
291
+ num_heads=12,
292
+ mlp_ratio=4.,
293
+ qkv_bias=False,
294
+ qk_scale=None,
295
+ drop_rate=0.,
296
+ attn_drop_rate=0.,
297
+ drop_path_rate=0.,
298
+ norm_layer=nn.LayerNorm,
299
+ init_values=None,
300
+ use_abs_pos_emb=True,
301
+ use_rel_pos_bias=False,
302
+ use_shared_rel_pos_bias=False,
303
+ use_mean_pooling=True,
304
+ init_scale=0.001,
305
+ use_checkpoint=False):
306
+ super().__init__()
307
+ self.image_size = img_size
308
+ self.num_classes = num_classes
309
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
310
+
311
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
312
+ num_patches = self.patch_embed.num_patches
313
+
314
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
315
+ if use_abs_pos_emb:
316
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
317
+ else:
318
+ self.pos_embed = None
319
+ self.pos_drop = nn.Dropout(p=drop_rate)
320
+
321
+ if use_shared_rel_pos_bias:
322
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
323
+ else:
324
+ self.rel_pos_bias = None
325
+ self.use_checkpoint = use_checkpoint
326
+
327
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
328
+ self.use_rel_pos_bias = use_rel_pos_bias
329
+ self.blocks = nn.ModuleList([
330
+ Block(
331
+ dim=embed_dim,
332
+ num_heads=num_heads,
333
+ mlp_ratio=mlp_ratio,
334
+ qkv_bias=qkv_bias,
335
+ qk_scale=qk_scale,
336
+ drop=drop_rate,
337
+ attn_drop=attn_drop_rate,
338
+ drop_path=dpr[i],
339
+ norm_layer=norm_layer,
340
+ init_values=init_values,
341
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)
342
+ ])
343
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
344
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
345
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
346
+
347
+ if self.pos_embed is not None:
348
+ trunc_normal_(self.pos_embed, std=.02)
349
+ trunc_normal_(self.cls_token, std=.02)
350
+ # trunc_normal_(self.mask_token, std=.02)
351
+ # if isinstance(self.head, nn.Linear):
352
+ # trunc_normal_(self.head.weight, std=.02)
353
+ self.apply(self._init_weights)
354
+ self.fix_init_weight()
355
+ # if isinstance(self.head, nn.Linear):
356
+ # self.head.weight.data.mul_(init_scale)
357
+ # self.head.bias.data.mul_(init_scale)
358
+
359
+ def fix_init_weight(self):
360
+
361
+ def rescale(param, layer_id):
362
+ param.div_(math.sqrt(2.0 * layer_id))
363
+
364
+ for layer_id, layer in enumerate(self.blocks):
365
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
366
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
367
+
368
+ def _init_weights(self, m):
369
+ if isinstance(m, nn.Linear):
370
+ trunc_normal_(m.weight, std=.02)
371
+ if isinstance(m, nn.Linear) and m.bias is not None:
372
+ nn.init.constant_(m.bias, 0)
373
+ elif isinstance(m, nn.LayerNorm):
374
+ nn.init.constant_(m.bias, 0)
375
+ nn.init.constant_(m.weight, 1.0)
376
+
377
+ def get_classifier(self):
378
+ return self.head
379
+
380
+ def reset_classifier(self, num_classes, global_pool=''):
381
+ self.num_classes = num_classes
382
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
383
+
384
+ def forward_features(self, x):
385
+ x = self.patch_embed(x)
386
+ batch_size, seq_len, _ = x.size()
387
+
388
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
389
+ x = torch.cat((cls_tokens, x), dim=1)
390
+ if self.pos_embed is not None:
391
+ x = x + self.pos_embed
392
+ x = self.pos_drop(x)
393
+
394
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
395
+ for blk in self.blocks:
396
+ if self.use_checkpoint:
397
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias, use_reentrant=False)
398
+ else:
399
+ x = blk(x, rel_pos_bias)
400
+ return x
401
+
402
+
403
+ # x = self.norm(x)
404
+
405
+ # if self.fc_norm is not None:
406
+ # t = x[:, 1:, :]
407
+ # return self.fc_norm(t.mean(1))
408
+ # else:
409
+ # return x[:, 0]
410
+
411
+ def forward(self, x):
412
+ x = self.forward_features(x)
413
+ # x = self.head(x)
414
+ return x
415
+
416
+ def get_intermediate_layers(self, x):
417
+ x = self.patch_embed(x)
418
+ batch_size, seq_len, _ = x.size()
419
+
420
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
421
+ x = torch.cat((cls_tokens, x), dim=1)
422
+ if self.pos_embed is not None:
423
+ x = x + self.pos_embed
424
+ x = self.pos_drop(x)
425
+
426
+ features = []
427
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
428
+ for blk in self.blocks:
429
+ x = blk(x, rel_pos_bias)
430
+ features.append(x)
431
+
432
+ return features
433
+
434
+
435
+ def interpolate_pos_embed(model, checkpoint_model):
436
+ if 'pos_embed' in checkpoint_model:
437
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
438
+ embedding_size = pos_embed_checkpoint.shape[-1]
439
+ num_patches = model.patch_embed.num_patches
440
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
441
+ # height (== width) for the checkpoint position embedding
442
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
443
+ # height (== width) for the new position embedding
444
+ new_size = int(num_patches**0.5)
445
+ # class_token and dist_token are kept unchanged
446
+ if orig_size != new_size:
447
+ print('Position interpolate from %dx%d to %dx%d' % (orig_size, orig_size, new_size, new_size))
448
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
449
+ # only the position tokens are interpolated
450
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
451
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
452
+ pos_tokens = torch.nn.functional.interpolate(
453
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
454
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
455
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
456
+ checkpoint_model['pos_embed'] = new_pos_embed
457
+
458
+
459
+ def convert_weights_to_fp16(model: nn.Module):
460
+ """Convert applicable model parameters to fp16"""
461
+
462
+ def _convert_weights_to_fp16(l):
463
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
464
+ l.weight.data = l.weight.data.half()
465
+ if l.bias is not None:
466
+ l.bias.data = l.bias.data.half()
467
+
468
+
469
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
470
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
471
+ # tensor = getattr(l, attr)
472
+ # if tensor is not None:
473
+ # tensor.data = tensor.data.half()
474
+
475
+ model.apply(_convert_weights_to_fp16)
476
+
477
+
478
+ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision='fp16'):
479
+ model = VisionTransformer(
480
+ img_size=img_size,
481
+ patch_size=14,
482
+ use_mean_pooling=False,
483
+ embed_dim=1408,
484
+ depth=39,
485
+ num_heads=1408 // 88,
486
+ mlp_ratio=4.3637,
487
+ qkv_bias=True,
488
+ drop_path_rate=drop_path_rate,
489
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
490
+ use_checkpoint=use_checkpoint,
491
+ )
492
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth'
493
+ cached_file = download_cached_file(url, check_hash=False, progress=True)
494
+ state_dict = torch.load(cached_file, map_location='cpu')
495
+ interpolate_pos_embed(model, state_dict)
496
+
497
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
498
+ # print(incompatible_keys)
499
+
500
+ if precision == 'fp16':
501
+ # model.to("cuda")
502
+ convert_weights_to_fp16(model)
503
+ return model