evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from torch import einsum, nn
4
+ from transformers import CLIPConfig
5
+ from transformers import CLIPModel as HFCLIPModel
6
+ from typing import Any, Optional, Tuple, Union
7
+
8
+ from .base_model import BaseModelConfig
9
+ from .cross_modeling import Cross_model
10
+
11
+
12
+ class XCLIPModel(HFCLIPModel):
13
+
14
+ def __init__(self, config: CLIPConfig):
15
+ super().__init__(config)
16
+
17
+ def get_text_features(
18
+ self,
19
+ input_ids: Optional[torch.Tensor] = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.Tensor] = None,
22
+ output_attentions: Optional[bool] = None,
23
+ output_hidden_states: Optional[bool] = None,
24
+ return_dict: Optional[bool] = None,
25
+ ) -> torch.FloatTensor:
26
+
27
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
28
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
29
+ output_hidden_states = (
30
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
31
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
32
+
33
+ text_outputs = self.text_model(
34
+ input_ids=input_ids,
35
+ attention_mask=attention_mask,
36
+ position_ids=position_ids,
37
+ output_attentions=output_attentions,
38
+ output_hidden_states=output_hidden_states,
39
+ return_dict=return_dict,
40
+ )
41
+
42
+ # pooled_output = text_outputs[1]
43
+ # text_features = self.text_projection(pooled_output)
44
+ last_hidden_state = text_outputs[0]
45
+ text_features = self.text_projection(last_hidden_state)
46
+
47
+ pooled_output = text_outputs[1]
48
+ text_features_EOS = self.text_projection(pooled_output)
49
+
50
+ # del last_hidden_state, text_outputs
51
+ # gc.collect()
52
+
53
+ return text_features, text_features_EOS
54
+
55
+ def get_image_features(
56
+ self,
57
+ pixel_values: Optional[torch.FloatTensor] = None,
58
+ output_attentions: Optional[bool] = None,
59
+ output_hidden_states: Optional[bool] = None,
60
+ return_dict: Optional[bool] = None,
61
+ ) -> torch.FloatTensor:
62
+
63
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
+ output_hidden_states = (
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
67
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
68
+
69
+ vision_outputs = self.vision_model(
70
+ pixel_values=pixel_values,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ # pooled_output = vision_outputs[1] # pooled_output
77
+ # image_features = self.visual_projection(pooled_output)
78
+ last_hidden_state = vision_outputs[0]
79
+ image_features = self.visual_projection(last_hidden_state)
80
+
81
+ return image_features
82
+
83
+
84
+ @dataclass
85
+ class ClipModelConfig(BaseModelConfig):
86
+ _target_: str = 'trainer.models.clip_model.CLIPModel'
87
+ pretrained_model_name_or_path: str = 'openai/clip-vit-base-patch32'
88
+
89
+
90
+ class CLIPModel(nn.Module):
91
+
92
+ def __init__(self, config):
93
+ super().__init__()
94
+ self.model = XCLIPModel._from_config(config)
95
+ self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
96
+
97
+ def get_text_features(self, *args, **kwargs):
98
+ return self.model.get_text_features(*args, **kwargs)
99
+
100
+ def get_image_features(self, *args, **kwargs):
101
+ return self.model.get_image_features(*args, **kwargs)
102
+
103
+ def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
104
+ outputs = ()
105
+
106
+ text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
107
+ outputs += text_EOS,
108
+
109
+ image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
110
+ condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
111
+
112
+ sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
113
+ sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
114
+ sim_text_condition = sim_text_condition / sim_text_condition.max()
115
+ mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
116
+
117
+ mask = mask.repeat(1, image_f.shape[1], 1) # B*257*77
118
+ bc = int(image_f.shape[0] / 2)
119
+
120
+ sim0 = self.cross_model(image_f[:bc, :, :], text_f, mask.half())
121
+ sim1 = self.cross_model(image_f[bc:, :, :], text_f, mask.half())
122
+ outputs += sim0[:, 0, :],
123
+ outputs += sim1[:, 0, :],
124
+
125
+ return outputs
126
+
127
+ @property
128
+ def logit_scale(self):
129
+ return self.model.logit_scale
130
+
131
+ def save(self, path):
132
+ self.model.save_pretrained(path)
@@ -0,0 +1,286 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange, repeat
4
+ from torch import einsum, nn
5
+
6
+ # helper functions
7
+
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+
17
+ # normalization
18
+ # they use layernorm without bias, something that pytorch does not offer
19
+
20
+
21
+ class LayerNorm(nn.Module):
22
+
23
+ def __init__(self, dim):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.ones(dim))
26
+ self.register_buffer('bias', torch.zeros(dim))
27
+
28
+ def forward(self, x):
29
+ return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
30
+
31
+
32
+ # residual
33
+
34
+
35
+ class Residual(nn.Module):
36
+
37
+ def __init__(self, fn):
38
+ super().__init__()
39
+ self.fn = fn
40
+
41
+ def forward(self, x, *args, **kwargs):
42
+ return self.fn(x, *args, **kwargs) + x
43
+
44
+
45
+ # rotary positional embedding
46
+ # https://arxiv.org/abs/2104.09864
47
+
48
+
49
+ class RotaryEmbedding(nn.Module):
50
+
51
+ def __init__(self, dim):
52
+ super().__init__()
53
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
54
+ self.register_buffer('inv_freq', inv_freq)
55
+
56
+ def forward(self, max_seq_len, *, device):
57
+ seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
58
+ freqs = einsum('i , j -> i j', seq, self.inv_freq)
59
+ return torch.cat((freqs, freqs), dim=-1)
60
+
61
+
62
+ def rotate_half(x):
63
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
64
+ x1, x2 = x.unbind(dim=-2)
65
+ return torch.cat((-x2, x1), dim=-1)
66
+
67
+
68
+ def apply_rotary_pos_emb(pos, t):
69
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())
70
+
71
+
72
+ # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
73
+ # https://arxiv.org/abs/2002.05202
74
+
75
+
76
+ class SwiGLU(nn.Module):
77
+
78
+ def forward(self, x):
79
+ x, gate = x.chunk(2, dim=-1)
80
+ return F.silu(gate) * x
81
+
82
+
83
+ # parallel attention and feedforward with residual
84
+ # discovered by Wang et al + EleutherAI from GPT-J fame
85
+
86
+
87
+ class ParallelTransformerBlock(nn.Module):
88
+
89
+ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
90
+ super().__init__()
91
+ self.norm = LayerNorm(dim)
92
+
93
+ attn_inner_dim = dim_head * heads
94
+ ff_inner_dim = dim * ff_mult
95
+ self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
96
+
97
+ self.heads = heads
98
+ self.scale = dim_head**-0.5
99
+ self.rotary_emb = RotaryEmbedding(dim_head)
100
+
101
+ self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
102
+ self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
103
+
104
+ self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False))
105
+
106
+ self.register_buffer('pos_emb', None, persistent=False)
107
+
108
+ def get_rotary_embedding(self, n, device):
109
+ if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
110
+ return self.pos_emb[:n]
111
+
112
+ pos_emb = self.rotary_emb(n, device=device)
113
+ self.register_buffer('pos_emb', pos_emb, persistent=False)
114
+ return pos_emb
115
+
116
+ def forward(self, x, attn_mask=None):
117
+ """
118
+ einstein notation
119
+ b - batch
120
+ h - heads
121
+ n, i, j - sequence length (base sequence length, source, target)
122
+ d - feature dimension
123
+ """
124
+
125
+ n, device, h = x.shape[1], x.device, self.heads
126
+
127
+ # pre layernorm
128
+
129
+ x = self.norm(x)
130
+
131
+ # attention queries, keys, values, and feedforward inner
132
+
133
+ q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
134
+
135
+ # split heads
136
+ # they use multi-query single-key-value attention, yet another Noam Shazeer paper
137
+ # they found no performance loss past a certain scale, and more efficient decoding obviously
138
+ # https://arxiv.org/abs/1911.02150
139
+
140
+ q = rearrange(q, 'b n (h d) -> b h n d', h=h)
141
+
142
+ # rotary embeddings
143
+
144
+ positions = self.get_rotary_embedding(n, device)
145
+ q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
146
+
147
+ # scale
148
+
149
+ q = q * self.scale
150
+
151
+ # similarity
152
+
153
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
154
+
155
+ # extra attention mask - for masking out attention from text CLS token to padding
156
+
157
+ if exists(attn_mask):
158
+ attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
159
+ sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
160
+
161
+ # attention
162
+
163
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
164
+ attn = sim.softmax(dim=-1)
165
+
166
+ # aggregate values
167
+
168
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
169
+
170
+ # merge heads
171
+
172
+ out = rearrange(out, 'b h n d -> b n (h d)')
173
+ return self.attn_out(out) + self.ff_out(ff)
174
+
175
+
176
+ # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
177
+
178
+
179
+ class CrossAttention(nn.Module):
180
+
181
+ def __init__(self,
182
+ dim,
183
+ *,
184
+ context_dim=None,
185
+ dim_head=64,
186
+ heads=12,
187
+ parallel_ff=False,
188
+ ff_mult=4,
189
+ norm_context=False):
190
+ super().__init__()
191
+ self.heads = heads
192
+ self.scale = dim_head**-0.5
193
+ inner_dim = heads * dim_head
194
+ context_dim = default(context_dim, dim)
195
+
196
+ self.norm = LayerNorm(dim)
197
+ self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
198
+
199
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
200
+ self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
201
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
202
+
203
+ # whether to have parallel feedforward
204
+
205
+ ff_inner_dim = ff_mult * dim
206
+
207
+ self.ff = nn.Sequential(
208
+ nn.Linear(dim, ff_inner_dim
209
+ * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) if parallel_ff else None
210
+
211
+ def forward(self, x, context, mask):
212
+ """
213
+ einstein notation
214
+ b - batch
215
+ h - heads
216
+ n, i, j - sequence length (base sequence length, source, target)
217
+ d - feature dimension
218
+ """
219
+
220
+ # pre-layernorm, for queries and context
221
+
222
+ x = self.norm(x)
223
+ context = self.context_norm(context)
224
+
225
+ # get queries
226
+
227
+ q = self.to_q(x)
228
+ q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
229
+
230
+ # scale
231
+
232
+ q = q * self.scale
233
+
234
+ # get key / values
235
+
236
+ k, v = self.to_kv(context).chunk(2, dim=-1)
237
+
238
+ # query / key similarity
239
+
240
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
241
+
242
+ # attention
243
+ mask = mask.unsqueeze(1).repeat(1, self.heads, 1, 1)
244
+ sim = sim + mask # context mask
245
+ sim = sim - sim.amax(dim=-1, keepdim=True)
246
+ attn = sim.softmax(dim=-1)
247
+
248
+ # aggregate
249
+
250
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
251
+
252
+ # merge and combine heads
253
+
254
+ out = rearrange(out, 'b h n d -> b n (h d)')
255
+ out = self.to_out(out)
256
+
257
+ # add parallel feedforward (for multimodal layers)
258
+
259
+ if exists(self.ff):
260
+ out = out + self.ff(x)
261
+
262
+ return out
263
+
264
+
265
+ class Cross_model(nn.Module):
266
+
267
+ def __init__(self, dim=512, layer_num=4, dim_head=64, heads=8, ff_mult=4):
268
+ super().__init__()
269
+
270
+ self.layers = nn.ModuleList([])
271
+
272
+ for ind in range(layer_num):
273
+ self.layers.append(
274
+ nn.ModuleList([
275
+ Residual(
276
+ CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
277
+ Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
278
+ ]))
279
+
280
+ def forward(self, query_tokens, context_tokens, mask):
281
+
282
+ for cross_attn, self_attn_ff in self.layers:
283
+ query_tokens = cross_attn(query_tokens, context_tokens, mask)
284
+ query_tokens = self_attn_ff(query_tokens)
285
+
286
+ return query_tokens
@@ -0,0 +1,114 @@
1
+ import os
2
+ import torch
3
+ from typing import List
4
+
5
+ from ...constants import CACHE_DIR
6
+ from ..model import ScoreModel
7
+
8
+ CLIP_MODELS = [
9
+ 'openai:RN50', 'yfcc15m:RN50', 'cc12m:RN50', 'openai:RN101', 'yfcc15m:RN101', 'openai:RN50x4', 'openai:RN50x16',
10
+ 'openai:RN50x64', 'openai:ViT-B-32', 'laion400m_e31:ViT-B-32', 'laion400m_e32:ViT-B-32', 'laion2b_e16:ViT-B-32',
11
+ 'laion2b_s34b_b79k:ViT-B-32', 'datacomp_xl_s13b_b90k:ViT-B-32', 'datacomp_m_s128m_b4k:ViT-B-32',
12
+ 'commonpool_m_clip_s128m_b4k:ViT-B-32', 'commonpool_m_laion_s128m_b4k:ViT-B-32',
13
+ 'commonpool_m_image_s128m_b4k:ViT-B-32', 'commonpool_m_text_s128m_b4k:ViT-B-32',
14
+ 'commonpool_m_basic_s128m_b4k:ViT-B-32', 'commonpool_m_s128m_b4k:ViT-B-32', 'datacomp_s_s13m_b4k:ViT-B-32',
15
+ 'commonpool_s_clip_s13m_b4k:ViT-B-32', 'commonpool_s_laion_s13m_b4k:ViT-B-32',
16
+ 'commonpool_s_image_s13m_b4k:ViT-B-32', 'commonpool_s_text_s13m_b4k:ViT-B-32',
17
+ 'commonpool_s_basic_s13m_b4k:ViT-B-32', 'commonpool_s_s13m_b4k:ViT-B-32', 'metaclip_400m:ViT-B-32',
18
+ 'metaclip_fullcc:ViT-B-32', 'datacomp_s34b_b86k:ViT-B-32-256', 'openai:ViT-B-16', 'laion400m_e31:ViT-B-16',
19
+ 'laion400m_e32:ViT-B-16', 'laion2b_s34b_b88k:ViT-B-16', 'datacomp_xl_s13b_b90k:ViT-B-16',
20
+ 'datacomp_l_s1b_b8k:ViT-B-16', 'commonpool_l_clip_s1b_b8k:ViT-B-16', 'commonpool_l_laion_s1b_b8k:ViT-B-16',
21
+ 'commonpool_l_image_s1b_b8k:ViT-B-16', 'commonpool_l_text_s1b_b8k:ViT-B-16', 'commonpool_l_basic_s1b_b8k:ViT-B-16',
22
+ 'commonpool_l_s1b_b8k:ViT-B-16', 'dfn2b:ViT-B-16', 'metaclip_400m:ViT-B-16', 'metaclip_fullcc:ViT-B-16',
23
+ 'laion400m_e31:ViT-B-16-plus-240', 'laion400m_e32:ViT-B-16-plus-240', 'openai:ViT-L-14', 'laion400m_e31:ViT-L-14',
24
+ 'laion400m_e32:ViT-L-14', 'laion2b_s32b_b82k:ViT-L-14', 'datacomp_xl_s13b_b90k:ViT-L-14',
25
+ 'commonpool_xl_clip_s13b_b90k:ViT-L-14', 'commonpool_xl_laion_s13b_b90k:ViT-L-14',
26
+ 'commonpool_xl_s13b_b90k:ViT-L-14', 'metaclip_400m:ViT-L-14', 'metaclip_fullcc:ViT-L-14', 'dfn2b:ViT-L-14',
27
+ 'dfn2b_s39b:ViT-L-14', 'openai:ViT-L-14-336', 'laion2b_s32b_b79k:ViT-H-14', 'metaclip_fullcc:ViT-H-14',
28
+ 'metaclip_altogether:ViT-H-14', 'dfn5b:ViT-H-14', 'dfn5b:ViT-H-14-378', 'laion2b_s12b_b42k:ViT-g-14',
29
+ 'laion2b_s34b_b88k:ViT-g-14', 'laion2b_s39b_b160k:ViT-bigG-14', 'metaclip_fullcc:ViT-bigG-14',
30
+ 'laion2b_s12b_b32k:roberta-ViT-B-32', 'laion5b_s13b_b90k:xlm-roberta-base-ViT-B-32',
31
+ 'frozen_laion5b_s13b_b90k:xlm-roberta-large-ViT-H-14', 'laion400m_s13b_b51k:convnext_base',
32
+ 'laion2b_s13b_b82k:convnext_base_w', 'laion2b_s13b_b82k_augreg:convnext_base_w',
33
+ 'laion_aesthetic_s13b_b82k:convnext_base_w', 'laion_aesthetic_s13b_b82k:convnext_base_w_320',
34
+ 'laion_aesthetic_s13b_b82k_augreg:convnext_base_w_320', 'laion2b_s26b_b102k_augreg:convnext_large_d',
35
+ 'laion2b_s29b_b131k_ft:convnext_large_d_320', 'laion2b_s29b_b131k_ft_soup:convnext_large_d_320',
36
+ 'laion2b_s34b_b82k_augreg:convnext_xxlarge', 'laion2b_s34b_b82k_augreg_rewind:convnext_xxlarge',
37
+ 'laion2b_s34b_b82k_augreg_soup:convnext_xxlarge', 'laion2b_s13b_b90k:coca_ViT-B-32',
38
+ 'mscoco_finetuned_laion2b_s13b_b90k:coca_ViT-B-32', 'laion2b_s13b_b90k:coca_ViT-L-14',
39
+ 'mscoco_finetuned_laion2b_s13b_b90k:coca_ViT-L-14', 'laion400m_s11b_b41k:EVA01-g-14',
40
+ 'merged2b_s11b_b114k:EVA01-g-14-plus', 'merged2b_s8b_b131k:EVA02-B-16', 'merged2b_s4b_b131k:EVA02-L-14',
41
+ 'merged2b_s6b_b61k:EVA02-L-14-336', 'laion2b_s4b_b115k:EVA02-E-14', 'laion2b_s9b_b144k:EVA02-E-14-plus',
42
+ 'webli:ViT-B-16-SigLIP', 'webli:ViT-B-16-SigLIP-256', 'webli:ViT-B-16-SigLIP-i18n-256', 'webli:ViT-B-16-SigLIP-384',
43
+ 'webli:ViT-B-16-SigLIP-512', 'webli:ViT-L-16-SigLIP-256', 'webli:ViT-L-16-SigLIP-384', 'webli:ViT-SO400M-14-SigLIP',
44
+ 'webli:ViT-SO400M-16-SigLIP-i18n-256', 'webli:ViT-SO400M-14-SigLIP-378', 'webli:ViT-SO400M-14-SigLIP-384',
45
+ 'webli:ViT-B-32-SigLIP2-256', 'webli:ViT-B-16-SigLIP2', 'webli:ViT-B-16-SigLIP2-256', 'webli:ViT-B-16-SigLIP2-384',
46
+ 'webli:ViT-B-16-SigLIP2-512', 'webli:ViT-L-16-SigLIP2-256', 'webli:ViT-L-16-SigLIP2-384',
47
+ 'webli:ViT-L-16-SigLIP2-512', 'webli:ViT-SO400M-14-SigLIP2', 'webli:ViT-SO400M-14-SigLIP2-378',
48
+ 'webli:ViT-SO400M-16-SigLIP2-256', 'webli:ViT-SO400M-16-SigLIP2-384', 'webli:ViT-SO400M-16-SigLIP2-512',
49
+ 'webli:ViT-gopt-16-SigLIP2-256', 'webli:ViT-gopt-16-SigLIP2-384', 'datacomp1b:ViT-L-14-CLIPA',
50
+ 'datacomp1b:ViT-L-14-CLIPA-336', 'datacomp1b:ViT-H-14-CLIPA', 'laion2b:ViT-H-14-CLIPA-336',
51
+ 'datacomp1b:ViT-H-14-CLIPA-336', 'datacomp1b:ViT-bigG-14-CLIPA', 'datacomp1b:ViT-bigG-14-CLIPA-336',
52
+ 'v1:nllb-clip-base', 'v1:nllb-clip-large', 'v1:nllb-clip-base-siglip', 'mrl:nllb-clip-base-siglip',
53
+ 'v1:nllb-clip-large-siglip', 'mrl:nllb-clip-large-siglip', 'datacompdr:MobileCLIP-S1', 'datacompdr:MobileCLIP-S2',
54
+ 'datacompdr:MobileCLIP-B', 'datacompdr_lt:MobileCLIP-B', 'datacomp1b:ViTamin-S', 'datacomp1b:ViTamin-S-LTT',
55
+ 'datacomp1b:ViTamin-B', 'datacomp1b:ViTamin-B-LTT', 'datacomp1b:ViTamin-L', 'datacomp1b:ViTamin-L-256',
56
+ 'datacomp1b:ViTamin-L-336', 'datacomp1b:ViTamin-L-384', 'datacomp1b:ViTamin-L2', 'datacomp1b:ViTamin-L2-256',
57
+ 'datacomp1b:ViTamin-L2-336', 'datacomp1b:ViTamin-L2-384', 'datacomp1b:ViTamin-XL-256', 'datacomp1b:ViTamin-XL-336',
58
+ 'datacomp1b:ViTamin-XL-384', 'openai:RN50-quickgelu', 'yfcc15m:RN50-quickgelu', 'cc12m:RN50-quickgelu',
59
+ 'openai:RN101-quickgelu', 'yfcc15m:RN101-quickgelu', 'openai:RN50x4-quickgelu', 'openai:RN50x16-quickgelu',
60
+ 'openai:RN50x64-quickgelu', 'openai:ViT-B-32-quickgelu', 'laion400m_e31:ViT-B-32-quickgelu',
61
+ 'laion400m_e32:ViT-B-32-quickgelu', 'metaclip_400m:ViT-B-32-quickgelu', 'metaclip_fullcc:ViT-B-32-quickgelu',
62
+ 'openai:ViT-B-16-quickgelu', 'dfn2b:ViT-B-16-quickgelu', 'metaclip_400m:ViT-B-16-quickgelu',
63
+ 'metaclip_fullcc:ViT-B-16-quickgelu', 'openai:ViT-L-14-quickgelu', 'metaclip_400m:ViT-L-14-quickgelu',
64
+ 'metaclip_fullcc:ViT-L-14-quickgelu', 'dfn2b:ViT-L-14-quickgelu', 'openai:ViT-L-14-336-quickgelu',
65
+ 'metaclip_fullcc:ViT-H-14-quickgelu', 'dfn5b:ViT-H-14-quickgelu', 'dfn5b:ViT-H-14-378-quickgelu',
66
+ 'metaclip_fullcc:ViT-bigG-14-quickgelu'
67
+ ] # noqa: E501
68
+
69
+
70
+ class CLIPScoreModel(ScoreModel):
71
+ "A wrapper for OpenCLIP models (including openAI's CLIP, OpenCLIP, DatacompCLIP)"
72
+
73
+ def __init__(self, model_name='openai:ViT-L-14', device='cuda', cache_dir=CACHE_DIR):
74
+ assert model_name in CLIP_MODELS
75
+ super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
76
+
77
+ def load_model(self):
78
+ """Load the model, tokenizer, image transform
79
+ """
80
+ import open_clip
81
+
82
+ from ..utils import download_open_clip_model
83
+
84
+ self.pretrained, self.arch = self.model_name.split(':')
85
+ # load model from modelscope
86
+ model_file_path = download_open_clip_model(self.arch, self.pretrained, self.cache_dir)
87
+
88
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
89
+ self.arch, pretrained=model_file_path, device=self.device)
90
+ self.tokenizer = open_clip.get_tokenizer(self.arch)
91
+ self.model.eval()
92
+
93
+ def load_images(self, image: List[str]) -> torch.Tensor:
94
+ """Load the image(s), and return a tensor (after preprocessing) put on self.device
95
+ """
96
+ image = [self.image_loader(x) for x in image]
97
+ image = [self.preprocess(x) for x in image]
98
+ image = torch.stack(image, dim=0).to(self.device)
99
+ return image
100
+
101
+ @torch.no_grad()
102
+ def forward(self, images: List[str], texts: List[str]) -> torch.Tensor:
103
+ """Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
104
+ """
105
+ assert len(images) == len(texts)
106
+ image = self.load_images(images)
107
+ text = self.tokenizer(texts).to(self.device)
108
+ image_features = self.model.encode_image(image)
109
+ image_features /= image_features.norm(dim=-1, keepdim=True)
110
+ text_features = self.model.encode_text(text)
111
+ text_features /= text_features.norm(dim=-1, keepdim=True)
112
+
113
+ # return cosine similarity as scores
114
+ return (image_features * text_features).sum(dim=-1)
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from typing import List
3
+
4
+ from ...constants import CACHE_DIR
5
+ from ..model import ScoreModel
6
+
7
+ HPSV2_MODELS = ['hpsv2', 'hpsv2.1']
8
+ HPS_VERSION_MAP = {
9
+ 'hpsv2': 'HPS_v2_compressed.pt',
10
+ 'hpsv2.1': 'HPS_v2.1_compressed.pt',
11
+ }
12
+
13
+
14
+ class HPSV2ScoreModel(ScoreModel):
15
+ 'A wrapper for HPSv2 models '
16
+
17
+ def __init__(self, model_name='openai:ViT-L-14', device='cuda', cache_dir=CACHE_DIR):
18
+ assert model_name in HPSV2_MODELS
19
+ super().__init__(model_name=model_name, device=device, cache_dir=cache_dir)
20
+
21
+ def load_model(self):
22
+ """Load the model, tokenizer, image transform
23
+ """
24
+ import open_clip
25
+
26
+ from ..utils import download_file, download_open_clip_model
27
+
28
+ self.pretrained, self.arch = 'laion2B-s32B-b79K:ViT-H-14'.split(':')
29
+ # load model from modelscope
30
+ model_file_path = download_open_clip_model(self.arch, self.pretrained, self.cache_dir)
31
+
32
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
33
+ self.arch,
34
+ pretrained=model_file_path,
35
+ precision='amp',
36
+ device=self.device,
37
+ jit=False,
38
+ force_quick_gelu=False,
39
+ force_custom_text=False,
40
+ force_patch_dropout=False,
41
+ force_image_size=None,
42
+ pretrained_image=False,
43
+ image_mean=None,
44
+ image_std=None,
45
+ image_resize_mode='longest',
46
+ aug_cfg={},
47
+ output_dict=True)
48
+
49
+ # update weight
50
+ model_weight_path = download_file('AI-ModelScope/HPSv2', HPS_VERSION_MAP[self.model_name], self.cache_dir)
51
+ checkpoint = torch.load(model_weight_path, map_location=self.device)
52
+ self.model.load_state_dict(checkpoint['state_dict'])
53
+ self.tokenizer = open_clip.get_tokenizer(self.arch)
54
+ self.model.eval()
55
+
56
+ def load_images(self, image: List[str]):
57
+ """Load the image(s), and return a tensor (after preprocessing) put on self.device
58
+ """
59
+ images = [self.image_loader(x) for x in image]
60
+ return images
61
+
62
+ @torch.no_grad()
63
+ def forward(self, images: List[str], texts: List[str]) -> torch.Tensor:
64
+ """Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
65
+ """
66
+ assert len(images) == len(texts)
67
+ images = self.load_images(images)
68
+ scores = torch.zeros(len(images), dtype=torch.float16).to(self.device)
69
+ for i in range(len(images)):
70
+ caption = texts[i]
71
+ image = images[i]
72
+ # Process the image
73
+ image = self.preprocess(image).unsqueeze(0).to(device=self.device, non_blocking=True)
74
+ # Process the prompt
75
+ text = self.tokenizer([caption]).to(device=self.device, non_blocking=True) # Updated to use texts[i]
76
+ # Calculate the HPS
77
+ with torch.amp.autocast(device_type=self.device):
78
+ outputs = self.model(image, text)
79
+ image_features, text_features = outputs['image_features'], outputs['text_features']
80
+ logits_per_image = image_features @ text_features.T
81
+
82
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
83
+ scores[i] = float(hps_score[0])
84
+
85
+ # return cosine similarity as scores
86
+ return scores