evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (302) hide show
  1. evalscope/__init__.py +4 -1
  2. evalscope/api/benchmark/__init__.py +3 -0
  3. evalscope/api/benchmark/adapters/__init__.py +5 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
  7. evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
  8. evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
  9. evalscope/api/benchmark/benchmark.py +356 -0
  10. evalscope/api/benchmark/meta.py +121 -0
  11. evalscope/api/dataset/__init__.py +2 -0
  12. evalscope/api/dataset/dataset.py +349 -0
  13. evalscope/api/dataset/loader.py +262 -0
  14. evalscope/api/dataset/utils.py +143 -0
  15. evalscope/api/evaluator/__init__.py +3 -0
  16. evalscope/api/evaluator/cache.py +378 -0
  17. evalscope/api/evaluator/evaluator.py +56 -0
  18. evalscope/api/evaluator/state.py +275 -0
  19. evalscope/api/filter/__init__.py +1 -0
  20. evalscope/api/filter/filter.py +72 -0
  21. evalscope/api/messages/__init__.py +12 -0
  22. evalscope/api/messages/chat_message.py +243 -0
  23. evalscope/api/messages/content.py +102 -0
  24. evalscope/api/messages/utils.py +35 -0
  25. evalscope/api/metric/__init__.py +2 -0
  26. evalscope/api/metric/metric.py +55 -0
  27. evalscope/api/metric/scorer.py +113 -0
  28. evalscope/api/mixin/__init__.py +1 -0
  29. evalscope/api/mixin/llm_judge_mixin.py +168 -0
  30. evalscope/api/model/__init__.py +12 -0
  31. evalscope/api/model/generate_config.py +155 -0
  32. evalscope/api/model/model.py +386 -0
  33. evalscope/api/model/model_output.py +285 -0
  34. evalscope/api/registry.py +182 -0
  35. evalscope/api/tool/__init__.py +3 -0
  36. evalscope/api/tool/tool_call.py +101 -0
  37. evalscope/api/tool/tool_info.py +173 -0
  38. evalscope/api/tool/utils.py +64 -0
  39. evalscope/app/app.py +3 -0
  40. evalscope/app/ui/app_ui.py +2 -1
  41. evalscope/app/ui/multi_model.py +50 -25
  42. evalscope/app/ui/single_model.py +26 -14
  43. evalscope/app/utils/data_utils.py +43 -27
  44. evalscope/app/utils/env_utils.py +12 -0
  45. evalscope/app/utils/text_utils.py +14 -14
  46. evalscope/app/utils/visualization.py +9 -4
  47. evalscope/arguments.py +7 -10
  48. evalscope/backend/opencompass/api_meta_template.py +2 -1
  49. evalscope/backend/opencompass/backend_manager.py +6 -5
  50. evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
  51. evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
  52. evalscope/backend/rag_eval/ragas/task_template.py +2 -1
  53. evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
  54. evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
  55. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
  56. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
  57. evalscope/backend/rag_eval/utils/embedding.py +10 -1
  58. evalscope/backend/rag_eval/utils/llm.py +13 -12
  59. evalscope/benchmarks/__init__.py +0 -2
  60. evalscope/benchmarks/aime/aime24_adapter.py +38 -40
  61. evalscope/benchmarks/aime/aime25_adapter.py +34 -40
  62. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
  63. evalscope/benchmarks/arc/arc_adapter.py +34 -147
  64. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
  65. evalscope/benchmarks/arena_hard/utils.py +37 -1
  66. evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
  67. evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
  68. evalscope/benchmarks/bfcl/generation.py +222 -0
  69. evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
  70. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
  71. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
  72. evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
  73. evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
  74. evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
  75. evalscope/benchmarks/docmath/utils.py +4 -5
  76. evalscope/benchmarks/drop/drop_adapter.py +88 -40
  77. evalscope/benchmarks/frames/frames_adapter.py +136 -52
  78. evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
  79. evalscope/benchmarks/general_arena/utils.py +23 -27
  80. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
  81. evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
  82. evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
  83. evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
  84. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
  85. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
  86. evalscope/benchmarks/hle/hle_adapter.py +127 -93
  87. evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
  88. evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
  89. evalscope/benchmarks/ifeval/instructions.py +109 -64
  90. evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
  91. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  92. evalscope/benchmarks/ifeval/utils.py +6 -7
  93. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  94. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  95. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  96. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  97. evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
  98. evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
  99. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
  100. evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
  101. evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
  102. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
  103. evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
  104. evalscope/benchmarks/math_vista/__init__.py +0 -0
  105. evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
  106. evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
  107. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
  108. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
  109. evalscope/benchmarks/mmmu/__init__.py +0 -0
  110. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  111. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  112. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
  113. evalscope/benchmarks/musr/musr_adapter.py +33 -64
  114. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
  115. evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
  116. evalscope/benchmarks/race/race_adapter.py +33 -119
  117. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
  118. evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
  119. evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
  120. evalscope/benchmarks/super_gpqa/utils.py +2 -1
  121. evalscope/benchmarks/tau_bench/generation.py +147 -0
  122. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
  123. evalscope/benchmarks/text2image/__init__.py +0 -0
  124. evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
  125. evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
  126. evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
  127. evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
  128. evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
  129. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
  130. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
  131. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
  132. evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
  133. evalscope/cli/cli.py +2 -0
  134. evalscope/cli/start_app.py +7 -1
  135. evalscope/cli/start_perf.py +7 -1
  136. evalscope/cli/start_server.py +6 -3
  137. evalscope/collections/__init__.py +2 -10
  138. evalscope/collections/sampler.py +10 -10
  139. evalscope/collections/schema.py +13 -11
  140. evalscope/config.py +157 -57
  141. evalscope/constants.py +37 -61
  142. evalscope/evaluator/__init__.py +1 -1
  143. evalscope/evaluator/evaluator.py +275 -419
  144. evalscope/filters/__init__.py +2 -0
  145. evalscope/filters/extraction.py +126 -0
  146. evalscope/filters/selection.py +57 -0
  147. evalscope/metrics/__init__.py +13 -13
  148. evalscope/metrics/llm_judge.py +47 -33
  149. evalscope/metrics/math_parser.py +27 -22
  150. evalscope/metrics/metric.py +307 -0
  151. evalscope/metrics/metrics.py +22 -18
  152. evalscope/metrics/t2v_metrics/__init__.py +0 -52
  153. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
  154. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
  155. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
  156. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
  157. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
  158. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
  159. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
  160. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
  161. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
  162. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
  163. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
  164. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
  165. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
  166. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
  167. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
  168. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
  169. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
  170. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
  171. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
  172. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
  173. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
  174. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
  175. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
  176. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
  177. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
  178. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
  179. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
  180. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
  181. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
  182. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
  183. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
  184. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
  185. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
  186. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
  187. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
  188. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
  189. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
  190. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
  191. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
  192. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
  193. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
  194. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
  195. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
  196. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
  197. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
  198. evalscope/models/__init__.py +6 -29
  199. evalscope/models/image_edit_model.py +125 -0
  200. evalscope/models/mockllm.py +65 -0
  201. evalscope/models/model_apis.py +67 -0
  202. evalscope/models/modelscope.py +455 -0
  203. evalscope/models/openai_compatible.py +126 -0
  204. evalscope/models/text2image_model.py +124 -0
  205. evalscope/models/utils/openai.py +701 -0
  206. evalscope/perf/benchmark.py +4 -1
  207. evalscope/perf/http_client.py +4 -2
  208. evalscope/perf/plugin/api/custom_api.py +5 -4
  209. evalscope/perf/plugin/api/openai_api.py +11 -9
  210. evalscope/perf/plugin/datasets/custom.py +2 -1
  211. evalscope/perf/plugin/datasets/flickr8k.py +1 -1
  212. evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
  213. evalscope/perf/plugin/datasets/line_by_line.py +2 -1
  214. evalscope/perf/plugin/datasets/longalpaca.py +2 -1
  215. evalscope/perf/plugin/datasets/openqa.py +4 -2
  216. evalscope/perf/utils/benchmark_util.py +15 -10
  217. evalscope/perf/utils/db_util.py +9 -6
  218. evalscope/perf/utils/local_server.py +11 -3
  219. evalscope/perf/utils/rich_display.py +16 -10
  220. evalscope/report/__init__.py +2 -3
  221. evalscope/report/combinator.py +18 -12
  222. evalscope/report/generator.py +51 -35
  223. evalscope/report/{utils.py → report.py} +8 -6
  224. evalscope/run.py +33 -47
  225. evalscope/summarizer.py +1 -1
  226. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  227. evalscope/utils/__init__.py +21 -2
  228. evalscope/utils/chat_service.py +3 -2
  229. evalscope/utils/deprecation_utils.py +12 -1
  230. evalscope/utils/function_utils.py +29 -0
  231. evalscope/utils/import_utils.py +23 -1
  232. evalscope/utils/io_utils.py +142 -6
  233. evalscope/utils/json_schema.py +208 -0
  234. evalscope/utils/logger.py +51 -12
  235. evalscope/utils/model_utils.py +11 -7
  236. evalscope/utils/multi_choices.py +288 -0
  237. evalscope/utils/url_utils.py +65 -0
  238. evalscope/version.py +2 -2
  239. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
  240. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
  241. tests/benchmark/test_eval.py +385 -0
  242. tests/benchmark/test_image_edit.py +65 -0
  243. tests/{aigc → benchmark}/test_t2i.py +22 -4
  244. tests/benchmark/test_vlm.py +80 -0
  245. tests/cli/test_all.py +85 -47
  246. tests/cli/test_collection.py +20 -8
  247. tests/cli/test_custom.py +22 -15
  248. tests/cli/test_reasoning.py +81 -0
  249. tests/common.py +73 -0
  250. tests/perf/test_perf.py +4 -2
  251. tests/rag/test_clip_benchmark.py +0 -2
  252. evalscope/benchmarks/aigc/t2i/base.py +0 -56
  253. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
  254. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
  255. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
  256. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
  257. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
  258. evalscope/benchmarks/arc/ai2_arc.py +0 -151
  259. evalscope/benchmarks/benchmark.py +0 -81
  260. evalscope/benchmarks/ceval/ceval_exam.py +0 -146
  261. evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
  262. evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
  263. evalscope/benchmarks/competition_math/competition_math.py +0 -79
  264. evalscope/benchmarks/data_adapter.py +0 -528
  265. evalscope/benchmarks/filters.py +0 -59
  266. evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
  267. evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
  268. evalscope/benchmarks/humaneval/humaneval.py +0 -79
  269. evalscope/benchmarks/mmlu/mmlu.py +0 -160
  270. evalscope/benchmarks/mmlu/samples.jsonl +0 -5
  271. evalscope/benchmarks/process_bench/critique_template.txt +0 -13
  272. evalscope/benchmarks/race/race.py +0 -104
  273. evalscope/benchmarks/race/samples.jsonl +0 -5
  274. evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
  275. evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
  276. evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
  277. evalscope/benchmarks/utils.py +0 -60
  278. evalscope/collections/evaluator.py +0 -375
  279. evalscope/metrics/completion_parsers.py +0 -227
  280. evalscope/metrics/named_metrics.py +0 -55
  281. evalscope/models/adapters/__init__.py +0 -14
  282. evalscope/models/adapters/base_adapter.py +0 -84
  283. evalscope/models/adapters/bfcl_adapter.py +0 -246
  284. evalscope/models/adapters/chat_adapter.py +0 -207
  285. evalscope/models/adapters/choice_adapter.py +0 -222
  286. evalscope/models/adapters/custom_adapter.py +0 -71
  287. evalscope/models/adapters/server_adapter.py +0 -236
  288. evalscope/models/adapters/t2i_adapter.py +0 -79
  289. evalscope/models/adapters/tau_bench_adapter.py +0 -189
  290. evalscope/models/custom/__init__.py +0 -4
  291. evalscope/models/custom/custom_model.py +0 -50
  292. evalscope/models/custom/dummy_model.py +0 -99
  293. evalscope/models/local_model.py +0 -128
  294. evalscope/models/register.py +0 -41
  295. tests/cli/test_run.py +0 -489
  296. /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
  297. /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
  298. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
  299. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
  300. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
  301. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
  302. /tests/{aigc → benchmark}/__init__.py +0 -0
@@ -61,8 +61,9 @@ class Blip2T5Instruct(Blip2Base):
61
61
 
62
62
  self.tokenizer = self.init_tokenizer(truncation_side='left')
63
63
 
64
- self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate,
65
- use_grad_checkpoint, vit_precision)
64
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
65
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
66
+ )
66
67
  if freeze_vit:
67
68
  for name, param in self.visual_encoder.named_parameters():
68
69
  param.requires_grad = False
@@ -171,8 +172,9 @@ class Blip2T5Instruct(Blip2Base):
171
172
 
172
173
  encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
173
174
 
174
- targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id,
175
- -100)
175
+ targets = output_tokens.input_ids.masked_fill(
176
+ output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
177
+ )
176
178
 
177
179
  inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
178
180
  inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
@@ -196,7 +198,8 @@ class Blip2T5Instruct(Blip2Base):
196
198
  this_n_fs = random.choices(
197
199
  list(range(self.num_few_shot_examples + 1)),
198
200
  weights=[1 - self.few_shot_prob]
199
- + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples)[0]
201
+ + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples
202
+ )[0]
200
203
 
201
204
  if this_n_fs == 0:
202
205
  return None, None
@@ -263,7 +266,8 @@ class Blip2T5Instruct(Blip2Base):
263
266
  encoder_atts = encoder_atts.reshape(encoder_atts.size(0) // this_n_fs, encoder_atts.size(1) * this_n_fs)
264
267
  inputs_embeds = inputs_embeds.reshape(
265
268
  inputs_embeds.size(0) // this_n_fs,
266
- inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2))
269
+ inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2)
270
+ )
267
271
 
268
272
  return inputs_embeds, encoder_atts
269
273
 
@@ -397,17 +401,19 @@ class Blip2T5Instruct(Blip2Base):
397
401
 
398
402
  return output_text
399
403
 
400
- def predict_answers(self,
401
- samples,
402
- num_beams=5,
403
- inference_method='generate',
404
- max_len=10,
405
- min_len=1,
406
- num_ans_candidates=128,
407
- answer_list=None,
408
- prompt='',
409
- length_penalty=-1,
410
- **kwargs):
404
+ def predict_answers(
405
+ self,
406
+ samples,
407
+ num_beams=5,
408
+ inference_method='generate',
409
+ max_len=10,
410
+ min_len=1,
411
+ num_ans_candidates=128,
412
+ answer_list=None,
413
+ prompt='',
414
+ length_penalty=-1,
415
+ **kwargs
416
+ ):
411
417
  if isinstance(samples['text_input'], str):
412
418
  samples['text_input'] = [samples['text_input']]
413
419
 
@@ -434,7 +440,8 @@ class Blip2T5Instruct(Blip2Base):
434
440
  samples['prompt'] = text_input
435
441
 
436
442
  output_text = self.generate(
437
- samples, num_beams=num_beams, max_length=max_len, min_length=min_len, length_penalty=length_penalty)
443
+ samples, num_beams=num_beams, max_length=max_len, min_length=min_len, length_penalty=length_penalty
444
+ )
438
445
 
439
446
  if self._apply_lemmatizer or ('apply_lemmatizer' in samples.keys() and samples['apply_lemmatizer']):
440
447
  output_text = self._lemmatize(output_text)
@@ -530,8 +537,8 @@ class Blip2T5Instruct(Blip2Base):
530
537
  query_tokens = self.query_tokens.expand(bs, -1, -1)
531
538
  if self.qformer_text_input:
532
539
  text_Qformer = self.tokenizer(
533
- prompt, padding='longest', truncation=True, max_length=self.max_txt_len,
534
- return_tensors='pt').to(image.device)
540
+ prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors='pt'
541
+ ).to(image.device)
535
542
  query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
536
543
  Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
537
544
 
@@ -625,7 +632,8 @@ class Blip2T5Instruct(Blip2Base):
625
632
  this_output_tokens_atts = output_tokens.attention_mask[start_i:end_i].repeat(bs, 1)
626
633
 
627
634
  this_targets = this_output_tokens_ids.masked_fill(
628
- this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100)
635
+ this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100
636
+ )
629
637
 
630
638
  outputs = self.t5_model(
631
639
  encoder_outputs=this_encoder_outputs,
@@ -692,13 +700,15 @@ class Blip2T5Instruct(Blip2Base):
692
700
 
693
701
  self._lemmatizer = spacy.load('en_core_web_sm')
694
702
  except ImportError:
695
- logging.error("""
703
+ logging.error(
704
+ """
696
705
  Please install spacy and en_core_web_sm model to apply lemmatization.
697
706
  python -m spacy download en_core_web_sm
698
707
  OR
699
708
  import spacy.cli
700
709
  spacy.cli.download("en_core_web_sm")
701
- """)
710
+ """
711
+ )
702
712
  exit(1)
703
713
 
704
714
  return self._lemmatizer
@@ -32,7 +32,8 @@ class MLP(nn.Module):
32
32
  # nn.Dropout(0.1),
33
33
  nn.Linear(64, 16),
34
34
  nn.ReLU(),
35
- nn.Linear(16, 1))
35
+ nn.Linear(16, 1)
36
+ )
36
37
 
37
38
  # initial MLP param
38
39
  for name, param in self.layers.named_parameters():
@@ -23,12 +23,19 @@ import torch.utils.checkpoint
23
23
  from torch import nn
24
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
25
  from transformers.activations import ACT2FN
26
- from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast,
27
- SequenceClassifierOutputWithPast)
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPast,
28
+ CausalLMOutputWithPast,
29
+ SequenceClassifierOutputWithPast,
30
+ )
28
31
  from transformers.modeling_utils import PreTrainedModel
29
32
  from transformers.models.llama.configuration_llama import LlamaConfig
30
- from transformers.utils import (add_start_docstrings, add_start_docstrings_to_model_forward, logging,
31
- replace_return_docstrings)
33
+ from transformers.utils import (
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
32
39
  from typing import List, Optional, Tuple, Union
33
40
 
34
41
  logger = logging.get_logger(__name__)
@@ -37,10 +44,9 @@ _CONFIG_FOR_DOC = 'LlamaConfig'
37
44
 
38
45
 
39
46
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
40
- def _make_causal_mask(input_ids_shape: torch.Size,
41
- dtype: torch.dtype,
42
- device: torch.device,
43
- past_key_values_length: int = 0):
47
+ def _make_causal_mask(
48
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
49
+ ):
44
50
  """
45
51
  Make causal mask used for bi-directional self-attention.
46
52
  """
@@ -171,8 +177,10 @@ class LlamaAttention(nn.Module):
171
177
  self.max_position_embeddings = config.max_position_embeddings
172
178
 
173
179
  if (self.head_dim * self.num_heads) != self.hidden_size:
174
- raise ValueError(f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
175
- f' and `num_heads`: {self.num_heads}).')
180
+ raise ValueError(
181
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
182
+ f' and `num_heads`: {self.num_heads}).'
183
+ )
176
184
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
185
  self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
186
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
@@ -214,13 +222,16 @@ class LlamaAttention(nn.Module):
214
222
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
215
223
 
216
224
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
217
- raise ValueError(f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
218
- f' {attn_weights.size()}')
225
+ raise ValueError(
226
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
227
+ f' {attn_weights.size()}'
228
+ )
219
229
 
220
230
  if attention_mask is not None:
221
231
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
222
232
  raise ValueError(
223
- f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}')
233
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
234
+ )
224
235
  attn_weights = attn_weights + attention_mask
225
236
  attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
226
237
 
@@ -229,8 +240,10 @@ class LlamaAttention(nn.Module):
229
240
  attn_output = torch.matmul(attn_weights, value_states)
230
241
 
231
242
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
232
- raise ValueError(f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
233
- f' {attn_output.size()}')
243
+ raise ValueError(
244
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
245
+ f' {attn_output.size()}'
246
+ )
234
247
 
235
248
  attn_output = attn_output.transpose(1, 2)
236
249
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -466,10 +479,11 @@ class LlamaModel(LlamaPreTrainedModel):
466
479
 
467
480
  if attention_mask is not None:
468
481
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
469
- expanded_attn_mask = _expand_mask(
470
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
482
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
483
+ tgt_len=input_shape[-1]).to(inputs_embeds.device)
471
484
  combined_attention_mask = (
472
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask)
485
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
486
+ )
473
487
 
474
488
  return combined_attention_mask
475
489
 
@@ -488,7 +502,8 @@ class LlamaModel(LlamaPreTrainedModel):
488
502
  ) -> Union[Tuple, BaseModelOutputWithPast]:
489
503
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
504
  output_hidden_states = (
491
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
505
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
506
+ )
492
507
  use_cache = use_cache if use_cache is not None else self.config.use_cache
493
508
 
494
509
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -513,7 +528,8 @@ class LlamaModel(LlamaPreTrainedModel):
513
528
  if position_ids is None:
514
529
  device = input_ids.device if input_ids is not None else inputs_embeds.device
515
530
  position_ids = torch.arange(
516
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
531
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
532
+ )
517
533
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
518
534
  else:
519
535
  position_ids = position_ids.view(-1, seq_length).long()
@@ -525,15 +541,17 @@ class LlamaModel(LlamaPreTrainedModel):
525
541
  attention_mask = torch.ones((batch_size, seq_length_with_past),
526
542
  dtype=torch.bool,
527
543
  device=inputs_embeds.device)
528
- attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
529
- past_key_values_length)
544
+ attention_mask = self._prepare_decoder_attention_mask(
545
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
546
+ )
530
547
 
531
548
  hidden_states = inputs_embeds
532
549
 
533
550
  if self.gradient_checkpointing and self.training:
534
551
  if use_cache:
535
552
  logger.warning_once(
536
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
553
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
554
+ )
537
555
  use_cache = False
538
556
 
539
557
  # decoder layers
@@ -672,7 +690,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
672
690
 
673
691
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
674
692
  output_hidden_states = (
675
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
693
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
694
+ )
676
695
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
677
696
 
678
697
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -719,12 +738,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
719
738
  attentions=outputs.attentions,
720
739
  )
721
740
 
722
- def prepare_inputs_for_generation(self,
723
- input_ids,
724
- past_key_values=None,
725
- attention_mask=None,
726
- inputs_embeds=None,
727
- **kwargs):
741
+ def prepare_inputs_for_generation(
742
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
743
+ ):
728
744
  if past_key_values:
729
745
  input_ids = input_ids[:, -1:]
730
746
 
@@ -22,13 +22,24 @@ from torch import nn
22
22
  from torch.nn import CrossEntropyLoss
23
23
  from torch.utils.checkpoint import checkpoint
24
24
  from transformers.activations import ACT2FN
25
- from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput,
26
- Seq2SeqModelOutput)
25
+ from transformers.modeling_outputs import (
26
+ BaseModelOutput,
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ Seq2SeqLMOutput,
29
+ Seq2SeqModelOutput,
30
+ )
27
31
  from transformers.modeling_utils import PreTrainedModel
28
32
  from transformers.models.t5.configuration_t5 import T5Config
29
33
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
30
- from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward,
31
- is_torch_fx_proxy, logging, replace_return_docstrings)
34
+ from transformers.utils import (
35
+ DUMMY_INPUTS,
36
+ DUMMY_MASK,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ is_torch_fx_proxy,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
32
43
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
33
44
  from typing import Optional, Tuple, Union
34
45
 
@@ -63,8 +74,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
63
74
  import re
64
75
  import tensorflow as tf
65
76
  except ImportError:
66
- logger.error('Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
67
- 'https://www.tensorflow.org/install/ for installation instructions.')
77
+ logger.error(
78
+ 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
79
+ 'https://www.tensorflow.org/install/ for installation instructions.'
80
+ )
68
81
  raise
69
82
  tf_path = os.path.abspath(tf_checkpoint_path)
70
83
  logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
@@ -82,13 +95,15 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
82
95
  name = txt_name.split('/')
83
96
  # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
84
97
  # which are not required for using pretrained model
85
- if any(n in [
98
+ if any(
99
+ n in [
86
100
  'adam_v',
87
101
  'adam_m',
88
102
  'AdamWeightDecayOptimizer',
89
103
  'AdamWeightDecayOptimizer_1',
90
104
  'global_step',
91
- ] for n in name):
105
+ ] for n in name
106
+ ):
92
107
  logger.info(f"Skipping {'/'.join(name)}")
93
108
  tf_weights.pop(txt_name, None)
94
109
  continue
@@ -149,7 +164,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
149
164
  array = np.transpose(array)
150
165
  try:
151
166
  assert (
152
- pointer.shape == array.shape), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
167
+ pointer.shape == array.shape
168
+ ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
153
169
  except AssertionError as e:
154
170
  e.args += (pointer.shape, array.shape)
155
171
  raise
@@ -392,9 +408,10 @@ class T5Attention(nn.Module):
392
408
  is_small = relative_position < max_exact
393
409
 
394
410
  # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
395
- relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact)
396
- / math.log(max_distance / max_exact) *
397
- (num_buckets - max_exact)).to(torch.long)
411
+ relative_position_if_large = max_exact + (
412
+ torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) *
413
+ (num_buckets - max_exact)
414
+ ).to(torch.long)
398
415
  relative_position_if_large = torch.min(
399
416
  relative_position_if_large,
400
417
  torch.full_like(relative_position_if_large, num_buckets - 1),
@@ -497,8 +514,9 @@ class T5Attention(nn.Module):
497
514
  )
498
515
 
499
516
  # compute scores
500
- scores = torch.matmul(query_states, key_states.transpose(
501
- 3, 2)) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
517
+ scores = torch.matmul(
518
+ query_states, key_states.transpose(3, 2)
519
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
502
520
 
503
521
  if position_bias is None:
504
522
  if not self.has_relative_attention_bias:
@@ -528,10 +546,11 @@ class T5Attention(nn.Module):
528
546
  position_bias_masked = position_bias
529
547
 
530
548
  scores += position_bias_masked
531
- attn_weights = nn.functional.softmax(
532
- scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
549
+ attn_weights = nn.functional.softmax(scores.float(),
550
+ dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
533
551
  attn_weights = nn.functional.dropout(
534
- attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
552
+ attn_weights, p=self.dropout, training=self.training
553
+ ) # (batch_size, n_heads, seq_length, key_length)
535
554
 
536
555
  # Mask heads if we want to
537
556
  if layer_head_mask is not None:
@@ -655,7 +674,8 @@ class T5Block(nn.Module):
655
674
  raise ValueError(
656
675
  f'There should be {expected_num_past_key_values} past states. '
657
676
  f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
658
- f'Got {len(past_key_value)} past key / value states')
677
+ f'Got {len(past_key_value)} past key / value states'
678
+ )
659
679
 
660
680
  self_attn_past_key_value = past_key_value[:2]
661
681
  cross_attn_past_key_value = past_key_value[2:]
@@ -809,7 +829,8 @@ class T5PreTrainedModel(PreTrainedModel):
809
829
 
810
830
  assert decoder_start_token_id is not None, (
811
831
  'self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id.'
812
- ' See T5 docs for more information')
832
+ ' See T5 docs for more information'
833
+ )
813
834
 
814
835
  # shift inputs to the right
815
836
  if is_torch_fx_proxy(input_ids):
@@ -836,8 +857,9 @@ class T5Stack(T5PreTrainedModel):
836
857
  self.embed_tokens = embed_tokens
837
858
  self.is_decoder = config.is_decoder
838
859
 
839
- self.block = nn.ModuleList(
840
- [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)])
860
+ self.block = nn.ModuleList([
861
+ T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)
862
+ ])
841
863
  self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
842
864
  self.dropout = nn.Dropout(config.dropout_rate)
843
865
 
@@ -852,7 +874,8 @@ class T5Stack(T5PreTrainedModel):
852
874
  def parallelize(self, device_map=None):
853
875
  # Check validity of device_map
854
876
  self.device_map = (
855
- get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map)
877
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
878
+ )
856
879
  assert_device_map(self.device_map, len(self.block))
857
880
  self.model_parallel = True
858
881
  self.first_device = ('cpu' if 'cpu' in self.device_map.keys() else 'cuda:' + str(min(self.device_map.keys())))
@@ -908,13 +931,15 @@ class T5Stack(T5PreTrainedModel):
908
931
  use_cache = use_cache if use_cache is not None else self.config.use_cache
909
932
  output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
910
933
  output_hidden_states = (
911
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
934
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
935
+ )
912
936
  return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
913
937
 
914
938
  if input_ids is not None and inputs_embeds is not None:
915
939
  err_msg_prefix = 'decoder_' if self.is_decoder else ''
916
940
  raise ValueError(
917
- f'You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time')
941
+ f'You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time'
942
+ )
918
943
  elif input_ids is not None:
919
944
  input_shape = input_ids.size()
920
945
  input_ids = input_ids.view(-1, input_shape[-1])
@@ -1009,7 +1034,8 @@ class T5Stack(T5PreTrainedModel):
1009
1034
  if self.gradient_checkpointing and self.training:
1010
1035
  if use_cache:
1011
1036
  logger.warning(
1012
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
1037
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
1038
+ )
1013
1039
  use_cache = False
1014
1040
 
1015
1041
  def create_custom_forward(module):
@@ -1082,13 +1108,15 @@ class T5Stack(T5PreTrainedModel):
1082
1108
  all_hidden_states = all_hidden_states + (hidden_states, )
1083
1109
 
1084
1110
  if not return_dict:
1085
- return tuple(v for v in [
1086
- hidden_states,
1087
- present_key_value_states,
1088
- all_hidden_states,
1089
- all_attentions,
1090
- all_cross_attentions,
1091
- ] if v is not None)
1111
+ return tuple(
1112
+ v for v in [
1113
+ hidden_states,
1114
+ present_key_value_states,
1115
+ all_hidden_states,
1116
+ all_attentions,
1117
+ all_cross_attentions,
1118
+ ] if v is not None
1119
+ )
1092
1120
  return BaseModelOutputWithPastAndCrossAttentions(
1093
1121
  last_hidden_state=hidden_states,
1094
1122
  past_key_values=present_key_value_states,
@@ -1298,7 +1326,8 @@ class T5Model(T5PreTrainedModel):
1298
1326
  def parallelize(self, device_map=None):
1299
1327
  self.device_map = (
1300
1328
  get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1301
- if device_map is None else device_map)
1329
+ if device_map is None else device_map
1330
+ )
1302
1331
  assert_device_map(self.device_map, len(self.encoder.block))
1303
1332
  self.encoder.parallelize(self.device_map)
1304
1333
  self.decoder.parallelize(self.device_map)
@@ -1493,7 +1522,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
1493
1522
  def parallelize(self, device_map=None):
1494
1523
  self.device_map = (
1495
1524
  get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1496
- if device_map is None else device_map)
1525
+ if device_map is None else device_map
1526
+ )
1497
1527
  assert_device_map(self.device_map, len(self.encoder.block))
1498
1528
  self.encoder.parallelize(self.device_map)
1499
1529
  self.decoder.parallelize(self.device_map)
@@ -1731,8 +1761,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
1731
1761
  reordered_layer_past_states = ()
1732
1762
  for layer_past_state in layer_past_states:
1733
1763
  # need to set correct `past` for each of the four key / value states
1734
- reordered_layer_past_states = reordered_layer_past_states + (layer_past_state.index_select(
1735
- 0, beam_idx.to(layer_past_state.device)), )
1764
+ reordered_layer_past_states = reordered_layer_past_states + (
1765
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1766
+ )
1736
1767
 
1737
1768
  assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1738
1769
  assert len(reordered_layer_past_states) == len(layer_past_states)
@@ -1770,7 +1801,8 @@ class T5EncoderModel(T5PreTrainedModel):
1770
1801
  def parallelize(self, device_map=None):
1771
1802
  self.device_map = (
1772
1803
  get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1773
- if device_map is None else device_map)
1804
+ if device_map is None else device_map
1805
+ )
1774
1806
  assert_device_map(self.device_map, len(self.encoder.block))
1775
1807
  self.encoder.parallelize(self.device_map)
1776
1808
  self.model_parallel = True
@@ -26,7 +26,8 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
26
26
  depth=0,
27
27
  ):
28
28
  assert isinstance(decoder_pointer, nn.Module) and isinstance(
29
- encoder_pointer, nn.Module), f'{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module'
29
+ encoder_pointer, nn.Module
30
+ ), f'{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module'
30
31
  if hasattr(decoder_pointer, 'weight') and skip_key not in module_name:
31
32
  assert hasattr(encoder_pointer, 'weight')
32
33
  encoder_pointer.weight = decoder_pointer.weight
@@ -39,8 +40,9 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
39
40
  encoder_modules = encoder_pointer._modules
40
41
  decoder_modules = decoder_pointer._modules
41
42
  if len(decoder_modules) > 0:
42
- assert (len(encoder_modules) >
43
- 0), f'Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}'
43
+ assert (
44
+ len(encoder_modules) > 0
45
+ ), f'Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}'
44
46
 
45
47
  all_encoder_weights = set([module_name + '/' + sub_name for sub_name in encoder_modules.keys()])
46
48
  encoder_layer_pos = 0
@@ -49,8 +51,8 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
49
51
  encoder_name = str(int(name) + encoder_layer_pos)
50
52
  decoder_name = name
51
53
  if not isinstance(
52
- decoder_modules[decoder_name],
53
- type(encoder_modules[encoder_name]),
54
+ decoder_modules[decoder_name],
55
+ type(encoder_modules[encoder_name]),
54
56
  ) and len(encoder_modules) != len(decoder_modules):
55
57
  # this can happen if the name corresponds to the position in a list module list of layers
56
58
  # in this case the decoder has added a cross-attention that the encoder does not have
@@ -37,11 +37,13 @@ class BlipBase(BaseModel):
37
37
 
38
38
  state_dict = checkpoint['model']
39
39
 
40
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
41
- self.visual_encoder)
40
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(
41
+ state_dict['visual_encoder.pos_embed'], self.visual_encoder
42
+ )
42
43
  if 'visual_encoder_m.pos_embed' in self.state_dict().keys():
43
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
44
- self.visual_encoder_m)
44
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(
45
+ state_dict['visual_encoder_m.pos_embed'], self.visual_encoder_m
46
+ )
45
47
 
46
48
  for key in self.state_dict().keys():
47
49
  if key in state_dict.keys():
@@ -119,7 +119,8 @@ class BlipITM(BlipBase):
119
119
  elif match_head == 'itc':
120
120
  encoder_input_ids[:, 0] = self.tokenizer.cls_token_id
121
121
  text_output = self.text_encoder(
122
- encoder_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text')
122
+ encoder_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text'
123
+ )
123
124
  image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
124
125
  text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
125
126
 
@@ -155,13 +156,14 @@ def compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6
155
156
  model.zero_grad()
156
157
  loss.backward()
157
158
  with torch.no_grad():
158
- mask = tokenized_text.attention_mask.view(tokenized_text.attention_mask.size(0), 1, -1, 1,
159
- 1) # (bsz,1,token_len, 1,1)
159
+ mask = tokenized_text.attention_mask.view(
160
+ tokenized_text.attention_mask.size(0), 1, -1, 1, 1
161
+ ) # (bsz,1,token_len, 1,1)
160
162
  token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
161
163
  token_length = token_length.cpu()
162
164
  # grads and cams [bsz, num_head, seq_len, image_patch]
163
- grads = model.text_encoder.base_model.base_model.encoder.layer[
164
- block_num].crossattention.self.get_attn_gradients()
165
+ grads = model.text_encoder.base_model.base_model.encoder.layer[block_num
166
+ ].crossattention.self.get_attn_gradients()
165
167
  cams = model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()
166
168
 
167
169
  # assume using vit with 576 num image patch
@@ -157,8 +157,9 @@ class BlipNLVR(BlipBase, MomentumDistilationMixin):
157
157
  raise RuntimeError('checkpoint url or path is invalid')
158
158
  state_dict = checkpoint['model']
159
159
 
160
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
161
- self.visual_encoder)
160
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(
161
+ state_dict['visual_encoder.pos_embed'], self.visual_encoder
162
+ )
162
163
 
163
164
  for key in list(state_dict.keys()):
164
165
  if 'crossattention.self.' in key:
@@ -7,8 +7,11 @@
7
7
 
8
8
  import torch
9
9
  from dataclasses import dataclass
10
- from transformers.modeling_outputs import (BaseModelOutputWithPoolingAndCrossAttentions,
11
- CausalLMOutputWithCrossAttentions, ModelOutput)
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPoolingAndCrossAttentions,
12
+ CausalLMOutputWithCrossAttentions,
13
+ ModelOutput,
14
+ )
12
15
  from typing import Optional
13
16
 
14
17