evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (302) hide show
  1. evalscope/__init__.py +4 -1
  2. evalscope/api/benchmark/__init__.py +3 -0
  3. evalscope/api/benchmark/adapters/__init__.py +5 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
  7. evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
  8. evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
  9. evalscope/api/benchmark/benchmark.py +356 -0
  10. evalscope/api/benchmark/meta.py +121 -0
  11. evalscope/api/dataset/__init__.py +2 -0
  12. evalscope/api/dataset/dataset.py +349 -0
  13. evalscope/api/dataset/loader.py +262 -0
  14. evalscope/api/dataset/utils.py +143 -0
  15. evalscope/api/evaluator/__init__.py +3 -0
  16. evalscope/api/evaluator/cache.py +378 -0
  17. evalscope/api/evaluator/evaluator.py +56 -0
  18. evalscope/api/evaluator/state.py +275 -0
  19. evalscope/api/filter/__init__.py +1 -0
  20. evalscope/api/filter/filter.py +72 -0
  21. evalscope/api/messages/__init__.py +12 -0
  22. evalscope/api/messages/chat_message.py +243 -0
  23. evalscope/api/messages/content.py +102 -0
  24. evalscope/api/messages/utils.py +35 -0
  25. evalscope/api/metric/__init__.py +2 -0
  26. evalscope/api/metric/metric.py +55 -0
  27. evalscope/api/metric/scorer.py +113 -0
  28. evalscope/api/mixin/__init__.py +1 -0
  29. evalscope/api/mixin/llm_judge_mixin.py +168 -0
  30. evalscope/api/model/__init__.py +12 -0
  31. evalscope/api/model/generate_config.py +155 -0
  32. evalscope/api/model/model.py +386 -0
  33. evalscope/api/model/model_output.py +285 -0
  34. evalscope/api/registry.py +182 -0
  35. evalscope/api/tool/__init__.py +3 -0
  36. evalscope/api/tool/tool_call.py +101 -0
  37. evalscope/api/tool/tool_info.py +173 -0
  38. evalscope/api/tool/utils.py +64 -0
  39. evalscope/app/app.py +3 -0
  40. evalscope/app/ui/app_ui.py +2 -1
  41. evalscope/app/ui/multi_model.py +50 -25
  42. evalscope/app/ui/single_model.py +26 -14
  43. evalscope/app/utils/data_utils.py +43 -27
  44. evalscope/app/utils/env_utils.py +12 -0
  45. evalscope/app/utils/text_utils.py +14 -14
  46. evalscope/app/utils/visualization.py +9 -4
  47. evalscope/arguments.py +7 -10
  48. evalscope/backend/opencompass/api_meta_template.py +2 -1
  49. evalscope/backend/opencompass/backend_manager.py +6 -5
  50. evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
  51. evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
  52. evalscope/backend/rag_eval/ragas/task_template.py +2 -1
  53. evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
  54. evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
  55. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
  56. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
  57. evalscope/backend/rag_eval/utils/embedding.py +10 -1
  58. evalscope/backend/rag_eval/utils/llm.py +13 -12
  59. evalscope/benchmarks/__init__.py +0 -2
  60. evalscope/benchmarks/aime/aime24_adapter.py +38 -40
  61. evalscope/benchmarks/aime/aime25_adapter.py +34 -40
  62. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
  63. evalscope/benchmarks/arc/arc_adapter.py +34 -147
  64. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
  65. evalscope/benchmarks/arena_hard/utils.py +37 -1
  66. evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
  67. evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
  68. evalscope/benchmarks/bfcl/generation.py +222 -0
  69. evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
  70. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
  71. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
  72. evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
  73. evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
  74. evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
  75. evalscope/benchmarks/docmath/utils.py +4 -5
  76. evalscope/benchmarks/drop/drop_adapter.py +88 -40
  77. evalscope/benchmarks/frames/frames_adapter.py +136 -52
  78. evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
  79. evalscope/benchmarks/general_arena/utils.py +23 -27
  80. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
  81. evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
  82. evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
  83. evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
  84. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
  85. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
  86. evalscope/benchmarks/hle/hle_adapter.py +127 -93
  87. evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
  88. evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
  89. evalscope/benchmarks/ifeval/instructions.py +109 -64
  90. evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
  91. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  92. evalscope/benchmarks/ifeval/utils.py +6 -7
  93. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  94. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  95. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  96. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  97. evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
  98. evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
  99. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
  100. evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
  101. evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
  102. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
  103. evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
  104. evalscope/benchmarks/math_vista/__init__.py +0 -0
  105. evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
  106. evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
  107. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
  108. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
  109. evalscope/benchmarks/mmmu/__init__.py +0 -0
  110. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  111. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  112. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
  113. evalscope/benchmarks/musr/musr_adapter.py +33 -64
  114. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
  115. evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
  116. evalscope/benchmarks/race/race_adapter.py +33 -119
  117. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
  118. evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
  119. evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
  120. evalscope/benchmarks/super_gpqa/utils.py +2 -1
  121. evalscope/benchmarks/tau_bench/generation.py +147 -0
  122. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
  123. evalscope/benchmarks/text2image/__init__.py +0 -0
  124. evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
  125. evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
  126. evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
  127. evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
  128. evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
  129. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
  130. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
  131. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
  132. evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
  133. evalscope/cli/cli.py +2 -0
  134. evalscope/cli/start_app.py +7 -1
  135. evalscope/cli/start_perf.py +7 -1
  136. evalscope/cli/start_server.py +6 -3
  137. evalscope/collections/__init__.py +2 -10
  138. evalscope/collections/sampler.py +10 -10
  139. evalscope/collections/schema.py +13 -11
  140. evalscope/config.py +157 -57
  141. evalscope/constants.py +37 -61
  142. evalscope/evaluator/__init__.py +1 -1
  143. evalscope/evaluator/evaluator.py +275 -419
  144. evalscope/filters/__init__.py +2 -0
  145. evalscope/filters/extraction.py +126 -0
  146. evalscope/filters/selection.py +57 -0
  147. evalscope/metrics/__init__.py +13 -13
  148. evalscope/metrics/llm_judge.py +47 -33
  149. evalscope/metrics/math_parser.py +27 -22
  150. evalscope/metrics/metric.py +307 -0
  151. evalscope/metrics/metrics.py +22 -18
  152. evalscope/metrics/t2v_metrics/__init__.py +0 -52
  153. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
  154. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
  155. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
  156. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
  157. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
  158. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
  159. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
  160. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
  161. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
  162. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
  163. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
  164. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
  165. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
  166. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
  167. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
  168. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
  169. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
  170. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
  171. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
  172. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
  173. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
  174. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
  175. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
  176. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
  177. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
  178. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
  179. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
  180. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
  181. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
  182. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
  183. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
  184. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
  185. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
  186. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
  187. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
  188. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
  189. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
  190. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
  191. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
  192. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
  193. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
  194. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
  195. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
  196. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
  197. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
  198. evalscope/models/__init__.py +6 -29
  199. evalscope/models/image_edit_model.py +125 -0
  200. evalscope/models/mockllm.py +65 -0
  201. evalscope/models/model_apis.py +67 -0
  202. evalscope/models/modelscope.py +455 -0
  203. evalscope/models/openai_compatible.py +126 -0
  204. evalscope/models/text2image_model.py +124 -0
  205. evalscope/models/utils/openai.py +701 -0
  206. evalscope/perf/benchmark.py +4 -1
  207. evalscope/perf/http_client.py +4 -2
  208. evalscope/perf/plugin/api/custom_api.py +5 -4
  209. evalscope/perf/plugin/api/openai_api.py +11 -9
  210. evalscope/perf/plugin/datasets/custom.py +2 -1
  211. evalscope/perf/plugin/datasets/flickr8k.py +1 -1
  212. evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
  213. evalscope/perf/plugin/datasets/line_by_line.py +2 -1
  214. evalscope/perf/plugin/datasets/longalpaca.py +2 -1
  215. evalscope/perf/plugin/datasets/openqa.py +4 -2
  216. evalscope/perf/utils/benchmark_util.py +15 -10
  217. evalscope/perf/utils/db_util.py +9 -6
  218. evalscope/perf/utils/local_server.py +11 -3
  219. evalscope/perf/utils/rich_display.py +16 -10
  220. evalscope/report/__init__.py +2 -3
  221. evalscope/report/combinator.py +18 -12
  222. evalscope/report/generator.py +51 -35
  223. evalscope/report/{utils.py → report.py} +8 -6
  224. evalscope/run.py +33 -47
  225. evalscope/summarizer.py +1 -1
  226. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  227. evalscope/utils/__init__.py +21 -2
  228. evalscope/utils/chat_service.py +3 -2
  229. evalscope/utils/deprecation_utils.py +12 -1
  230. evalscope/utils/function_utils.py +29 -0
  231. evalscope/utils/import_utils.py +23 -1
  232. evalscope/utils/io_utils.py +142 -6
  233. evalscope/utils/json_schema.py +208 -0
  234. evalscope/utils/logger.py +51 -12
  235. evalscope/utils/model_utils.py +11 -7
  236. evalscope/utils/multi_choices.py +288 -0
  237. evalscope/utils/url_utils.py +65 -0
  238. evalscope/version.py +2 -2
  239. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
  240. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
  241. tests/benchmark/test_eval.py +385 -0
  242. tests/benchmark/test_image_edit.py +65 -0
  243. tests/{aigc → benchmark}/test_t2i.py +22 -4
  244. tests/benchmark/test_vlm.py +80 -0
  245. tests/cli/test_all.py +85 -47
  246. tests/cli/test_collection.py +20 -8
  247. tests/cli/test_custom.py +22 -15
  248. tests/cli/test_reasoning.py +81 -0
  249. tests/common.py +73 -0
  250. tests/perf/test_perf.py +4 -2
  251. tests/rag/test_clip_benchmark.py +0 -2
  252. evalscope/benchmarks/aigc/t2i/base.py +0 -56
  253. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
  254. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
  255. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
  256. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
  257. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
  258. evalscope/benchmarks/arc/ai2_arc.py +0 -151
  259. evalscope/benchmarks/benchmark.py +0 -81
  260. evalscope/benchmarks/ceval/ceval_exam.py +0 -146
  261. evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
  262. evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
  263. evalscope/benchmarks/competition_math/competition_math.py +0 -79
  264. evalscope/benchmarks/data_adapter.py +0 -528
  265. evalscope/benchmarks/filters.py +0 -59
  266. evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
  267. evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
  268. evalscope/benchmarks/humaneval/humaneval.py +0 -79
  269. evalscope/benchmarks/mmlu/mmlu.py +0 -160
  270. evalscope/benchmarks/mmlu/samples.jsonl +0 -5
  271. evalscope/benchmarks/process_bench/critique_template.txt +0 -13
  272. evalscope/benchmarks/race/race.py +0 -104
  273. evalscope/benchmarks/race/samples.jsonl +0 -5
  274. evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
  275. evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
  276. evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
  277. evalscope/benchmarks/utils.py +0 -60
  278. evalscope/collections/evaluator.py +0 -375
  279. evalscope/metrics/completion_parsers.py +0 -227
  280. evalscope/metrics/named_metrics.py +0 -55
  281. evalscope/models/adapters/__init__.py +0 -14
  282. evalscope/models/adapters/base_adapter.py +0 -84
  283. evalscope/models/adapters/bfcl_adapter.py +0 -246
  284. evalscope/models/adapters/chat_adapter.py +0 -207
  285. evalscope/models/adapters/choice_adapter.py +0 -222
  286. evalscope/models/adapters/custom_adapter.py +0 -71
  287. evalscope/models/adapters/server_adapter.py +0 -236
  288. evalscope/models/adapters/t2i_adapter.py +0 -79
  289. evalscope/models/adapters/tau_bench_adapter.py +0 -189
  290. evalscope/models/custom/__init__.py +0 -4
  291. evalscope/models/custom/custom_model.py +0 -50
  292. evalscope/models/custom/dummy_model.py +0 -99
  293. evalscope/models/local_model.py +0 -128
  294. evalscope/models/register.py +0 -41
  295. tests/cli/test_run.py +0 -489
  296. /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
  297. /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
  298. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
  299. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
  300. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
  301. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
  302. /tests/{aigc → benchmark}/__init__.py +0 -0
@@ -0,0 +1,307 @@
1
+ from collections import defaultdict
2
+ from typing import List
3
+
4
+ from evalscope.api.metric import Aggregator, AggScore, Metric, SampleScore, T2IMetric
5
+ from evalscope.api.registry import register_aggregation, register_metric
6
+ from .metrics import mean
7
+
8
+
9
+ @register_metric(name='exact_match')
10
+ class ExactMatch(Metric):
11
+
12
+ def apply(self, predictions, references):
13
+ return [float(prediction == reference) for prediction, reference in zip(predictions, references)]
14
+
15
+
16
+ @register_metric(name='acc')
17
+ class Accuracy(ExactMatch):
18
+
19
+ def __init__(self, allow_inclusion: bool = False, numeric: bool = False):
20
+ self.allow_inclusion = allow_inclusion
21
+ self.numeric = numeric
22
+
23
+ def apply(self, predictions, references):
24
+ if self.allow_inclusion:
25
+ results = []
26
+ for prediction, reference in zip(predictions, references):
27
+ if prediction and prediction in reference:
28
+ results.append(1.0)
29
+ else:
30
+ results.append(0.0)
31
+ return results
32
+ elif self.numeric:
33
+ from .math_parser import extract_answer, math_equal, strip_answer_string
34
+
35
+ results = []
36
+ for prediction, reference in zip(predictions, references):
37
+ pred_answer = strip_answer_string(extract_answer(prediction))
38
+ ref_answer = strip_answer_string(reference)
39
+ results.append(float(math_equal(pred_answer, ref_answer)))
40
+
41
+ return results
42
+ else:
43
+ return super().apply(predictions, references)
44
+
45
+
46
+ @register_metric(name='numeric_match')
47
+ class NumericMatch(Metric):
48
+
49
+ def apply(self, predictions, references):
50
+ return [float(prediction == reference) for prediction, reference in zip(predictions, references)]
51
+
52
+
53
+ @register_metric(name='math_acc')
54
+ class MathAcc(Metric):
55
+
56
+ def apply(self, predictions, references):
57
+ from .math_parser import extract_answer, math_equal, strip_answer_string
58
+
59
+ results = []
60
+ for prediction, reference in zip(predictions, references):
61
+ pred_answer = strip_answer_string(extract_answer(prediction))
62
+ ref_answer = strip_answer_string(reference)
63
+ results.append(float(math_equal(pred_answer, ref_answer)))
64
+
65
+ return results
66
+
67
+
68
+ @register_metric(name='multi_choice_acc')
69
+ class MultiChoiceAcc(Metric):
70
+
71
+ def apply(self, predictions, references):
72
+ """
73
+ Calculate accuracy for multiple-choice questions.
74
+
75
+ Args:
76
+ predictions (List[str]): List of predicted answers.
77
+ references (List[str]): List of correct answers.
78
+
79
+ Returns:
80
+ List[float]: List of accuracy scores (1.0 for correct, 0.0 for incorrect).
81
+ """
82
+ res = []
83
+ for prediction, reference in zip(predictions, references):
84
+ prediction = set(prediction.strip().upper())
85
+ reference = set(reference.strip().upper())
86
+ # if the prediction has answer that not in reference, it is wrong
87
+ if not prediction.issubset(reference):
88
+ res.append(0.0)
89
+ continue
90
+ common = prediction.intersection(reference)
91
+ res.append(len(common) / len(reference) if reference else 0.0)
92
+ return res
93
+
94
+
95
+ # ##################
96
+ # T2I Metrics ######
97
+ ####################
98
+ @register_metric(name='VQAScore')
99
+ class VQAScore(T2IMetric):
100
+
101
+ def _init_once(self, model: str = 'clip-flant5-xxl'):
102
+ from .t2v_metrics.vqascore import VQAScore
103
+ self.model = VQAScore(model=model)
104
+
105
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
106
+ return self.model(images, texts, **kwargs)
107
+
108
+
109
+ @register_metric(name='PickScore')
110
+ class PickScore(T2IMetric):
111
+
112
+ def _init_once(self, model: str = 'pickscore-v1'):
113
+ from .t2v_metrics.clipscore import CLIPScore
114
+ self.model = CLIPScore(model=model)
115
+
116
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
117
+ return self.model(images, texts, **kwargs)
118
+
119
+
120
+ @register_metric(name='CLIPScore')
121
+ class CLIPScore(T2IMetric):
122
+
123
+ def _init_once(self, model: str = 'openai:ViT-L-14-336'):
124
+ from .t2v_metrics.clipscore import CLIPScore
125
+ self.model = CLIPScore(model=model)
126
+
127
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
128
+ return self.model(images, texts, **kwargs)
129
+
130
+
131
+ @register_metric(name='BLIPv2Score')
132
+ class BLIPv2Score(T2IMetric):
133
+
134
+ def _init_once(self, model: str = 'blip2-itm'):
135
+ from .t2v_metrics.itmscore import ITMScore
136
+ self.model = ITMScore(model=model)
137
+
138
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
139
+ return self.model(images, texts, **kwargs)
140
+
141
+
142
+ @register_metric(name='HPSv2Score')
143
+ class HPSv2Score(T2IMetric):
144
+
145
+ def _init_once(self, model: str = 'hpsv2'):
146
+ from .t2v_metrics.clipscore import CLIPScore
147
+ self.model = CLIPScore(model=model)
148
+
149
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
150
+ return self.model(images, texts, **kwargs)
151
+
152
+
153
+ @register_metric(name='HPSv2.1Score')
154
+ class HPSv2_1Score(T2IMetric):
155
+
156
+ def _init_once(self, model: str = 'hpsv2.1'):
157
+ from .t2v_metrics.clipscore import CLIPScore
158
+ self.model = CLIPScore(model=model)
159
+
160
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
161
+ return self.model(images, texts, **kwargs)
162
+
163
+
164
+ @register_metric(name='ImageRewardScore')
165
+ class ImageRewardScore(T2IMetric):
166
+
167
+ def _init_once(self, model: str = 'image-reward-v1'):
168
+ from .t2v_metrics.itmscore import ITMScore
169
+ self.model = ITMScore(model=model)
170
+
171
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
172
+ return self.model(images, texts, **kwargs)
173
+
174
+
175
+ @register_metric(name='FGA_BLIP2Score')
176
+ class FGA_BLIP2Score(T2IMetric):
177
+
178
+ def _init_once(self, model: str = 'fga_blip2'):
179
+ from .t2v_metrics.itmscore import ITMScore
180
+ self.model = ITMScore(model=model)
181
+
182
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
183
+ return self.model(images, texts, **kwargs)
184
+
185
+
186
+ @register_metric(name='MPS')
187
+ class MPS(T2IMetric):
188
+
189
+ def _init_once(self, model: str = 'mps'):
190
+ from .t2v_metrics.clipscore import CLIPScore
191
+ self.model = CLIPScore(model=model)
192
+
193
+ def apply(self, images: List[str], texts: List[str], **kwargs) -> List[float]:
194
+ return self.model(images, texts, **kwargs)
195
+
196
+
197
+ # ##################
198
+ # Aggregators ######
199
+ # ##################
200
+ @register_aggregation(name='mean')
201
+ class Mean(Aggregator):
202
+
203
+ name = 'mean'
204
+
205
+ def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
206
+ """Aggregate scores by computing the mean for each metric.
207
+
208
+ Args:
209
+ scores: List of sample scores to aggregate
210
+
211
+ Returns:
212
+ List of aggregated scores with mean values
213
+ """
214
+ if not scores:
215
+ return []
216
+
217
+ # Group score values by metric name
218
+ metric_values = defaultdict(list)
219
+ metric_sample_ids = defaultdict(list)
220
+
221
+ for score in scores:
222
+
223
+ for metric_name, value in score.score.value.items():
224
+ metric_values[metric_name].append(value)
225
+ metric_sample_ids[metric_name].append(score.sample_id)
226
+
227
+ # Calculate mean for each metric
228
+ aggregated_scores = []
229
+ for metric_name, values in metric_values.items():
230
+ if values: # Only process non-empty value lists
231
+ aggregated_scores.append(
232
+ AggScore(
233
+ score=mean(values),
234
+ metric_name=metric_name,
235
+ aggregation_name=self.name,
236
+ num=len(values),
237
+ ids=metric_sample_ids[metric_name]
238
+ )
239
+ )
240
+
241
+ return aggregated_scores
242
+
243
+
244
+ @register_aggregation(name='pass_at_k')
245
+ class PassAtK(Aggregator):
246
+
247
+ def __init__(self, k: int = 1):
248
+ self.k = k
249
+ self.name = f'pass_at_{k}'
250
+
251
+ def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
252
+ """Aggregate scores by computing the pass@k for each metric using group_id.
253
+
254
+ Args:
255
+ scores: List of sample scores to aggregate
256
+
257
+ Returns:
258
+ List of aggregated scores with pass@k values
259
+ """
260
+ if not scores:
261
+ return []
262
+
263
+ import numpy as np
264
+
265
+ from .metrics import calculate_pass_at_k
266
+
267
+ # Group scores by metric name and group_id
268
+ metric_groups = defaultdict(lambda: defaultdict(list))
269
+
270
+ for score in scores:
271
+ group_id = getattr(score, 'group_id', score.sample_id) # fallback to sample_id if no group_id
272
+
273
+ for metric_name, value in score.score.value.items():
274
+ metric_groups[metric_name][group_id].append(float(value))
275
+
276
+ # Calculate pass@k for each metric
277
+ aggregated_scores = []
278
+ for metric_name, groups in metric_groups.items():
279
+ if not groups:
280
+ continue
281
+
282
+ # Calculate pass@k for each group (problem)
283
+ num_samples = []
284
+ num_correct = []
285
+ all_sample_ids = []
286
+
287
+ for group_id, group_values in groups.items():
288
+ num_samples.append(len(group_values))
289
+ num_correct.append(sum(group_values)) # count how many passed in this group
290
+ all_sample_ids.extend([f'{group_id}_{i}' for i in range(len(group_values))])
291
+
292
+ if num_samples:
293
+ # Use the calculate_pass_at_k function from metrics
294
+ pass_at_k_values = calculate_pass_at_k(num_samples, num_correct, self.k)
295
+ overall_pass_at_k = float(np.mean(pass_at_k_values))
296
+
297
+ aggregated_scores.append(
298
+ AggScore(
299
+ score=overall_pass_at_k,
300
+ metric_name=f'pass@{self.k}',
301
+ aggregation_name='',
302
+ num=len(scores),
303
+ ids=all_sample_ids
304
+ )
305
+ )
306
+
307
+ return aggregated_scores
@@ -191,7 +191,7 @@ def bleu(items):
191
191
  return sacrebleu.corpus_bleu(preds, refs).score
192
192
 
193
193
 
194
- def bleu_ngram_one_sample(predict, reference):
194
+ def bleu_ngram_one_sample(predict: str, reference: str):
195
195
  """
196
196
  Calculate BLEU-1, BLEU-2, BLEU-3, and BLEU-4 scores
197
197
 
@@ -322,11 +322,11 @@ def bootstrap_stderr(f, xs, iters):
322
322
 
323
323
  print('bootstrapping for stddev:', f.__name__)
324
324
  for bootstrap in tqdm(
325
- pool.imap(
326
- _bootstrap_internal(f, chunk_size),
327
- [(i, xs) for i in range(iters // chunk_size)],
328
- ),
329
- total=iters // chunk_size,
325
+ pool.imap(
326
+ _bootstrap_internal(f, chunk_size),
327
+ [(i, xs) for i in range(iters // chunk_size)],
328
+ ),
329
+ total=iters // chunk_size,
330
330
  ):
331
331
  # sample w replacement
332
332
  res.extend(bootstrap)
@@ -361,15 +361,17 @@ def yesno(x):
361
361
  return 'no'
362
362
 
363
363
 
364
- def compute_elo(battles,
365
- col_model_a='model_a',
366
- col_model_b='model_b',
367
- col_win='win',
368
- tie_values=['tie', 'tie (bothbad)'],
369
- k=32,
370
- scale=400,
371
- base=10,
372
- init_rating=1000):
364
+ def compute_elo(
365
+ battles,
366
+ col_model_a='model_a',
367
+ col_model_b='model_b',
368
+ col_win='win',
369
+ tie_values=['tie', 'tie (bothbad)'],
370
+ k=32,
371
+ scale=400,
372
+ base=10,
373
+ init_rating=1000
374
+ ):
373
375
  rating = defaultdict(lambda: init_rating)
374
376
 
375
377
  for rd, model_a, model_b, win in battles[[col_model_a, col_model_b, col_win]].itertuples():
@@ -434,9 +436,11 @@ def calculate_arc_accuracy(question_answers: Dict[str, str], predictions: Dict[s
434
436
  return score / len(question_answers)
435
437
 
436
438
 
437
- def calculate_pass_at_k(num_samples: Union[int, List[int], np.ndarray],
438
- num_correct: Union[List[int], np.ndarray],
439
- k: int = 1) -> np.ndarray:
439
+ def calculate_pass_at_k(
440
+ num_samples: Union[int, List[int], np.ndarray],
441
+ num_correct: Union[List[int], np.ndarray],
442
+ k: int = 1
443
+ ) -> np.ndarray:
440
444
  """
441
445
  Estimates pass@k of each problem and returns them in an array.
442
446
  Examples:
@@ -1,52 +0,0 @@
1
- def clip_flant5_score():
2
- from .vqascore import VQAScore
3
- clip_flant5_score = VQAScore(model='clip-flant5-xxl')
4
- return clip_flant5_score
5
-
6
-
7
- def pick_score():
8
- from .clipscore import CLIPScore
9
- pick_score = CLIPScore(model='pickscore-v1')
10
- return pick_score
11
-
12
-
13
- def clip_score():
14
- from .clipscore import CLIPScore
15
- clip_score = CLIPScore(model='openai:ViT-L-14-336')
16
- return clip_score
17
-
18
-
19
- def blip2_score():
20
- from .itmscore import ITMScore
21
- blip_itm_score = ITMScore(model='blip2-itm')
22
- return blip_itm_score
23
-
24
-
25
- def hpsv2_score():
26
- from .clipscore import CLIPScore
27
- hpsv2_score = CLIPScore(model='hpsv2')
28
- return hpsv2_score
29
-
30
-
31
- def hpsv2_1_score():
32
- from .clipscore import CLIPScore
33
- hpsv2_1_score = CLIPScore(model='hpsv2.1')
34
- return hpsv2_1_score
35
-
36
-
37
- def image_reward_score():
38
- from .itmscore import ITMScore
39
- image_reward_score = ITMScore(model='image-reward-v1')
40
- return image_reward_score
41
-
42
-
43
- def fga_blip2_score():
44
- from .itmscore import ITMScore
45
- fga_blip2_score = ITMScore(model='fga_blip2')
46
- return fga_blip2_score
47
-
48
-
49
- def mps_score():
50
- from .clipscore import CLIPScore
51
- mps_score = CLIPScore(model='mps')
52
- return mps_score
@@ -27,7 +27,8 @@ class XCLIPModel(HFCLIPModel):
27
27
  # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
28
28
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
29
29
  output_hidden_states = (
30
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
30
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
31
+ )
31
32
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
32
33
 
33
34
  text_outputs = self.text_model(
@@ -63,7 +64,8 @@ class XCLIPModel(HFCLIPModel):
63
64
  # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
64
65
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
66
  output_hidden_states = (
66
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
67
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
68
+ )
67
69
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
68
70
 
69
71
  vision_outputs = self.vision_model(
@@ -178,15 +178,9 @@ class ParallelTransformerBlock(nn.Module):
178
178
 
179
179
  class CrossAttention(nn.Module):
180
180
 
181
- def __init__(self,
182
- dim,
183
- *,
184
- context_dim=None,
185
- dim_head=64,
186
- heads=12,
187
- parallel_ff=False,
188
- ff_mult=4,
189
- norm_context=False):
181
+ def __init__(
182
+ self, dim, *, context_dim=None, dim_head=64, heads=12, parallel_ff=False, ff_mult=4, norm_context=False
183
+ ):
190
184
  super().__init__()
191
185
  self.heads = heads
192
186
  self.scale = dim_head**-0.5
@@ -205,8 +199,8 @@ class CrossAttention(nn.Module):
205
199
  ff_inner_dim = ff_mult * dim
206
200
 
207
201
  self.ff = nn.Sequential(
208
- nn.Linear(dim, ff_inner_dim
209
- * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) if parallel_ff else None
202
+ nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)
203
+ ) if parallel_ff else None
210
204
 
211
205
  def forward(self, x, context, mask):
212
206
  """
@@ -273,9 +267,11 @@ class Cross_model(nn.Module):
273
267
  self.layers.append(
274
268
  nn.ModuleList([
275
269
  Residual(
276
- CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
270
+ CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)
271
+ ),
277
272
  Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
278
- ]))
273
+ ])
274
+ )
279
275
 
280
276
  def forward(self, query_tokens, context_tokens, mask):
281
277
 
@@ -86,7 +86,8 @@ class CLIPScoreModel(ScoreModel):
86
86
  model_file_path = download_open_clip_model(self.arch, self.pretrained, self.cache_dir)
87
87
 
88
88
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
89
- self.arch, pretrained=model_file_path, device=self.device)
89
+ self.arch, pretrained=model_file_path, device=self.device
90
+ )
90
91
  self.tokenizer = open_clip.get_tokenizer(self.arch)
91
92
  self.model.eval()
92
93
 
@@ -44,11 +44,12 @@ class HPSV2ScoreModel(ScoreModel):
44
44
  image_std=None,
45
45
  image_resize_mode='longest',
46
46
  aug_cfg={},
47
- output_dict=True)
47
+ output_dict=True
48
+ )
48
49
 
49
50
  # update weight
50
51
  model_weight_path = download_file('AI-ModelScope/HPSv2', HPS_VERSION_MAP[self.model_name], self.cache_dir)
51
- checkpoint = torch.load(model_weight_path, map_location=self.device)
52
+ checkpoint = torch.load(model_weight_path, map_location=self.device, weights_only=False)
52
53
  self.model.load_state_dict(checkpoint['state_dict'])
53
54
  self.tokenizer = open_clip.get_tokenizer(self.arch)
54
55
  self.model.eval()
@@ -29,7 +29,8 @@ class MPSModel(ScoreModel):
29
29
 
30
30
  config = download_file('AI-ModelScope/MPS', file_name='config.json', cache_dir=self.cache_dir)
31
31
  model_pretrained_path = download_file(
32
- 'AI-ModelScope/MPS', file_name='MPS_overall_state_dict.pt', cache_dir=self.cache_dir) # modelscope model
32
+ 'AI-ModelScope/MPS', file_name='MPS_overall_state_dict.pt', cache_dir=self.cache_dir
33
+ ) # modelscope model
33
34
  model_weight = torch.load(model_pretrained_path, weights_only=True, map_location='cpu')
34
35
 
35
36
  self.model = CLIPModel(config=CLIPConfig.from_json_file(config))
@@ -31,8 +31,8 @@ class PickScoreModel(ScoreModel):
31
31
  """Load the image(s), and return a tensor (no preprocessing!!) put on self.device
32
32
  """
33
33
  image = [self.image_loader(x) for x in image]
34
- image = self.processor(
35
- images=image, padding=True, truncation=True, max_length=77, return_tensors='pt').to(self.device)
34
+ image = self.processor(images=image, padding=True, truncation=True, max_length=77,
35
+ return_tensors='pt').to(self.device)
36
36
  # image = torch.stack(image, dim=0).to(self.device)
37
37
  return image
38
38
 
@@ -66,7 +66,8 @@ class BLIP2ITMScoreModel(ScoreModel):
66
66
  query_att = torch.ones(query_token.size()[:-1], dtype=torch.long).to(query_token.device)
67
67
 
68
68
  text_input = self.model.tokenizer(
69
- texts, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
69
+ texts, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
70
+ ).to(self.device)
70
71
 
71
72
  attention_mask_all = torch.cat([query_att, text_input.attention_mask], dim=1)
72
73
  output_itm = self.model.Qformer.bert(
@@ -42,10 +42,12 @@ class FGA_BLIP2ScoreModel(ScoreModel):
42
42
  # load model
43
43
  self.variant = FGA_BLIP2_MODELS[self.model_name]['variant']
44
44
  self.model, self.vis_processors, self.text_processors = load_model_and_preprocess(
45
- 'fga_blip2', self.variant, is_eval=True, device=self.device)
45
+ 'fga_blip2', self.variant, is_eval=True, device=self.device
46
+ )
46
47
  # load pretrained weights
47
48
  model_weight_path = download_file(
48
- 'AI-ModelScope/FGA-BLIP2', file_name='fga_blip2.pth', cache_dir=self.cache_dir)
49
+ 'AI-ModelScope/FGA-BLIP2', file_name='fga_blip2.pth', cache_dir=self.cache_dir
50
+ )
49
51
  self.model.load_checkpoint(model_weight_path)
50
52
  self.model.eval()
51
53
 
@@ -47,7 +47,8 @@ class MLP(nn.Module):
47
47
  nn.Dropout(0.1),
48
48
  nn.Linear(64, 16),
49
49
  #nn.ReLU(),
50
- nn.Linear(16, 1))
50
+ nn.Linear(16, 1)
51
+ )
51
52
 
52
53
  # initial MLP param
53
54
  for name, param in self.layers.named_parameters():
@@ -100,7 +101,8 @@ class ImageReward(nn.Module):
100
101
 
101
102
  # text encode
102
103
  text_input = self.blip.tokenizer(
103
- prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
104
+ prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
105
+ ).to(self.device)
104
106
 
105
107
  # image encode
106
108
  if isinstance(image, Image.Image):
@@ -109,7 +111,8 @@ class ImageReward(nn.Module):
109
111
  pil_image = Image.open(image)
110
112
  else:
111
113
  raise TypeError(
112
- r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
114
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
115
+ )
113
116
 
114
117
  image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
115
118
  image_embeds = self.blip.visual_encoder(image)
@@ -133,7 +136,8 @@ class ImageReward(nn.Module):
133
136
  def inference_rank(self, prompt, generations_list):
134
137
 
135
138
  text_input = self.blip.tokenizer(
136
- prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
139
+ prompt, padding='max_length', truncation=True, max_length=35, return_tensors='pt'
140
+ ).to(self.device)
137
141
 
138
142
  txt_set = []
139
143
  for generation in generations_list:
@@ -145,7 +149,8 @@ class ImageReward(nn.Module):
145
149
  pil_image = Image.open(generation)
146
150
  else:
147
151
  raise TypeError(
148
- r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
152
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.'
153
+ )
149
154
  image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
150
155
  image_embeds = self.blip.visual_encoder(image)
151
156
 
@@ -30,7 +30,8 @@ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop
30
30
  num_heads=12,
31
31
  use_grad_checkpointing=use_grad_checkpointing,
32
32
  ckpt_layer=ckpt_layer,
33
- drop_path_rate=0 or drop_path_rate)
33
+ drop_path_rate=0 or drop_path_rate
34
+ )
34
35
  elif vit == 'large':
35
36
  vision_width = 1024
36
37
  visual_encoder = VisionTransformer(
@@ -41,7 +42,8 @@ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop
41
42
  num_heads=16,
42
43
  use_grad_checkpointing=use_grad_checkpointing,
43
44
  ckpt_layer=ckpt_layer,
44
- drop_path_rate=0.1 or drop_path_rate)
45
+ drop_path_rate=0.1 or drop_path_rate
46
+ )
45
47
  return visual_encoder, vision_width
46
48
 
47
49
 
@@ -53,7 +53,8 @@ class ImageRewardScoreModel(ScoreModel):
53
53
  images = self.load_images(images)
54
54
  for index in range(len(texts)):
55
55
  text_input = self.model.blip.tokenizer(
56
- texts[index], padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device)
56
+ texts[index], padding='max_length', truncation=True, max_length=35, return_tensors='pt'
57
+ ).to(self.device)
57
58
  image_embeds = self.model.blip.visual_encoder(images[index].unsqueeze(0))
58
59
  image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
59
60
  text_output = self.model.blip.text_encoder(