evalscope 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (214) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/backend/rag_eval/__init__.py +1 -1
  3. evalscope/backend/rag_eval/backend_manager.py +21 -5
  4. evalscope/backend/rag_eval/cmteb/arguments.py +10 -0
  5. evalscope/backend/rag_eval/ragas/arguments.py +0 -1
  6. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +7 -2
  7. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +0 -5
  8. evalscope/backend/rag_eval/utils/embedding.py +49 -3
  9. evalscope/backend/rag_eval/utils/llm.py +4 -4
  10. evalscope/backend/vlm_eval_kit/backend_manager.py +4 -2
  11. evalscope/benchmarks/__init__.py +2 -2
  12. evalscope/benchmarks/aigc/__init__.py +0 -0
  13. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  14. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  15. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  16. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  17. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  18. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  19. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  20. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  21. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  22. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  23. evalscope/benchmarks/arc/arc_adapter.py +2 -2
  24. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  25. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  26. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  27. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  28. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  29. evalscope/benchmarks/data_adapter.py +21 -10
  30. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  31. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  32. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  33. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +1 -1
  34. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  35. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +5 -4
  36. evalscope/benchmarks/live_code_bench/testing_util.py +369 -550
  37. evalscope/benchmarks/maritime_bench/__init__.py +0 -0
  38. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +79 -0
  39. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  40. evalscope/benchmarks/mmlu/mmlu_adapter.py +8 -8
  41. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
  42. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +1 -1
  43. evalscope/benchmarks/musr/musr_adapter.py +1 -1
  44. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  45. evalscope/benchmarks/utils.py +7 -16
  46. evalscope/cli/start_app.py +1 -1
  47. evalscope/collections/evaluator.py +20 -6
  48. evalscope/config.py +8 -4
  49. evalscope/constants.py +11 -0
  50. evalscope/evaluator/evaluator.py +2 -2
  51. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  52. evalscope/metrics/__init__.py +49 -4
  53. evalscope/metrics/llm_judge.py +1 -1
  54. evalscope/metrics/named_metrics.py +13 -0
  55. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  56. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  57. evalscope/metrics/t2v_metrics/constants.py +12 -0
  58. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  59. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  60. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  61. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  62. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  63. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  64. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  65. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  66. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  67. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  68. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  69. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  70. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  71. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  72. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  73. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  74. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  75. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  76. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  77. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  139. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  140. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  141. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  142. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  143. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  144. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  145. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  146. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  147. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  148. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  149. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  150. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  151. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  152. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  153. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  154. evalscope/metrics/t2v_metrics/score.py +78 -0
  155. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  156. evalscope/models/__init__.py +50 -14
  157. evalscope/models/adapters/__init__.py +17 -0
  158. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  159. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  160. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  161. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  162. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  163. evalscope/models/adapters/t2i_adapter.py +76 -0
  164. evalscope/models/custom/__init__.py +2 -1
  165. evalscope/models/custom/dummy_model.py +11 -13
  166. evalscope/models/local_model.py +82 -33
  167. evalscope/models/model.py +2 -42
  168. evalscope/models/register.py +26 -0
  169. evalscope/perf/arguments.py +24 -5
  170. evalscope/perf/benchmark.py +28 -42
  171. evalscope/perf/http_client.py +2 -3
  172. evalscope/perf/plugin/api/custom_api.py +1 -1
  173. evalscope/perf/plugin/api/openai_api.py +2 -2
  174. evalscope/perf/plugin/datasets/custom.py +4 -1
  175. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  176. evalscope/perf/plugin/datasets/line_by_line.py +4 -1
  177. evalscope/perf/plugin/datasets/longalpaca.py +4 -1
  178. evalscope/perf/plugin/datasets/openqa.py +4 -1
  179. evalscope/perf/plugin/datasets/random_dataset.py +13 -6
  180. evalscope/perf/utils/benchmark_util.py +14 -8
  181. evalscope/perf/utils/db_util.py +9 -3
  182. evalscope/perf/utils/log_utils.py +41 -0
  183. evalscope/report/__init__.py +1 -0
  184. evalscope/report/app.py +128 -78
  185. evalscope/report/app_arguments.py +11 -0
  186. evalscope/report/generator.py +1 -1
  187. evalscope/run.py +10 -3
  188. evalscope/summarizer.py +2 -1
  189. evalscope/third_party/thinkbench/eval.py +19 -7
  190. evalscope/utils/chat_service.py +2 -2
  191. evalscope/utils/import_utils.py +66 -0
  192. evalscope/utils/utils.py +48 -29
  193. evalscope/version.py +2 -2
  194. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/METADATA +37 -15
  195. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/RECORD +209 -96
  196. tests/aigc/__init__.py +1 -0
  197. tests/aigc/test_t2i.py +87 -0
  198. tests/cli/test_all.py +4 -4
  199. tests/cli/test_collection.py +2 -1
  200. tests/cli/test_run.py +19 -12
  201. tests/perf/test_perf.py +3 -3
  202. tests/rag/test_clip_benchmark.py +0 -1
  203. tests/rag/test_mteb.py +37 -8
  204. tests/rag/test_ragas.py +29 -26
  205. tests/vlm/test_vlmeval.py +37 -1
  206. evalscope/backend/vlm_eval_kit/custom_dataset.py +0 -46
  207. evalscope/benchmarks/live_code_bench/execute_utils.py +0 -267
  208. evalscope/metrics/code_metric.py +0 -98
  209. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  210. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  211. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
  212. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
  213. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
  214. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,271 @@
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from collections import OrderedDict
6
+ from itertools import repeat
7
+ from torch import nn
8
+
9
+ from ..common.dist_utils import download_cached_file
10
+ from ..models.eva_vit import convert_weights_to_fp16
11
+
12
+
13
+ class Bottleneck(nn.Module):
14
+ expansion = 4
15
+
16
+ def __init__(self, inplanes, planes, stride=1):
17
+ super().__init__()
18
+
19
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
20
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
21
+ self.bn1 = nn.BatchNorm2d(planes)
22
+ self.relu1 = nn.ReLU(inplace=True)
23
+
24
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+ self.relu2 = nn.ReLU(inplace=True)
27
+
28
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
29
+
30
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
31
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
32
+ self.relu3 = nn.ReLU(inplace=True)
33
+
34
+ self.downsample = None
35
+ self.stride = stride
36
+
37
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
38
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
39
+ self.downsample = nn.Sequential(
40
+ OrderedDict([('-1', nn.AvgPool2d(stride)),
41
+ ('0', nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
42
+ ('1', nn.BatchNorm2d(planes * self.expansion))]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu1(self.bn1(self.conv1(x)))
48
+ out = self.relu2(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu3(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+
62
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
63
+ super().__init__()
64
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
65
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
69
+ self.num_heads = num_heads
70
+
71
+ def forward(self, x):
72
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
73
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
74
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
75
+ x, _ = F.multi_head_attention_forward(
76
+ query=x,
77
+ key=x,
78
+ value=x,
79
+ embed_dim_to_check=x.shape[-1],
80
+ num_heads=self.num_heads,
81
+ q_proj_weight=self.q_proj.weight,
82
+ k_proj_weight=self.k_proj.weight,
83
+ v_proj_weight=self.v_proj.weight,
84
+ in_proj_weight=None,
85
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
86
+ bias_k=None,
87
+ bias_v=None,
88
+ add_zero_attn=False,
89
+ dropout_p=0,
90
+ out_proj_weight=self.c_proj.weight,
91
+ out_proj_bias=self.c_proj.bias,
92
+ use_separate_proj_weight=True,
93
+ training=self.training,
94
+ need_weights=False)
95
+
96
+ return x[0]
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ orig_type = x.dtype
104
+ ret = super().forward(x.type(torch.float32))
105
+ return ret.type(orig_type)
106
+
107
+
108
+ class QuickGELU(nn.Module):
109
+
110
+ def forward(self, x: torch.Tensor):
111
+ return x * torch.sigmoid(1.702 * x)
112
+
113
+
114
+ class ResidualAttentionBlock(nn.Module):
115
+
116
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
117
+ super().__init__()
118
+
119
+ self.attn = nn.MultiheadAttention(d_model, n_head)
120
+ self.ln_1 = LayerNorm(d_model)
121
+ self.mlp = nn.Sequential(
122
+ OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()),
123
+ ('c_proj', nn.Linear(d_model * 4, d_model))]))
124
+ self.ln_2 = LayerNorm(d_model)
125
+ self.attn_mask = attn_mask
126
+
127
+ if use_grad_checkpointing:
128
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
129
+ self.attn = checkpoint_wrapper(self.attn)
130
+ self.mlp = checkpoint_wrapper(self.mlp)
131
+
132
+ def attention(self, x: torch.Tensor):
133
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
134
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
135
+
136
+ def forward(self, x: torch.Tensor):
137
+ x = x + self.attention(self.ln_1(x))
138
+ x = x + self.mlp(self.ln_2(x))
139
+ return x
140
+
141
+
142
+ class Transformer(nn.Module):
143
+
144
+ def __init__(self,
145
+ width: int,
146
+ layers: int,
147
+ heads: int,
148
+ attn_mask: torch.Tensor = None,
149
+ use_grad_checkpointing=False):
150
+ super().__init__()
151
+ self.width = width
152
+ self.layers = layers
153
+ self.resblocks = nn.Sequential(
154
+ *
155
+ [ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)])
156
+
157
+ def forward(self, x: torch.Tensor):
158
+ return self.resblocks(x)
159
+
160
+
161
+ class VisionTransformer(nn.Module):
162
+
163
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int,
164
+ use_grad_checkpointing: bool):
165
+ super().__init__()
166
+ self.input_resolution = input_resolution
167
+ self.num_features = width
168
+ self.num_heads = heads
169
+ self.num_patches = (input_resolution // patch_size)**2
170
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
171
+
172
+ scale = width**-0.5
173
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
174
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
175
+ self.ln_pre = LayerNorm(width)
176
+
177
+ self.transformer = Transformer(width, layers - 1, heads, use_grad_checkpointing=use_grad_checkpointing)
178
+
179
+ # self.ln_final = LayerNorm(width)
180
+
181
+ def forward(self, x: torch.Tensor):
182
+
183
+ x = self.conv1(x) # shape = [*, width, grid, grid]
184
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
185
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
186
+ x = torch.cat([
187
+ self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
188
+ x
189
+ ],
190
+ dim=1) # shape = [*, grid ** 2 + 1, width]
191
+ x = x + self.positional_embedding.to(x.dtype)
192
+ x = self.ln_pre(x)
193
+
194
+ x = x.permute(1, 0, 2) # NLD -> LND
195
+ x = self.transformer(x)
196
+ x = x.permute(1, 0, 2) # LND -> NLD
197
+
198
+ # x = self.ln_final(x)
199
+ return x
200
+
201
+
202
+ # From PyTorch internals
203
+ def _ntuple(n):
204
+
205
+ def parse(x):
206
+ if isinstance(x, collections.abc.Iterable):
207
+ return x
208
+ return tuple(repeat(x, n))
209
+
210
+ return parse
211
+
212
+
213
+ to_2tuple = _ntuple(2)
214
+
215
+
216
+ def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
217
+ # Rescale the grid of position embeddings when loading from state_dict
218
+ old_pos_embed = state_dict.get('positional_embedding', None)
219
+
220
+ grid_size = round((model.positional_embedding.shape[0] - 1)**0.5)
221
+ if old_pos_embed is None:
222
+ return
223
+ grid_size = to_2tuple(grid_size)
224
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
225
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
226
+ if new_seq_len == old_pos_embed.shape[0]:
227
+ return
228
+
229
+ if extra_tokens:
230
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
231
+ else:
232
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
233
+
234
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
235
+
236
+ print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
237
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
238
+ pos_emb_img = F.interpolate(
239
+ pos_emb_img,
240
+ size=grid_size,
241
+ mode=interpolation,
242
+ align_corners=True,
243
+ )
244
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
245
+ if pos_emb_tok is not None:
246
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
247
+ else:
248
+ new_pos_embed = pos_emb_img
249
+ state_dict['positional_embedding'] = new_pos_embed
250
+
251
+
252
+ def create_clip_vit_L(img_size=224, use_checkpoint=False, precision='fp16'):
253
+ model = VisionTransformer(
254
+ input_resolution=img_size,
255
+ patch_size=14,
256
+ width=1024,
257
+ layers=22,
258
+ heads=16,
259
+ use_grad_checkpointing=use_checkpoint,
260
+ )
261
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth'
262
+ cached_file = download_cached_file(url, check_hash=False, progress=True)
263
+ state_dict = torch.load(cached_file, map_location='cpu')
264
+ interpolate_pos_embed(model, state_dict)
265
+
266
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
267
+ # print(incompatible_keys)
268
+
269
+ if precision == 'fp16':
270
+ convert_weights_to_fp16(model)
271
+ return model