evalscope 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (214) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/backend/rag_eval/__init__.py +1 -1
  3. evalscope/backend/rag_eval/backend_manager.py +21 -5
  4. evalscope/backend/rag_eval/cmteb/arguments.py +10 -0
  5. evalscope/backend/rag_eval/ragas/arguments.py +0 -1
  6. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +7 -2
  7. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +0 -5
  8. evalscope/backend/rag_eval/utils/embedding.py +49 -3
  9. evalscope/backend/rag_eval/utils/llm.py +4 -4
  10. evalscope/backend/vlm_eval_kit/backend_manager.py +4 -2
  11. evalscope/benchmarks/__init__.py +2 -2
  12. evalscope/benchmarks/aigc/__init__.py +0 -0
  13. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  14. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  15. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  16. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  17. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  18. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  19. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  20. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  21. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  22. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  23. evalscope/benchmarks/arc/arc_adapter.py +2 -2
  24. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  25. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  26. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  27. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  28. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  29. evalscope/benchmarks/data_adapter.py +21 -10
  30. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  31. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  32. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  33. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +1 -1
  34. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  35. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +5 -4
  36. evalscope/benchmarks/live_code_bench/testing_util.py +369 -550
  37. evalscope/benchmarks/maritime_bench/__init__.py +0 -0
  38. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +79 -0
  39. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  40. evalscope/benchmarks/mmlu/mmlu_adapter.py +8 -8
  41. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
  42. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +1 -1
  43. evalscope/benchmarks/musr/musr_adapter.py +1 -1
  44. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  45. evalscope/benchmarks/utils.py +7 -16
  46. evalscope/cli/start_app.py +1 -1
  47. evalscope/collections/evaluator.py +20 -6
  48. evalscope/config.py +8 -4
  49. evalscope/constants.py +11 -0
  50. evalscope/evaluator/evaluator.py +2 -2
  51. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  52. evalscope/metrics/__init__.py +49 -4
  53. evalscope/metrics/llm_judge.py +1 -1
  54. evalscope/metrics/named_metrics.py +13 -0
  55. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  56. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  57. evalscope/metrics/t2v_metrics/constants.py +12 -0
  58. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  59. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  60. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  61. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  62. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  63. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  64. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  65. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  66. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  67. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  68. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  69. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  70. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  71. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  72. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  73. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  74. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  75. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  76. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  77. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  139. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  140. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  141. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  142. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  143. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  144. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  145. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  146. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  147. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  148. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  149. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  150. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  151. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  152. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  153. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  154. evalscope/metrics/t2v_metrics/score.py +78 -0
  155. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  156. evalscope/models/__init__.py +50 -14
  157. evalscope/models/adapters/__init__.py +17 -0
  158. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  159. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  160. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  161. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  162. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  163. evalscope/models/adapters/t2i_adapter.py +76 -0
  164. evalscope/models/custom/__init__.py +2 -1
  165. evalscope/models/custom/dummy_model.py +11 -13
  166. evalscope/models/local_model.py +82 -33
  167. evalscope/models/model.py +2 -42
  168. evalscope/models/register.py +26 -0
  169. evalscope/perf/arguments.py +24 -5
  170. evalscope/perf/benchmark.py +28 -42
  171. evalscope/perf/http_client.py +2 -3
  172. evalscope/perf/plugin/api/custom_api.py +1 -1
  173. evalscope/perf/plugin/api/openai_api.py +2 -2
  174. evalscope/perf/plugin/datasets/custom.py +4 -1
  175. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  176. evalscope/perf/plugin/datasets/line_by_line.py +4 -1
  177. evalscope/perf/plugin/datasets/longalpaca.py +4 -1
  178. evalscope/perf/plugin/datasets/openqa.py +4 -1
  179. evalscope/perf/plugin/datasets/random_dataset.py +13 -6
  180. evalscope/perf/utils/benchmark_util.py +14 -8
  181. evalscope/perf/utils/db_util.py +9 -3
  182. evalscope/perf/utils/log_utils.py +41 -0
  183. evalscope/report/__init__.py +1 -0
  184. evalscope/report/app.py +128 -78
  185. evalscope/report/app_arguments.py +11 -0
  186. evalscope/report/generator.py +1 -1
  187. evalscope/run.py +10 -3
  188. evalscope/summarizer.py +2 -1
  189. evalscope/third_party/thinkbench/eval.py +19 -7
  190. evalscope/utils/chat_service.py +2 -2
  191. evalscope/utils/import_utils.py +66 -0
  192. evalscope/utils/utils.py +48 -29
  193. evalscope/version.py +2 -2
  194. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/METADATA +37 -15
  195. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/RECORD +209 -96
  196. tests/aigc/__init__.py +1 -0
  197. tests/aigc/test_t2i.py +87 -0
  198. tests/cli/test_all.py +4 -4
  199. tests/cli/test_collection.py +2 -1
  200. tests/cli/test_run.py +19 -12
  201. tests/perf/test_perf.py +3 -3
  202. tests/rag/test_clip_benchmark.py +0 -1
  203. tests/rag/test_mteb.py +37 -8
  204. tests/rag/test_ragas.py +29 -26
  205. tests/vlm/test_vlmeval.py +37 -1
  206. evalscope/backend/vlm_eval_kit/custom_dataset.py +0 -46
  207. evalscope/benchmarks/live_code_bench/execute_utils.py +0 -267
  208. evalscope/metrics/code_metric.py +0 -98
  209. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  210. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  211. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
  212. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
  213. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
  214. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,79 @@
1
+ from typing import Any
2
+
3
+ from evalscope.benchmarks import Benchmark, DataAdapter
4
+ from evalscope.constants import EvalType, OutputType
5
+ from evalscope.metrics import exact_match
6
+ from evalscope.utils.utils import ResponseParser
7
+
8
+ SUBSET_LIST = ['default']
9
+
10
+
11
+ @Benchmark.register(
12
+ name='maritime_bench',
13
+ pretty_name='MaritimeBench',
14
+ dataset_id='HiDolphin/MaritimeBench',
15
+ model_adapter=OutputType.GENERATION,
16
+ output_types=[OutputType.MULTIPLE_CHOICE, OutputType.GENERATION],
17
+ subset_list=SUBSET_LIST,
18
+ metric_list=['AverageAccuracy'],
19
+ eval_split='test',
20
+ prompt_template=
21
+ '题目来自于{subset_name}请回答单选题。要求只输出选项,不输出解释,将选项放在<>里,直接输出答案。示例:\n\n题目:在船舶主推进动力装置中,传动轴系在运转中承受以下复杂的应力和负荷,但不包括______。\n选项:\nA. 电磁力\nB. 压拉应力\nC. 弯曲应力\nD. 扭应力\n答:<A> 当前题目\n {query}', # noqa: E501
22
+ )
23
+ class MaritimeBenchAdapter(DataAdapter):
24
+
25
+ def __init__(self, **kwargs):
26
+ super().__init__(**kwargs)
27
+
28
+ self.choices = ['A', 'B', 'C', 'D']
29
+
30
+ def gen_prompt(self, input_d: dict, subset_name: str, few_shot_list: list, **kwargs) -> Any:
31
+
32
+ prefix = ''
33
+ query = prefix + input_d['question'] + '\n'
34
+ available_choices = []
35
+ for option in self.choices:
36
+ if option in input_d and input_d[option]:
37
+ query += option + ':' + input_d[option] + '\n'
38
+ available_choices.append(option)
39
+
40
+ full_prompt = self.prompt_template.format(subset_name=subset_name, query=query)
41
+ return self.gen_prompt_data(full_prompt, choices=available_choices)
42
+
43
+ def get_gold_answer(self, input_d: dict) -> str:
44
+ """
45
+ Parse the raw input labels (gold).
46
+
47
+ Args:
48
+ input_d: input raw data. Depending on the dataset.
49
+
50
+ Returns:
51
+ The parsed input. e.g. gold answer ... Depending on the dataset.
52
+ """
53
+ return input_d['answer']
54
+
55
+ def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str:
56
+ """
57
+ Parse the raw model prediction (pred).
58
+
59
+ Args:
60
+ pred: model prediction. Depending on the model.
61
+
62
+ Returns:
63
+ The parsed prediction. e.g. model answer... Depending on the model.
64
+ """
65
+
66
+ return ResponseParser.parse_bracketed_answer(result, options=self.choices)
67
+
68
+ def match(self, gold: Any, pred: Any) -> Any:
69
+ """
70
+ Match the gold answer with the predicted answer.
71
+
72
+ Args:
73
+ gold: The gold answer.
74
+ pred: The predicted answer.
75
+
76
+ Returns:
77
+ The result of the match.
78
+ """
79
+ return exact_match(gold=gold, pred=pred)
@@ -1,5 +1,5 @@
1
1
  from evalscope.benchmarks import Benchmark, DataAdapter
2
- from evalscope.metrics.math_parser import extract_answer, math_equal, strip_answer_string
2
+ from evalscope.metrics import extract_answer, math_equal, strip_answer_string
3
3
  from evalscope.utils.logger import get_logger
4
4
 
5
5
  # flake8: noqa
@@ -137,7 +137,7 @@ SUBJECT_MAPPING = {
137
137
  name='mmlu',
138
138
  pretty_name='MMLU',
139
139
  dataset_id='modelscope/mmlu',
140
- model_adapter=OutputType.MULTIPLE_CHOICE,
140
+ model_adapter=OutputType.GENERATION,
141
141
  output_types=[OutputType.MULTIPLE_CHOICE, OutputType.GENERATION],
142
142
  subset_list=SUBSET_LIST,
143
143
  metric_list=['AverageAccuracy'],
@@ -145,7 +145,7 @@ SUBJECT_MAPPING = {
145
145
  train_split='train',
146
146
  eval_split='test',
147
147
  prompt_template=
148
- 'Answer the following multiple choice question about {subset_name}. There is only one correct answer. The last line of your response should be in the format "Answer: LETTER" (without quotes), where LETTER is one of A, B, C, D. \n{query}',
148
+ """Answer the following multiple choice question about {subset_name}. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{query}""", # noqa: E501
149
149
  )
150
150
  class MMLUAdapter(DataAdapter):
151
151
 
@@ -224,9 +224,8 @@ class MMLUAdapter(DataAdapter):
224
224
 
225
225
  context: str = '\n'.join(few_shot_prompts) + '\n'
226
226
  context += self._generate_prompt(input_d=input_d, include_answer=False)
227
- query = context.strip() + self._generate_prompt(input_d=input_d, include_answer=False)
228
227
 
229
- full_prompt = self.prompt_template.format(subset_name=self._format_subject(subset_name), query=query)
228
+ full_prompt = self.prompt_template.format(subset_name=self._format_subject(subset_name), query=context.strip())
230
229
 
231
230
  return self.gen_prompt_data(full_prompt)
232
231
 
@@ -249,7 +248,7 @@ class MMLUAdapter(DataAdapter):
249
248
  if self.model_adapter == OutputType.MULTIPLE_CHOICE:
250
249
  return result
251
250
  else:
252
- return ResponseParser.parse_first_option(result)
251
+ return ResponseParser.parse_first_option(result, options=self.choices)
253
252
 
254
253
  def match(self, gold: str, pred: str) -> float:
255
254
  return exact_match(gold=gold, pred=pred)
@@ -260,11 +259,12 @@ class MMLUAdapter(DataAdapter):
260
259
 
261
260
  example: str = input_d['input']
262
261
  for j in range(len(self.choices)):
263
- example += '\n{}. {}'.format(self.choices[j], input_choices[j])
262
+ example += f'\n{self.choices[j]}) {input_choices[j]}'
264
263
 
265
- example += '\nAnswer:'
266
264
  if include_answer:
267
- example += ' {}\n\n'.format(input_d['target'])
265
+ example += f"\nAnswer: {input_d['target']}\n\n"
266
+ else:
267
+ example += '\nAnswer: \n\n'
268
268
 
269
269
  return example
270
270
 
@@ -92,7 +92,7 @@ class MMLUProAdapter(DataAdapter):
92
92
  if self.model_adapter == OutputType.MULTIPLE_CHOICE:
93
93
  return result
94
94
  else:
95
- return ResponseParser.parse_first_option(result)
95
+ return ResponseParser.parse_first_option(result, options=self.choices)
96
96
 
97
97
  def match(self, gold: str, pred: str) -> float:
98
98
  """
@@ -164,7 +164,7 @@ class MMLUReduxAdapter(DataAdapter):
164
164
  if self.model_adapter == OutputType.MULTIPLE_CHOICE:
165
165
  return result
166
166
  else:
167
- return ResponseParser.parse_first_option(result)
167
+ return ResponseParser.parse_first_option(result, options=self.choices)
168
168
 
169
169
  def match(self, gold: str, pred: str) -> float:
170
170
  """
@@ -62,7 +62,7 @@ class MuSRAdapter(DataAdapter):
62
62
  if self.model_adapter == OutputType.MULTIPLE_CHOICE:
63
63
  return result
64
64
  else:
65
- return ResponseParser.parse_first_option(result)
65
+ return ResponseParser.parse_first_option(result, options=self.choices)
66
66
 
67
67
  def match(self, gold: str, pred: str) -> float:
68
68
  """
@@ -3,8 +3,7 @@ from collections import defaultdict
3
3
  from typing import Any, List
4
4
 
5
5
  from evalscope.benchmarks import Benchmark, DataAdapter
6
- from evalscope.metrics import Metric, mean, metric_registry
7
- from evalscope.metrics.llm_judge import LLMJudge
6
+ from evalscope.metrics import LLMJudge, Metric, mean, metric_registry
8
7
  from evalscope.utils.logger import get_logger
9
8
 
10
9
  # flake8: noqa
@@ -1,6 +1,6 @@
1
- from dataclasses import dataclass
1
+ from dataclasses import asdict, dataclass
2
2
  from functools import wraps
3
- from typing import Dict, List, Optional
3
+ from typing import Dict, List, Optional, Union
4
4
 
5
5
  from evalscope.constants import EvalType
6
6
  from evalscope.utils.filters import Filter
@@ -9,30 +9,21 @@ from evalscope.utils.filters import Filter
9
9
  @dataclass
10
10
  class PromptData:
11
11
  data: List[str]
12
- index: Optional[int] = 0
12
+ index: Optional[Union[int, str]] = 0
13
13
  system_prompt: Optional[str] = None
14
14
  multi_choices: Optional[List[str]] = None
15
+ id: Optional[str] = None
15
16
 
16
17
  def to_dict(self) -> Dict:
17
- if self.multi_choices is None:
18
- return {
19
- 'data': self.data,
20
- 'index': self.index,
21
- 'system_prompt': self.system_prompt,
22
- }
23
- else:
24
- return {
25
- 'data': self.data,
26
- 'index': self.index,
27
- 'system_prompt': self.system_prompt,
28
- 'multi_choices': self.multi_choices,
29
- }
18
+ return {k: v for k, v in asdict(self).items() if v is not None}
30
19
 
31
20
 
32
21
  def preprocess_decorator(func):
33
22
 
34
23
  @wraps(func)
35
24
  def wrapper(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT):
25
+ if result is None:
26
+ result = ''
36
27
  filters = self.config_kwargs.get('filters', None)
37
28
  if filters:
38
29
  # Apply filters to the resultply filters to the result
@@ -21,7 +21,7 @@ class StartAppCMD(CLICommand):
21
21
  def define_args(parsers: ArgumentParser):
22
22
  """ define args for create pipeline template command.
23
23
  """
24
- from evalscope.report.app import add_argument
24
+ from evalscope.report import add_argument
25
25
 
26
26
  parser = parsers.add_parser(StartAppCMD.name)
27
27
  add_argument(parser)
@@ -1,8 +1,10 @@
1
1
  import json
2
2
  import os
3
3
  import pandas as pd
4
+ import random
4
5
  from collections import defaultdict
5
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from copy import deepcopy
6
8
  from tabulate import tabulate
7
9
  from tqdm import tqdm
8
10
  from typing import List
@@ -10,7 +12,7 @@ from typing import List
10
12
  from evalscope.benchmarks import Benchmark, DataAdapter
11
13
  from evalscope.collections.sampler import DatasetEntry
12
14
  from evalscope.config import TaskConfig
13
- from evalscope.constants import AnswerKeys, DumpMode, EvalType
15
+ from evalscope.constants import AnswerKeys, DataCollection, DumpMode, EvalType
14
16
  from evalscope.evaluator import Evaluator
15
17
  from evalscope.models import initialize_model_adapter
16
18
  from evalscope.report import ReportGenerator
@@ -65,11 +67,12 @@ class EvaluatorCollection:
65
67
  self.evaluators = self._initialize_evaluators()
66
68
 
67
69
  def load(self) -> tuple[list[DatasetEntry], str]:
68
- dataset_name = os.path.basename(self.data_adapter.dataset_id).split('.')[0]
70
+ dataset_name = os.path.splitext(os.path.basename(self.data_adapter.dataset_id))[0]
69
71
  raw_dataset = self.data_adapter.load()
70
- # limit the dataset
72
+ # random limit the dataset
71
73
  if self.task_cfg.limit:
72
- raw_dataset = raw_dataset[:self.task_cfg.limit]
74
+ raw_dataset = random.sample(raw_dataset,
75
+ self.task_cfg.limit) if len(raw_dataset) > self.task_cfg.limit else raw_dataset
73
76
  # index dataset
74
77
  datasets = []
75
78
  for sample in raw_dataset:
@@ -95,10 +98,17 @@ class EvaluatorCollection:
95
98
 
96
99
  def _initialize_evaluators(self):
97
100
  evaluators = {}
101
+ # load dataset args
102
+ dataset_args = deepcopy(self.task_cfg.dataset_args)
103
+ common_args = dataset_args.get(DataCollection.NAME, {})
98
104
  for dataset_name in self.dataset_name_map.keys():
99
105
  benchmark = Benchmark.get(dataset_name)
100
106
  model_adapter = initialize_model_adapter(self.task_cfg, benchmark, self.model)
101
- data_adapter = benchmark.get_data_adapter()
107
+ # update dataset args
108
+ cur_dataset_args = dataset_args.get(dataset_name, {})
109
+ cur_dataset_args.update(common_args)
110
+ # get data adapter
111
+ data_adapter = benchmark.get_data_adapter(cur_dataset_args)
102
112
  evaluators[dataset_name] = SimpleEvaluator(dataset_name, data_adapter, model_adapter, self.task_cfg,
103
113
  self.outputs)
104
114
  return evaluators
@@ -174,6 +184,7 @@ class EvaluatorCollection:
174
184
  os.makedirs(os.path.dirname(report_file_path), exist_ok=True)
175
185
  with open(report_file_path, 'w', encoding='utf-8') as f:
176
186
  json.dump(report.to_dict(), f, ensure_ascii=False, indent=4)
187
+ return report
177
188
 
178
189
  def _filter_answer(self, pred_file_path):
179
190
  answer_dict = defaultdict(dict)
@@ -184,12 +195,14 @@ class EvaluatorCollection:
184
195
  index = answer.get(AnswerKeys.INDEX)
185
196
  answer_dict[index] = answer
186
197
  indices.add(index)
198
+
187
199
  data = []
188
200
  for sample in self.dataset:
189
201
  if sample.index not in indices:
190
202
  data.append(sample)
191
203
  data_map = self._init_name_map(data)
192
204
 
205
+ logger.info(f'Reuse from {pred_file_path}. Loaded {len(indices)} samples, remain {len(data)} samples.')
193
206
  return answer_dict, data, data_map
194
207
  return answer_dict, self.dataset, self.dataset_name_map
195
208
 
@@ -274,4 +287,5 @@ class EvaluatorCollection:
274
287
  answers = self.get_answers()
275
288
  reviews = self.get_reviews(answers)
276
289
  scores = self.get_scores(reviews)
277
- self.get_report(scores)
290
+ report = self.get_report(scores)
291
+ return report
evalscope/config.py CHANGED
@@ -4,13 +4,12 @@ import copy
4
4
  import json
5
5
  import os
6
6
  from argparse import Namespace
7
- from collections import OrderedDict
8
7
  from dataclasses import dataclass, field
9
8
  from typing import Dict, List, Optional, Union
10
9
 
11
10
  from evalscope.constants import (DEFAULT_DATASET_CACHE_DIR, DEFAULT_WORK_DIR, EvalBackend, EvalStage, EvalType, HubType,
12
- JudgeStrategy, OutputType)
13
- from evalscope.models.custom import CustomModel
11
+ JudgeStrategy, ModelTask, OutputType)
12
+ from evalscope.models import CustomModel, DummyCustomModel
14
13
  from evalscope.utils import gen_hash
15
14
  from evalscope.utils.io_utils import dict_to_yaml, json_to_dict, yaml_to_dict
16
15
  from evalscope.utils.logger import get_logger
@@ -36,6 +35,7 @@ class TaskConfig:
36
35
  model: Union[str, 'CustomModel', None] = None
37
36
  model_id: Optional[str] = None
38
37
  model_args: Optional[Dict] = field(default_factory=lambda: DEFAULT_MODEL_ARGS | {})
38
+ model_task: Optional[str] = ModelTask.TEXT_GENERATION
39
39
 
40
40
  # Template-related arguments
41
41
  template_type: Optional[str] = None # Deprecated, will be removed in v1.0.0.
@@ -79,6 +79,10 @@ class TaskConfig:
79
79
  judge_model_args: Optional[Dict] = field(default_factory=lambda: {})
80
80
 
81
81
  def __post_init__(self):
82
+ if self.model is None:
83
+ self.model = DummyCustomModel()
84
+ self.eval_type = EvalType.CUSTOM
85
+
82
86
  if (not self.model_id) and self.model:
83
87
  if isinstance(self.model, CustomModel):
84
88
  self.model_id = self.model.config.get('model_id', 'custom_model')
@@ -212,7 +216,7 @@ def parse_task_config(task_cfg) -> TaskConfig:
212
216
  logger.info('Args: Task config is provided with CommandLine type.')
213
217
  task_cfg = TaskConfig.from_args(task_cfg)
214
218
  elif isinstance(task_cfg, str):
215
- extension = task_cfg.split('.')[-1]
219
+ extension = os.path.splitext(task_cfg)[-1]
216
220
  logger.info(f'Args: Task config is provided with {extension} file type.')
217
221
  if extension in ['yaml', 'yml']:
218
222
  task_cfg = TaskConfig.from_yaml(task_cfg)
evalscope/constants.py CHANGED
@@ -1,4 +1,9 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # flake8: noqa
3
+ import os
4
+
5
+ os.environ['MODELSCOPE_LOG_LEVEL'] = '40' # Set default log level to ERROR
6
+
2
7
  from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION
3
8
  from modelscope.utils.file_utils import get_dataset_cache_root, get_model_cache_root
4
9
 
@@ -145,6 +150,7 @@ class OutputType:
145
150
  GENERATION = 'generation' # for text generation tasks and general tasks
146
151
  MULTIPLE_CHOICE = 'multiple_choice_logits' # for multiple choice tasks
147
152
  CONTINUOUS = 'continuous_logits' # for continuous tasks
153
+ IMAGE_GENERATION = 'image_generation' # for image generation tasks
148
154
 
149
155
 
150
156
  class EvalBackend:
@@ -164,3 +170,8 @@ class JudgeStrategy:
164
170
  RULE = 'rule'
165
171
  LLM = 'llm'
166
172
  LLM_RECALL = 'llm_recall'
173
+
174
+
175
+ class ModelTask:
176
+ TEXT_GENERATION = 'text_generation'
177
+ IMAGE_GENERATION = 'image_generation'
@@ -66,7 +66,7 @@ class Evaluator(object):
66
66
  if self.task_cfg.judge_strategy == JudgeStrategy.RULE:
67
67
  self.judge = None
68
68
  else:
69
- from evalscope.metrics.llm_judge import LLMJudge
69
+ from evalscope.metrics import LLMJudge
70
70
  self.judge = LLMJudge(**self.task_cfg.judge_model_args)
71
71
 
72
72
  def load_dataset(self):
@@ -281,7 +281,7 @@ class Evaluator(object):
281
281
  os.makedirs(os.path.dirname(review_file_path), exist_ok=True)
282
282
 
283
283
  if self.use_cache and os.path.exists(review_file_path):
284
- logger.warning(f'Ignore use_cache={self.use_cache}, updating the review file: {review_file_path} ...')
284
+ logger.info(f'Updating the review file: {review_file_path} ...')
285
285
  os.remove(review_file_path)
286
286
 
287
287
  def process_single_review(answer_d):
@@ -11,7 +11,7 @@ from functools import partial
11
11
  from typing import Any, List, Tuple
12
12
 
13
13
  from evalscope.constants import ArenaMode, EvalConfigKeys, FnCompletionParser, PositionBiasMitigation
14
- from evalscope.models.model import OpenAIModel
14
+ from evalscope.models import OpenAIModel
15
15
  from evalscope.utils import completion_parsers, random_seeded_choice
16
16
  from evalscope.utils.arena_utils import get_battle_pairs, merge_ques_ans, shuffle_pairwise_preferences
17
17
  from evalscope.utils.io_utils import dump_jsonl_data, jsonl_to_list
@@ -1,5 +1,50 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
- from evalscope.metrics.metrics import (bleu_ngram_one_sample, exact_match, macro_mean, mean, micro_mean,
3
- simple_f1_score, weighted_mean)
4
- from evalscope.metrics.named_metrics import *
5
- from evalscope.metrics.rouge_metric import compute_rouge_score_one_sample_zh
2
+ from typing import TYPE_CHECKING
3
+
4
+ from evalscope.utils.import_utils import _LazyModule
5
+
6
+ if TYPE_CHECKING:
7
+ from .llm_judge import LLMJudge
8
+ from .math_parser import extract_answer, math_equal, strip_answer_string
9
+ from .metrics import (bleu_ngram_one_sample, exact_match, macro_mean, mean, micro_mean, simple_f1_score,
10
+ weighted_mean)
11
+ from .named_metrics import Metric, metric_registry
12
+ from .rouge_metric import compute_rouge_score_one_sample_zh
13
+
14
+ else:
15
+ _import_structure = {
16
+ 'metrics': [
17
+ 'bleu_ngram_one_sample',
18
+ 'exact_match',
19
+ 'macro_mean',
20
+ 'mean',
21
+ 'micro_mean',
22
+ 'simple_f1_score',
23
+ 'weighted_mean',
24
+ ],
25
+ 'named_metrics': [
26
+ 'Metric',
27
+ 'metric_registry',
28
+ ],
29
+ 'rouge_metric': [
30
+ 'compute_rouge_score_one_sample_zh',
31
+ ],
32
+ 'llm_judge': [
33
+ 'LLMJudge',
34
+ ],
35
+ 'math_parser': [
36
+ 'extract_answer',
37
+ 'math_equal',
38
+ 'strip_answer_string',
39
+ ],
40
+ }
41
+
42
+ import sys
43
+
44
+ sys.modules[__name__] = _LazyModule(
45
+ __name__,
46
+ globals()['__file__'],
47
+ _import_structure,
48
+ module_spec=__spec__,
49
+ extra_objects={},
50
+ )
@@ -54,7 +54,7 @@ class LLMJudge:
54
54
  self.prompt_template = prompt_template or os.environ.get('JUDGE_PROMPT_TEMPLATE', DEFAULT_PROMPT_TEMPLATE)
55
55
  self.generation_config = generation_config
56
56
 
57
- from evalscope.models.server_adapter import ServerModelAdapter
57
+ from evalscope.models import ServerModelAdapter
58
58
 
59
59
  # Initialize ServerModelAdapter
60
60
  self.server_adapter = ServerModelAdapter(api_url=self.api_url, model_id=self.model_id, api_key=self.api_key)
@@ -3,6 +3,8 @@ from functools import partial
3
3
  from typing import Callable, Dict
4
4
 
5
5
  from evalscope.metrics.metrics import mean, pass_at_k, weighted_mean
6
+ from evalscope.metrics.t2v_metrics import (blip2_score, clip_flant5_score, clip_score, fga_blip2_score, hpsv2_1_score,
7
+ hpsv2_score, image_reward_score, mps_score, pick_score)
6
8
 
7
9
 
8
10
  @dataclass
@@ -40,3 +42,14 @@ metric_registry.register(Metric(name='WeightedAverageBLEU', object=weighted_mean
40
42
  metric_registry.register(Metric(name='AveragePass@1', object=mean))
41
43
  for k in range(1, 17):
42
44
  metric_registry.register(Metric(name=f'Pass@{k}', object=partial(pass_at_k, k=k)))
45
+
46
+ # t2v_metrics
47
+ metric_registry.register(Metric(name='VQAScore', object=clip_flant5_score))
48
+ metric_registry.register(Metric(name='PickScore', object=pick_score))
49
+ metric_registry.register(Metric(name='CLIPScore', object=clip_score))
50
+ metric_registry.register(Metric(name='BLIPv2Score', object=blip2_score))
51
+ metric_registry.register(Metric(name='HPSv2Score', object=hpsv2_score))
52
+ metric_registry.register(Metric(name='HPSv2.1Score', object=hpsv2_1_score))
53
+ metric_registry.register(Metric(name='ImageRewardScore', object=image_reward_score))
54
+ metric_registry.register(Metric(name='FGA_BLIP2Score', object=fga_blip2_score))
55
+ metric_registry.register(Metric(name='MPS', object=mps_score))
@@ -0,0 +1,66 @@
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ from .clipscore import CLIPScore, list_all_clipscore_models
4
+ from .constants import CACHE_DIR
5
+ from .itmscore import ITMScore, list_all_itmscore_models
6
+ from .vqascore import VQAScore, list_all_vqascore_models
7
+
8
+
9
+ def list_all_models():
10
+ return list_all_vqascore_models() + list_all_clipscore_models() + list_all_itmscore_models()
11
+
12
+
13
+ def get_score_model(model='clip-flant5-xxl', device='cuda', cache_dir=CACHE_DIR, **kwargs):
14
+ if model in list_all_vqascore_models():
15
+ return VQAScore(model, device=device, cache_dir=cache_dir, **kwargs)
16
+ elif model in list_all_clipscore_models():
17
+ return CLIPScore(model, device=device, cache_dir=cache_dir, **kwargs)
18
+ elif model in list_all_itmscore_models():
19
+ return ITMScore(model, device=device, cache_dir=cache_dir, **kwargs)
20
+ else:
21
+ raise NotImplementedError()
22
+
23
+
24
+ def clip_flant5_score():
25
+ clip_flant5_score = VQAScore(model='clip-flant5-xxl')
26
+ return clip_flant5_score
27
+
28
+
29
+ def pick_score():
30
+ pick_score = CLIPScore(model='pickscore-v1')
31
+ return pick_score
32
+
33
+
34
+ def clip_score():
35
+ clip_score = CLIPScore(model='openai:ViT-L-14-336')
36
+ return clip_score
37
+
38
+
39
+ def blip2_score():
40
+ blip_itm_score = ITMScore(model='blip2-itm')
41
+ return blip_itm_score
42
+
43
+
44
+ def hpsv2_score():
45
+ hpsv2_score = CLIPScore(model='hpsv2')
46
+ return hpsv2_score
47
+
48
+
49
+ def hpsv2_1_score():
50
+ hpsv2_1_score = CLIPScore(model='hpsv2.1')
51
+ return hpsv2_1_score
52
+
53
+
54
+ def image_reward_score():
55
+ image_reward_score = ITMScore(model='image-reward-v1')
56
+ return image_reward_score
57
+
58
+
59
+ def fga_blip2_score():
60
+ fga_blip2_score = ITMScore(model='fga_blip2')
61
+ return fga_blip2_score
62
+
63
+
64
+ def mps_score():
65
+ mps_score = CLIPScore(model='mps')
66
+ return mps_score
@@ -0,0 +1,14 @@
1
+ from typing import List
2
+
3
+ from .constants import CACHE_DIR
4
+ from .models.clipscore_models import get_clipscore_model, list_all_clipscore_models
5
+ from .score import Score
6
+
7
+
8
+ class CLIPScore(Score):
9
+
10
+ def prepare_scoremodel(self, model='openai:ViT-L/14', device='cuda', cache_dir=CACHE_DIR):
11
+ return get_clipscore_model(model, device=device, cache_dir=cache_dir)
12
+
13
+ def list_all_models(self) -> List[str]:
14
+ return list_all_clipscore_models()
@@ -0,0 +1,12 @@
1
+ import os
2
+ from modelscope.utils.file_utils import get_model_cache_root
3
+
4
+ CACHE_DIR = get_model_cache_root()
5
+ os.environ['TORCH_HOME'] = CACHE_DIR # set timm cache dir
6
+
7
+ # For CLIP-FlanT5
8
+ CONTEXT_LEN = 2048
9
+ SYSTEM_MSG = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
10
+ IGNORE_INDEX = -100
11
+ IMAGE_TOKEN_INDEX = -200
12
+ DEFAULT_IMAGE_TOKEN = '<image>'
@@ -0,0 +1,14 @@
1
+ from typing import List
2
+
3
+ from .constants import CACHE_DIR
4
+ from .models.itmscore_models import get_itmscore_model, list_all_itmscore_models
5
+ from .score import Score
6
+
7
+
8
+ class ITMScore(Score):
9
+
10
+ def prepare_scoremodel(self, model='blip2-itm', device='cuda', cache_dir=CACHE_DIR):
11
+ return get_itmscore_model(model, device=device, cache_dir=cache_dir)
12
+
13
+ def list_all_models(self) -> List[str]:
14
+ return list_all_itmscore_models()
File without changes
@@ -0,0 +1,30 @@
1
+ from ...constants import CACHE_DIR
2
+ from .clip_model import CLIP_MODELS, CLIPScoreModel
3
+ from .hpsv2_model import HPSV2_MODELS, HPSV2ScoreModel
4
+ from .mps_model import MPS_MODELS, MPSModel
5
+ from .pickscore_model import PICKSCORE_MODELS, PickScoreModel
6
+
7
+ ALL_CLIP_MODELS = [
8
+ CLIP_MODELS,
9
+ HPSV2_MODELS,
10
+ PICKSCORE_MODELS,
11
+ MPS_MODELS,
12
+ ]
13
+
14
+
15
+ def list_all_clipscore_models():
16
+ return [model for models in ALL_CLIP_MODELS for model in models]
17
+
18
+
19
+ def get_clipscore_model(model_name, device='cuda', cache_dir=CACHE_DIR):
20
+ assert model_name in list_all_clipscore_models()
21
+ if model_name in CLIP_MODELS:
22
+ return CLIPScoreModel(model_name, device=device, cache_dir=cache_dir)
23
+ elif model_name in HPSV2_MODELS:
24
+ return HPSV2ScoreModel(model_name, device=device, cache_dir=cache_dir)
25
+ elif model_name in PICKSCORE_MODELS:
26
+ return PickScoreModel(model_name, device=device, cache_dir=cache_dir)
27
+ elif model_name in MPS_MODELS:
28
+ return MPSModel(model_name, device=device, cache_dir=cache_dir)
29
+ else:
30
+ raise NotImplementedError()