evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (302) hide show
  1. evalscope/__init__.py +4 -1
  2. evalscope/api/benchmark/__init__.py +3 -0
  3. evalscope/api/benchmark/adapters/__init__.py +5 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
  7. evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
  8. evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
  9. evalscope/api/benchmark/benchmark.py +356 -0
  10. evalscope/api/benchmark/meta.py +121 -0
  11. evalscope/api/dataset/__init__.py +2 -0
  12. evalscope/api/dataset/dataset.py +349 -0
  13. evalscope/api/dataset/loader.py +262 -0
  14. evalscope/api/dataset/utils.py +143 -0
  15. evalscope/api/evaluator/__init__.py +3 -0
  16. evalscope/api/evaluator/cache.py +378 -0
  17. evalscope/api/evaluator/evaluator.py +56 -0
  18. evalscope/api/evaluator/state.py +275 -0
  19. evalscope/api/filter/__init__.py +1 -0
  20. evalscope/api/filter/filter.py +72 -0
  21. evalscope/api/messages/__init__.py +12 -0
  22. evalscope/api/messages/chat_message.py +243 -0
  23. evalscope/api/messages/content.py +102 -0
  24. evalscope/api/messages/utils.py +35 -0
  25. evalscope/api/metric/__init__.py +2 -0
  26. evalscope/api/metric/metric.py +55 -0
  27. evalscope/api/metric/scorer.py +113 -0
  28. evalscope/api/mixin/__init__.py +1 -0
  29. evalscope/api/mixin/llm_judge_mixin.py +168 -0
  30. evalscope/api/model/__init__.py +12 -0
  31. evalscope/api/model/generate_config.py +155 -0
  32. evalscope/api/model/model.py +386 -0
  33. evalscope/api/model/model_output.py +285 -0
  34. evalscope/api/registry.py +182 -0
  35. evalscope/api/tool/__init__.py +3 -0
  36. evalscope/api/tool/tool_call.py +101 -0
  37. evalscope/api/tool/tool_info.py +173 -0
  38. evalscope/api/tool/utils.py +64 -0
  39. evalscope/app/app.py +3 -0
  40. evalscope/app/ui/app_ui.py +2 -1
  41. evalscope/app/ui/multi_model.py +50 -25
  42. evalscope/app/ui/single_model.py +26 -14
  43. evalscope/app/utils/data_utils.py +43 -27
  44. evalscope/app/utils/env_utils.py +12 -0
  45. evalscope/app/utils/text_utils.py +14 -14
  46. evalscope/app/utils/visualization.py +9 -4
  47. evalscope/arguments.py +7 -10
  48. evalscope/backend/opencompass/api_meta_template.py +2 -1
  49. evalscope/backend/opencompass/backend_manager.py +6 -5
  50. evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
  51. evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
  52. evalscope/backend/rag_eval/ragas/task_template.py +2 -1
  53. evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
  54. evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
  55. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
  56. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
  57. evalscope/backend/rag_eval/utils/embedding.py +10 -1
  58. evalscope/backend/rag_eval/utils/llm.py +13 -12
  59. evalscope/benchmarks/__init__.py +0 -2
  60. evalscope/benchmarks/aime/aime24_adapter.py +38 -40
  61. evalscope/benchmarks/aime/aime25_adapter.py +34 -40
  62. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
  63. evalscope/benchmarks/arc/arc_adapter.py +34 -147
  64. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
  65. evalscope/benchmarks/arena_hard/utils.py +37 -1
  66. evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
  67. evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
  68. evalscope/benchmarks/bfcl/generation.py +222 -0
  69. evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
  70. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
  71. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
  72. evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
  73. evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
  74. evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
  75. evalscope/benchmarks/docmath/utils.py +4 -5
  76. evalscope/benchmarks/drop/drop_adapter.py +88 -40
  77. evalscope/benchmarks/frames/frames_adapter.py +136 -52
  78. evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
  79. evalscope/benchmarks/general_arena/utils.py +23 -27
  80. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
  81. evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
  82. evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
  83. evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
  84. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
  85. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
  86. evalscope/benchmarks/hle/hle_adapter.py +127 -93
  87. evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
  88. evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
  89. evalscope/benchmarks/ifeval/instructions.py +109 -64
  90. evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
  91. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  92. evalscope/benchmarks/ifeval/utils.py +6 -7
  93. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  94. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  95. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  96. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  97. evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
  98. evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
  99. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
  100. evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
  101. evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
  102. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
  103. evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
  104. evalscope/benchmarks/math_vista/__init__.py +0 -0
  105. evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
  106. evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
  107. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
  108. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
  109. evalscope/benchmarks/mmmu/__init__.py +0 -0
  110. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  111. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  112. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
  113. evalscope/benchmarks/musr/musr_adapter.py +33 -64
  114. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
  115. evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
  116. evalscope/benchmarks/race/race_adapter.py +33 -119
  117. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
  118. evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
  119. evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
  120. evalscope/benchmarks/super_gpqa/utils.py +2 -1
  121. evalscope/benchmarks/tau_bench/generation.py +147 -0
  122. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
  123. evalscope/benchmarks/text2image/__init__.py +0 -0
  124. evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
  125. evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
  126. evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
  127. evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
  128. evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
  129. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
  130. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
  131. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
  132. evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
  133. evalscope/cli/cli.py +2 -0
  134. evalscope/cli/start_app.py +7 -1
  135. evalscope/cli/start_perf.py +7 -1
  136. evalscope/cli/start_server.py +6 -3
  137. evalscope/collections/__init__.py +2 -10
  138. evalscope/collections/sampler.py +10 -10
  139. evalscope/collections/schema.py +13 -11
  140. evalscope/config.py +157 -57
  141. evalscope/constants.py +37 -61
  142. evalscope/evaluator/__init__.py +1 -1
  143. evalscope/evaluator/evaluator.py +275 -419
  144. evalscope/filters/__init__.py +2 -0
  145. evalscope/filters/extraction.py +126 -0
  146. evalscope/filters/selection.py +57 -0
  147. evalscope/metrics/__init__.py +13 -13
  148. evalscope/metrics/llm_judge.py +47 -33
  149. evalscope/metrics/math_parser.py +27 -22
  150. evalscope/metrics/metric.py +307 -0
  151. evalscope/metrics/metrics.py +22 -18
  152. evalscope/metrics/t2v_metrics/__init__.py +0 -52
  153. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
  154. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
  155. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
  156. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
  157. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
  158. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
  159. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
  160. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
  161. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
  162. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
  163. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
  164. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
  165. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
  166. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
  167. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
  168. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
  169. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
  170. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
  171. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
  172. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
  173. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
  174. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
  175. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
  176. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
  177. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
  178. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
  179. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
  180. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
  181. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
  182. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
  183. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
  184. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
  185. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
  186. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
  187. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
  188. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
  189. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
  190. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
  191. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
  192. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
  193. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
  194. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
  195. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
  196. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
  197. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
  198. evalscope/models/__init__.py +6 -29
  199. evalscope/models/image_edit_model.py +125 -0
  200. evalscope/models/mockllm.py +65 -0
  201. evalscope/models/model_apis.py +67 -0
  202. evalscope/models/modelscope.py +455 -0
  203. evalscope/models/openai_compatible.py +126 -0
  204. evalscope/models/text2image_model.py +124 -0
  205. evalscope/models/utils/openai.py +701 -0
  206. evalscope/perf/benchmark.py +4 -1
  207. evalscope/perf/http_client.py +4 -2
  208. evalscope/perf/plugin/api/custom_api.py +5 -4
  209. evalscope/perf/plugin/api/openai_api.py +11 -9
  210. evalscope/perf/plugin/datasets/custom.py +2 -1
  211. evalscope/perf/plugin/datasets/flickr8k.py +1 -1
  212. evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
  213. evalscope/perf/plugin/datasets/line_by_line.py +2 -1
  214. evalscope/perf/plugin/datasets/longalpaca.py +2 -1
  215. evalscope/perf/plugin/datasets/openqa.py +4 -2
  216. evalscope/perf/utils/benchmark_util.py +15 -10
  217. evalscope/perf/utils/db_util.py +9 -6
  218. evalscope/perf/utils/local_server.py +11 -3
  219. evalscope/perf/utils/rich_display.py +16 -10
  220. evalscope/report/__init__.py +2 -3
  221. evalscope/report/combinator.py +18 -12
  222. evalscope/report/generator.py +51 -35
  223. evalscope/report/{utils.py → report.py} +8 -6
  224. evalscope/run.py +33 -47
  225. evalscope/summarizer.py +1 -1
  226. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  227. evalscope/utils/__init__.py +21 -2
  228. evalscope/utils/chat_service.py +3 -2
  229. evalscope/utils/deprecation_utils.py +12 -1
  230. evalscope/utils/function_utils.py +29 -0
  231. evalscope/utils/import_utils.py +23 -1
  232. evalscope/utils/io_utils.py +142 -6
  233. evalscope/utils/json_schema.py +208 -0
  234. evalscope/utils/logger.py +51 -12
  235. evalscope/utils/model_utils.py +11 -7
  236. evalscope/utils/multi_choices.py +288 -0
  237. evalscope/utils/url_utils.py +65 -0
  238. evalscope/version.py +2 -2
  239. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
  240. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
  241. tests/benchmark/test_eval.py +385 -0
  242. tests/benchmark/test_image_edit.py +65 -0
  243. tests/{aigc → benchmark}/test_t2i.py +22 -4
  244. tests/benchmark/test_vlm.py +80 -0
  245. tests/cli/test_all.py +85 -47
  246. tests/cli/test_collection.py +20 -8
  247. tests/cli/test_custom.py +22 -15
  248. tests/cli/test_reasoning.py +81 -0
  249. tests/common.py +73 -0
  250. tests/perf/test_perf.py +4 -2
  251. tests/rag/test_clip_benchmark.py +0 -2
  252. evalscope/benchmarks/aigc/t2i/base.py +0 -56
  253. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
  254. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
  255. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
  256. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
  257. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
  258. evalscope/benchmarks/arc/ai2_arc.py +0 -151
  259. evalscope/benchmarks/benchmark.py +0 -81
  260. evalscope/benchmarks/ceval/ceval_exam.py +0 -146
  261. evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
  262. evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
  263. evalscope/benchmarks/competition_math/competition_math.py +0 -79
  264. evalscope/benchmarks/data_adapter.py +0 -528
  265. evalscope/benchmarks/filters.py +0 -59
  266. evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
  267. evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
  268. evalscope/benchmarks/humaneval/humaneval.py +0 -79
  269. evalscope/benchmarks/mmlu/mmlu.py +0 -160
  270. evalscope/benchmarks/mmlu/samples.jsonl +0 -5
  271. evalscope/benchmarks/process_bench/critique_template.txt +0 -13
  272. evalscope/benchmarks/race/race.py +0 -104
  273. evalscope/benchmarks/race/samples.jsonl +0 -5
  274. evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
  275. evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
  276. evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
  277. evalscope/benchmarks/utils.py +0 -60
  278. evalscope/collections/evaluator.py +0 -375
  279. evalscope/metrics/completion_parsers.py +0 -227
  280. evalscope/metrics/named_metrics.py +0 -55
  281. evalscope/models/adapters/__init__.py +0 -14
  282. evalscope/models/adapters/base_adapter.py +0 -84
  283. evalscope/models/adapters/bfcl_adapter.py +0 -246
  284. evalscope/models/adapters/chat_adapter.py +0 -207
  285. evalscope/models/adapters/choice_adapter.py +0 -222
  286. evalscope/models/adapters/custom_adapter.py +0 -71
  287. evalscope/models/adapters/server_adapter.py +0 -236
  288. evalscope/models/adapters/t2i_adapter.py +0 -79
  289. evalscope/models/adapters/tau_bench_adapter.py +0 -189
  290. evalscope/models/custom/__init__.py +0 -4
  291. evalscope/models/custom/custom_model.py +0 -50
  292. evalscope/models/custom/dummy_model.py +0 -99
  293. evalscope/models/local_model.py +0 -128
  294. evalscope/models/register.py +0 -41
  295. tests/cli/test_run.py +0 -489
  296. /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
  297. /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
  298. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
  299. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
  300. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
  301. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
  302. /tests/{aigc → benchmark}/__init__.py +0 -0
@@ -58,8 +58,9 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
58
58
  def get_model(self):
59
59
  return self # for compatibility with LlavaMetaForCausalLM
60
60
 
61
- def prepare_inputs_labels_for_multimodal(self, input_ids, attention_mask, decoder_attention_mask, past_key_values,
62
- labels, images):
61
+ def prepare_inputs_labels_for_multimodal(
62
+ self, input_ids, attention_mask, decoder_attention_mask, past_key_values, labels, images
63
+ ):
63
64
  # The labels are now separated from the input_ids.
64
65
  vision_tower = self.get_vision_tower()
65
66
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
@@ -103,10 +104,12 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
103
104
  _input_embeds_lengths = []
104
105
  for cur_new_embed in new_input_embeds:
105
106
  _input_embeds_lengths.append(cur_new_embed.shape[0])
106
- cur_new_embed = torch.cat((cur_new_embed,
107
- torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
108
- dtype=cur_new_embed.dtype,
109
- device=cur_new_embed.device)),
107
+ cur_new_embed = torch.cat((
108
+ cur_new_embed,
109
+ torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
110
+ dtype=cur_new_embed.dtype,
111
+ device=cur_new_embed.device)
112
+ ),
110
113
  dim=0)
111
114
  new_input_embeds_align.append(cur_new_embed)
112
115
  new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
@@ -123,7 +126,8 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
123
126
  dtype=attention_mask.dtype,
124
127
  device=attention_mask.device)
125
128
  cur_new_attention_mask = torch.cat(
126
- (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
129
+ (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0
130
+ )
127
131
  new_attention_mask.append(cur_new_attention_mask)
128
132
  attention_mask = torch.stack(new_attention_mask, dim=0)
129
133
  assert attention_mask.shape == new_input_embeds.shape[:2]
@@ -135,7 +139,8 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
135
139
  (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]),
136
140
  True,
137
141
  dtype=attention_mask.dtype,
138
- device=attention_mask.device)
142
+ device=attention_mask.device
143
+ )
139
144
  attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
140
145
  assert attention_mask.shape == new_input_embeds.shape[:2]
141
146
 
@@ -204,7 +209,8 @@ class CLIPT5ForConditionalGeneration(T5ForConditionalGeneration):
204
209
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
205
210
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
206
211
  output_hidden_states = (
207
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
212
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
213
+ )
208
214
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
209
215
 
210
216
  if inputs_embeds is None:
@@ -44,12 +44,14 @@ class CLIPVisionTower(nn.Module):
44
44
  image_features = []
45
45
  for image in images:
46
46
  image_forward_out = self.vision_tower(
47
- image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
47
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
48
+ )
48
49
  image_feature = self.feature_select(image_forward_out).to(image.dtype)
49
50
  image_features.append(image_feature)
50
51
  else:
51
52
  image_forward_outs = self.vision_tower(
52
- images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
53
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
54
+ )
53
55
  image_features = self.feature_select(image_forward_outs).to(images.dtype)
54
56
 
55
57
  return image_features
@@ -98,7 +98,8 @@ class CLIPT5Model(VQAScoreModel):
98
98
  mmprojector_repo=mmprojector_repo,
99
99
  mmprojector_name=mmprojector_name,
100
100
  device=self.device,
101
- cache_dir=self.cache_dir)
101
+ cache_dir=self.cache_dir
102
+ )
102
103
 
103
104
  def load_images(self, image: List[str]) -> torch.Tensor:
104
105
  """Load the image(s), and return a tensor (after preprocessing) put on self.device
@@ -115,11 +116,13 @@ class CLIPT5Model(VQAScoreModel):
115
116
 
116
117
  @torch.no_grad()
117
118
  @torch.autocast(device_type='cuda', dtype=torch.bfloat16)
118
- def forward(self,
119
- images: List[str],
120
- texts: List[str],
121
- question_template: str = default_question_template,
122
- answer_template: str = default_answer_template) -> torch.Tensor:
119
+ def forward(
120
+ self,
121
+ images: List[str],
122
+ texts: List[str],
123
+ question_template: str = default_question_template,
124
+ answer_template: str = default_answer_template
125
+ ) -> torch.Tensor:
123
126
  """Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
124
127
  """
125
128
  assert len(images) == len(texts), 'Number of images and texts must match'
@@ -139,7 +142,8 @@ class CLIPT5Model(VQAScoreModel):
139
142
  labels = [t5_tokenizer_image_token(ans, self.tokenizer, return_tensors='pt') for ans in answers]
140
143
 
141
144
  input_ids = torch.nn.utils.rnn.pad_sequence(
142
- input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
145
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
146
+ )
143
147
  labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
144
148
  input_ids = input_ids[:, :self.tokenizer.model_max_length]
145
149
  labels = labels[:, :self.tokenizer.model_max_length]
@@ -169,8 +173,8 @@ class CLIPT5Model(VQAScoreModel):
169
173
  lm_prob = torch.zeros(logits.shape[0])
170
174
  loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
171
175
  for k in range(lm_prob.shape[0]):
172
- lm_prob[k] = (
173
- -loss_fct(logits[k], labels[k])).exp() # exp to cancel the log and get raw prob between 0 and 1
176
+ lm_prob[k] = (-loss_fct(logits[k],
177
+ labels[k])).exp() # exp to cancel the log and get raw prob between 0 and 1
174
178
  return lm_prob
175
179
 
176
180
  @torch.no_grad()
@@ -191,7 +195,8 @@ class CLIPT5Model(VQAScoreModel):
191
195
 
192
196
  input_ids = [t5_tokenizer_image_token(qs, self.tokenizer, return_tensors='pt') for qs in questions]
193
197
  input_ids = torch.nn.utils.rnn.pad_sequence(
194
- input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
198
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
199
+ )
195
200
  input_ids = input_ids[:, :self.tokenizer.model_max_length]
196
201
 
197
202
  attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
@@ -1,6 +1,5 @@
1
1
  import base64
2
2
  import os
3
- import tiktoken
4
3
  import torch
5
4
  from openai import OpenAI
6
5
  from typing import List
@@ -42,6 +41,8 @@ class GPT4VModel(VQAScoreModel):
42
41
  def load_model(self):
43
42
  """Load the model, tokenizer, image transform
44
43
  """
44
+ import tiktoken
45
+
45
46
  self.tokenizer = tiktoken.encoding_for_model(self.model_name)
46
47
  self.client = OpenAI(api_key=self.openai_key)
47
48
  # self.candidate_answers = GPT4V_MODELS[self.model_name]['candidate_answers']
@@ -122,11 +123,13 @@ class GPT4VModel(VQAScoreModel):
122
123
  print(completion.choices[0].logprobs.content[0].top_logprobs)
123
124
  return torch.Tensor([0.0])
124
125
 
125
- def forward(self,
126
- images: List[str],
127
- texts: List[str],
128
- question_template: str = default_question_template,
129
- answer_template: str = default_answer_template) -> torch.Tensor:
126
+ def forward(
127
+ self,
128
+ images: List[str],
129
+ texts: List[str],
130
+ question_template: str = default_question_template,
131
+ answer_template: str = default_answer_template
132
+ ) -> torch.Tensor:
130
133
  """Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
131
134
  """
132
135
  assert len(images) == len(texts), 'Number of images and texts must match'
@@ -227,8 +227,8 @@ class ConfigValidator:
227
227
  """
228
228
  for k, v in config.items():
229
229
  assert (
230
- k
231
- in self.arguments), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
230
+ k in self.arguments
231
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
232
232
 
233
233
  if self.arguments[k].type is not None:
234
234
  try:
@@ -17,6 +17,8 @@ def getAttMap(img, attMap, blur=True, overlap=True):
17
17
  attMapV = cmap(attMap)
18
18
  attMapV = np.delete(attMapV, 3, 2)
19
19
  if overlap:
20
- attMap = (1 * (1 - attMap**0.7).reshape(attMap.shape + (1, )) * img +
21
- (attMap**0.7).reshape(attMap.shape + (1, )) * attMapV)
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1, )) * img +
22
+ (attMap**0.7).reshape(attMap.shape + (1, )) * attMapV
23
+ )
22
24
  return attMap
@@ -155,7 +155,8 @@ class MetricLogger(object):
155
155
  time=str(iter_time),
156
156
  data=str(data_time),
157
157
  memory=torch.cuda.max_memory_allocated() / MB,
158
- ))
158
+ )
159
+ )
159
160
  else:
160
161
  print(
161
162
  log_msg.format(
@@ -165,7 +166,8 @@ class MetricLogger(object):
165
166
  meters=str(self),
166
167
  time=str(iter_time),
167
168
  data=str(data_time),
168
- ))
169
+ )
170
+ )
169
171
  i += 1
170
172
  end = time.time()
171
173
  total_time = time.time() - start_time
@@ -13,15 +13,9 @@ from . import registry
13
13
  @registry.register_lr_scheduler('linear_warmup_step_lr')
14
14
  class LinearWarmupStepLRScheduler:
15
15
 
16
- def __init__(self,
17
- optimizer,
18
- max_epoch,
19
- min_lr,
20
- init_lr,
21
- decay_rate=1,
22
- warmup_start_lr=-1,
23
- warmup_steps=0,
24
- **kwargs):
16
+ def __init__(
17
+ self, optimizer, max_epoch, min_lr, init_lr, decay_rate=1, warmup_start_lr=-1, warmup_steps=0, **kwargs
18
+ ):
25
19
  self.optimizer = optimizer
26
20
 
27
21
  self.max_epoch = max_epoch
@@ -96,8 +96,9 @@ class Registry:
96
96
 
97
97
  assert issubclass(model_cls, BaseModel), 'All models must inherit BaseModel class'
98
98
  if name in cls.mapping['model_name_mapping']:
99
- raise KeyError("Name '{}' already registered for {}.".format(name,
100
- cls.mapping['model_name_mapping'][name]))
99
+ raise KeyError(
100
+ "Name '{}' already registered for {}.".format(name, cls.mapping['model_name_mapping'][name])
101
+ )
101
102
  cls.mapping['model_name_mapping'][name] = model_cls
102
103
  return model_cls
103
104
 
@@ -120,8 +121,9 @@ class Registry:
120
121
 
121
122
  assert issubclass(processor_cls, BaseProcessor), 'All processors must inherit BaseProcessor class'
122
123
  if name in cls.mapping['processor_name_mapping']:
123
- raise KeyError("Name '{}' already registered for {}.".format(
124
- name, cls.mapping['processor_name_mapping'][name]))
124
+ raise KeyError(
125
+ "Name '{}' already registered for {}.".format(name, cls.mapping['processor_name_mapping'][name])
126
+ )
125
127
  cls.mapping['processor_name_mapping'][name] = processor_cls
126
128
  return processor_cls
127
129
 
@@ -141,8 +143,9 @@ class Registry:
141
143
 
142
144
  def wrap(lr_sched_cls):
143
145
  if name in cls.mapping['lr_scheduler_name_mapping']:
144
- raise KeyError("Name '{}' already registered for {}.".format(
145
- name, cls.mapping['lr_scheduler_name_mapping'][name]))
146
+ raise KeyError(
147
+ "Name '{}' already registered for {}.".format(name, cls.mapping['lr_scheduler_name_mapping'][name])
148
+ )
146
149
  cls.mapping['lr_scheduler_name_mapping'][name] = lr_sched_cls
147
150
  return lr_sched_cls
148
151
 
@@ -162,8 +165,9 @@ class Registry:
162
165
 
163
166
  def wrap(runner_cls):
164
167
  if name in cls.mapping['runner_name_mapping']:
165
- raise KeyError("Name '{}' already registered for {}.".format(name,
166
- cls.mapping['runner_name_mapping'][name]))
168
+ raise KeyError(
169
+ "Name '{}' already registered for {}.".format(name, cls.mapping['runner_name_mapping'][name])
170
+ )
167
171
  cls.mapping['runner_name_mapping'][name] = runner_cls
168
172
  return runner_cls
169
173
 
@@ -285,8 +289,10 @@ class Registry:
285
289
  break
286
290
 
287
291
  if ('writer' in cls.mapping['state'] and value == default and no_warning is False):
288
- cls.mapping['state']['writer'].warning('Key {} is not present in registry, returning default value '
289
- 'of {}'.format(original_name, default))
292
+ cls.mapping['state']['writer'].warning(
293
+ 'Key {} is not present in registry, returning default value '
294
+ 'of {}'.format(original_name, default)
295
+ )
290
296
  return value
291
297
 
292
298
  @classmethod
@@ -178,8 +178,9 @@ class VQA:
178
178
  for ann in anns:
179
179
  quesId = ann['question_id']
180
180
  if res.dataset['task_type'] == 'Multiple Choice':
181
- assert (ann['answer']
182
- in self.qqa[quesId]['multiple_choices']), 'predicted answer is not one of the multiple choices'
181
+ assert (
182
+ ann['answer'] in self.qqa[quesId]['multiple_choices']
183
+ ), 'predicted answer is not one of the multiple choices'
183
184
  qaAnn = self.qa[quesId]
184
185
  ann['image_id'] = qaAnn['image_id']
185
186
  ann['question_type'] = qaAnn['question_type']
@@ -10,6 +10,7 @@
10
10
  __author__ = 'aagrawal'
11
11
 
12
12
  import re
13
+
13
14
  # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
14
15
  # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
15
16
  import sys
@@ -312,7 +313,8 @@ class VQAEval:
312
313
  progress = 1
313
314
  status = 'Done...\r\n'
314
315
  block = int(round(barLength * progress))
315
- text = '\rFinshed Percent: [{0}] {1}% {2}'.format('#' * block + '-' * (barLength - block), int(progress * 100),
316
- status)
316
+ text = '\rFinshed Percent: [{0}] {1}% {2}'.format(
317
+ '#' * block + '-' * (barLength - block), int(progress * 100), status
318
+ )
317
319
  sys.stdout.write(text)
318
320
  sys.stdout.flush()
@@ -166,10 +166,12 @@ def load_model_and_preprocess(name, model_type, is_eval=False, device='cpu'):
166
166
  vis_processors, txt_processors = load_preprocess(preprocess_cfg)
167
167
  else:
168
168
  vis_processors, txt_processors = None, None
169
- logging.info(f"""No default preprocess for model {name} ({model_type}).
169
+ logging.info(
170
+ f"""No default preprocess for model {name} ({model_type}).
170
171
  This can happen if the model is not finetuned on downstream datasets,
171
172
  or it is not intended for direct use without finetuning.
172
- """)
173
+ """
174
+ )
173
175
 
174
176
  if device == 'cpu' or device == torch.device('cpu'):
175
177
  model = model.float()
@@ -195,8 +197,10 @@ class ModelZoo:
195
197
  }
196
198
 
197
199
  def __str__(self) -> str:
198
- return ('=' * 50 + '\n' + f"{'Architectures':<30} {'Types'}\n" + '=' * 50 + '\n'
199
- + '\n'.join([f"{name:<30} {', '.join(types)}" for name, types in self.model_zoo.items()]))
200
+ return (
201
+ '=' * 50 + '\n' + f"{'Architectures':<30} {'Types'}\n" + '=' * 50 + '\n'
202
+ + '\n'.join([f"{name:<30} {', '.join(types)}" for name, types in self.model_zoo.items()])
203
+ )
200
204
 
201
205
  def __iter__(self):
202
206
  return iter(self.model_zoo.items())
@@ -19,13 +19,23 @@ from torch import Tensor, device, dtype, nn
19
19
  from torch.nn import CrossEntropyLoss
20
20
  from transformers.activations import ACT2FN
21
21
  from transformers.file_utils import ModelOutput
22
- from transformers.modeling_outputs import (BaseModelOutputWithPastAndCrossAttentions,
23
- BaseModelOutputWithPoolingAndCrossAttentions,
24
- CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput,
25
- NextSentencePredictorOutput, QuestionAnsweringModelOutput,
26
- SequenceClassifierOutput, TokenClassifierOutput)
27
- from transformers.modeling_utils import (PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices,
28
- prune_linear_layer)
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPastAndCrossAttentions,
24
+ BaseModelOutputWithPoolingAndCrossAttentions,
25
+ CausalLMOutputWithCrossAttentions,
26
+ MaskedLMOutput,
27
+ MultipleChoiceModelOutput,
28
+ NextSentencePredictorOutput,
29
+ QuestionAnsweringModelOutput,
30
+ SequenceClassifierOutput,
31
+ TokenClassifierOutput,
32
+ )
33
+ from transformers.modeling_utils import (
34
+ PreTrainedModel,
35
+ apply_chunking_to_forward,
36
+ find_pruneable_heads_and_indices,
37
+ prune_linear_layer,
38
+ )
29
39
  from transformers.models.bert.configuration_bert import BertConfig
30
40
  from transformers.utils import logging
31
41
  from typing import Any, Dict, Optional, Tuple
@@ -89,8 +99,10 @@ class BertSelfAttention(nn.Module):
89
99
  super().__init__()
90
100
  self.config = config
91
101
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
92
- raise ValueError('The hidden size (%d) is not a multiple of the number of attention '
93
- 'heads (%d)' % (config.hidden_size, config.num_attention_heads))
102
+ raise ValueError(
103
+ 'The hidden size (%d) is not a multiple of the number of attention '
104
+ 'heads (%d)' % (config.hidden_size, config.num_attention_heads)
105
+ )
94
106
 
95
107
  self.num_attention_heads = config.num_attention_heads
96
108
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -366,8 +378,9 @@ class BertLayer(nn.Module):
366
378
  query_attention_output = attention_output[:, :query_length, :]
367
379
 
368
380
  if self.has_cross_attention:
369
- assert (encoder_hidden_states
370
- is not None), 'encoder_hidden_states must be given for cross-attention layers'
381
+ assert (
382
+ encoder_hidden_states is not None
383
+ ), 'encoder_hidden_states must be given for cross-attention layers'
371
384
  cross_attention_outputs = self.crossattention(
372
385
  query_attention_output,
373
386
  attention_mask,
@@ -377,8 +390,9 @@ class BertLayer(nn.Module):
377
390
  output_attentions=output_attentions,
378
391
  )
379
392
  query_attention_output = cross_attention_outputs[0]
380
- outputs = (outputs + cross_attention_outputs[1:-1]
381
- ) # add cross attentions if we output attention weights
393
+ outputs = (
394
+ outputs + cross_attention_outputs[1:-1]
395
+ ) # add cross attentions if we output attention weights
382
396
 
383
397
  layer_output = apply_chunking_to_forward(
384
398
  self.feed_forward_chunk_query,
@@ -457,7 +471,8 @@ class BertEncoder(nn.Module):
457
471
 
458
472
  if use_cache:
459
473
  logger.warn(
460
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
474
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
475
+ )
461
476
  use_cache = False
462
477
 
463
478
  def create_custom_forward(module):
@@ -498,13 +513,15 @@ class BertEncoder(nn.Module):
498
513
  all_hidden_states = all_hidden_states + (hidden_states, )
499
514
 
500
515
  if not return_dict:
501
- return tuple(v for v in [
502
- hidden_states,
503
- next_decoder_cache,
504
- all_hidden_states,
505
- all_self_attentions,
506
- all_cross_attentions,
507
- ] if v is not None)
516
+ return tuple(
517
+ v for v in [
518
+ hidden_states,
519
+ next_decoder_cache,
520
+ all_hidden_states,
521
+ all_self_attentions,
522
+ all_cross_attentions,
523
+ ] if v is not None
524
+ )
508
525
  return BaseModelOutputWithPastAndCrossAttentions(
509
526
  last_hidden_state=hidden_states,
510
527
  past_key_values=next_decoder_cache,
@@ -708,8 +725,11 @@ class BertModel(BertPreTrainedModel):
708
725
  else:
709
726
  extended_attention_mask = attention_mask[:, None, None, :]
710
727
  else:
711
- raise ValueError('Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
712
- input_shape, attention_mask.shape))
728
+ raise ValueError(
729
+ 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
730
+ input_shape, attention_mask.shape
731
+ )
732
+ )
713
733
 
714
734
  # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
715
735
  # masked positions, this operation will create a tensor which is 0.0 for
@@ -756,7 +776,8 @@ class BertModel(BertPreTrainedModel):
756
776
  """
757
777
  output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
758
778
  output_hidden_states = (
759
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
779
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
780
+ )
760
781
  return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
761
782
 
762
783
  # use_cache = use_cache if use_cache is not None else self.config.use_cache
@@ -766,7 +787,8 @@ class BertModel(BertPreTrainedModel):
766
787
 
767
788
  # past_key_values_length
768
789
  past_key_values_length = (
769
- past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0)
790
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
791
+ )
770
792
 
771
793
  query_length = query_embeds.shape[1] if query_embeds is not None else 0
772
794
 
@@ -54,16 +54,18 @@ class Blip2Qformer(Blip2Base):
54
54
 
55
55
  self.tokenizer = self.init_tokenizer()
56
56
 
57
- self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate,
58
- use_grad_checkpoint, vit_precision)
57
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
58
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
59
+ )
59
60
  if freeze_vit:
60
61
  for name, param in self.visual_encoder.named_parameters():
61
62
  param.requires_grad = False
62
63
  self.visual_encoder = self.visual_encoder.eval()
63
64
  self.visual_encoder.train = disabled_train
64
65
  logging.info('freeze vision encoder')
65
- self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, self.visual_encoder.num_features,
66
- cross_attention_freq)
66
+ self.Qformer, self.query_tokens = self.init_Qformer(
67
+ num_query_token, self.visual_encoder.num_features, cross_attention_freq
68
+ )
67
69
  self.Qformer.resize_token_embeddings(len(self.tokenizer))
68
70
  state_dict = self.Qformer.state_dict()
69
71
  for name, param in self.Qformer.named_parameters():
@@ -135,8 +137,10 @@ class Blip2Qformer(Blip2Base):
135
137
  bs = image.size(0)
136
138
  targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image.device)
137
139
 
138
- loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
139
- + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2
140
+ loss_itc = (
141
+ F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
142
+ + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
143
+ ) / 2
140
144
 
141
145
  ###============== Image-text Matching ===================###
142
146
  text_input_ids_world = concat_all_gather(text_tokens.input_ids)
@@ -274,7 +278,8 @@ class Blip2Qformer(Blip2Base):
274
278
  top_p=top_p,
275
279
  eos_token_id=self.tokenizer.sep_token_id,
276
280
  pad_token_id=self.tokenizer.pad_token_id,
277
- **model_kwargs)
281
+ **model_kwargs
282
+ )
278
283
  captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
279
284
  return captions
280
285
 
@@ -66,8 +66,9 @@ class Blip2T5(Blip2Base):
66
66
 
67
67
  self.tokenizer = self.init_tokenizer()
68
68
 
69
- self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate,
70
- use_grad_checkpoint, vit_precision)
69
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
70
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
71
+ )
71
72
  if freeze_vit:
72
73
  for name, param in self.visual_encoder.named_parameters():
73
74
  param.requires_grad = False
@@ -136,8 +137,9 @@ class Blip2T5(Blip2Base):
136
137
 
137
138
  encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
138
139
 
139
- targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id,
140
- -100)
140
+ targets = output_tokens.input_ids.masked_fill(
141
+ output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
142
+ )
141
143
 
142
144
  inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
143
145
  inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
@@ -234,17 +236,19 @@ class Blip2T5(Blip2Base):
234
236
 
235
237
  return output_text
236
238
 
237
- def predict_answers(self,
238
- samples,
239
- num_beams=5,
240
- inference_method='generate',
241
- max_len=10,
242
- min_len=1,
243
- num_ans_candidates=128,
244
- answer_list=None,
245
- prompt='',
246
- length_penalty=-1,
247
- **kwargs):
239
+ def predict_answers(
240
+ self,
241
+ samples,
242
+ num_beams=5,
243
+ inference_method='generate',
244
+ max_len=10,
245
+ min_len=1,
246
+ num_ans_candidates=128,
247
+ answer_list=None,
248
+ prompt='',
249
+ length_penalty=-1,
250
+ **kwargs
251
+ ):
248
252
  image = samples['image']
249
253
  with self.maybe_autocast():
250
254
  image_embeds = self.ln_vision(self.visual_encoder(image))
@@ -318,13 +322,15 @@ class Blip2T5(Blip2Base):
318
322
 
319
323
  self._lemmatizer = spacy.load('en_core_web_sm')
320
324
  except ImportError:
321
- logging.error("""
325
+ logging.error(
326
+ """
322
327
  Please install spacy and en_core_web_sm model to apply lemmatization.
323
328
  python -m spacy download en_core_web_sm
324
329
  OR
325
330
  import spacy.cli
326
331
  spacy.cli.download("en_core_web_sm")
327
- """)
332
+ """
333
+ )
328
334
  exit(1)
329
335
 
330
336
  return self._lemmatizer