evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (302) hide show
  1. evalscope/__init__.py +4 -1
  2. evalscope/api/benchmark/__init__.py +3 -0
  3. evalscope/api/benchmark/adapters/__init__.py +5 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
  7. evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
  8. evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
  9. evalscope/api/benchmark/benchmark.py +356 -0
  10. evalscope/api/benchmark/meta.py +121 -0
  11. evalscope/api/dataset/__init__.py +2 -0
  12. evalscope/api/dataset/dataset.py +349 -0
  13. evalscope/api/dataset/loader.py +262 -0
  14. evalscope/api/dataset/utils.py +143 -0
  15. evalscope/api/evaluator/__init__.py +3 -0
  16. evalscope/api/evaluator/cache.py +378 -0
  17. evalscope/api/evaluator/evaluator.py +56 -0
  18. evalscope/api/evaluator/state.py +275 -0
  19. evalscope/api/filter/__init__.py +1 -0
  20. evalscope/api/filter/filter.py +72 -0
  21. evalscope/api/messages/__init__.py +12 -0
  22. evalscope/api/messages/chat_message.py +243 -0
  23. evalscope/api/messages/content.py +102 -0
  24. evalscope/api/messages/utils.py +35 -0
  25. evalscope/api/metric/__init__.py +2 -0
  26. evalscope/api/metric/metric.py +55 -0
  27. evalscope/api/metric/scorer.py +113 -0
  28. evalscope/api/mixin/__init__.py +1 -0
  29. evalscope/api/mixin/llm_judge_mixin.py +168 -0
  30. evalscope/api/model/__init__.py +12 -0
  31. evalscope/api/model/generate_config.py +155 -0
  32. evalscope/api/model/model.py +386 -0
  33. evalscope/api/model/model_output.py +285 -0
  34. evalscope/api/registry.py +182 -0
  35. evalscope/api/tool/__init__.py +3 -0
  36. evalscope/api/tool/tool_call.py +101 -0
  37. evalscope/api/tool/tool_info.py +173 -0
  38. evalscope/api/tool/utils.py +64 -0
  39. evalscope/app/app.py +3 -0
  40. evalscope/app/ui/app_ui.py +2 -1
  41. evalscope/app/ui/multi_model.py +50 -25
  42. evalscope/app/ui/single_model.py +26 -14
  43. evalscope/app/utils/data_utils.py +43 -27
  44. evalscope/app/utils/env_utils.py +12 -0
  45. evalscope/app/utils/text_utils.py +14 -14
  46. evalscope/app/utils/visualization.py +9 -4
  47. evalscope/arguments.py +7 -10
  48. evalscope/backend/opencompass/api_meta_template.py +2 -1
  49. evalscope/backend/opencompass/backend_manager.py +6 -5
  50. evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
  51. evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
  52. evalscope/backend/rag_eval/ragas/task_template.py +2 -1
  53. evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
  54. evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
  55. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
  56. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
  57. evalscope/backend/rag_eval/utils/embedding.py +10 -1
  58. evalscope/backend/rag_eval/utils/llm.py +13 -12
  59. evalscope/benchmarks/__init__.py +0 -2
  60. evalscope/benchmarks/aime/aime24_adapter.py +38 -40
  61. evalscope/benchmarks/aime/aime25_adapter.py +34 -40
  62. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
  63. evalscope/benchmarks/arc/arc_adapter.py +34 -147
  64. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
  65. evalscope/benchmarks/arena_hard/utils.py +37 -1
  66. evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
  67. evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
  68. evalscope/benchmarks/bfcl/generation.py +222 -0
  69. evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
  70. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
  71. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
  72. evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
  73. evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
  74. evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
  75. evalscope/benchmarks/docmath/utils.py +4 -5
  76. evalscope/benchmarks/drop/drop_adapter.py +88 -40
  77. evalscope/benchmarks/frames/frames_adapter.py +136 -52
  78. evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
  79. evalscope/benchmarks/general_arena/utils.py +23 -27
  80. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
  81. evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
  82. evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
  83. evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
  84. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
  85. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
  86. evalscope/benchmarks/hle/hle_adapter.py +127 -93
  87. evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
  88. evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
  89. evalscope/benchmarks/ifeval/instructions.py +109 -64
  90. evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
  91. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  92. evalscope/benchmarks/ifeval/utils.py +6 -7
  93. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  94. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  95. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  96. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  97. evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
  98. evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
  99. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
  100. evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
  101. evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
  102. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
  103. evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
  104. evalscope/benchmarks/math_vista/__init__.py +0 -0
  105. evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
  106. evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
  107. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
  108. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
  109. evalscope/benchmarks/mmmu/__init__.py +0 -0
  110. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  111. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  112. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
  113. evalscope/benchmarks/musr/musr_adapter.py +33 -64
  114. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
  115. evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
  116. evalscope/benchmarks/race/race_adapter.py +33 -119
  117. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
  118. evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
  119. evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
  120. evalscope/benchmarks/super_gpqa/utils.py +2 -1
  121. evalscope/benchmarks/tau_bench/generation.py +147 -0
  122. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
  123. evalscope/benchmarks/text2image/__init__.py +0 -0
  124. evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
  125. evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
  126. evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
  127. evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
  128. evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
  129. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
  130. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
  131. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
  132. evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
  133. evalscope/cli/cli.py +2 -0
  134. evalscope/cli/start_app.py +7 -1
  135. evalscope/cli/start_perf.py +7 -1
  136. evalscope/cli/start_server.py +6 -3
  137. evalscope/collections/__init__.py +2 -10
  138. evalscope/collections/sampler.py +10 -10
  139. evalscope/collections/schema.py +13 -11
  140. evalscope/config.py +157 -57
  141. evalscope/constants.py +37 -61
  142. evalscope/evaluator/__init__.py +1 -1
  143. evalscope/evaluator/evaluator.py +275 -419
  144. evalscope/filters/__init__.py +2 -0
  145. evalscope/filters/extraction.py +126 -0
  146. evalscope/filters/selection.py +57 -0
  147. evalscope/metrics/__init__.py +13 -13
  148. evalscope/metrics/llm_judge.py +47 -33
  149. evalscope/metrics/math_parser.py +27 -22
  150. evalscope/metrics/metric.py +307 -0
  151. evalscope/metrics/metrics.py +22 -18
  152. evalscope/metrics/t2v_metrics/__init__.py +0 -52
  153. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
  154. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
  155. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
  156. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
  157. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
  158. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
  159. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
  160. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
  161. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
  162. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
  163. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
  164. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
  165. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
  166. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
  167. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
  168. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
  169. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
  170. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
  171. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
  172. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
  173. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
  174. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
  175. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
  176. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
  177. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
  178. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
  179. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
  180. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
  181. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
  182. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
  183. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
  184. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
  185. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
  186. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
  187. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
  188. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
  189. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
  190. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
  191. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
  192. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
  193. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
  194. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
  195. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
  196. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
  197. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
  198. evalscope/models/__init__.py +6 -29
  199. evalscope/models/image_edit_model.py +125 -0
  200. evalscope/models/mockllm.py +65 -0
  201. evalscope/models/model_apis.py +67 -0
  202. evalscope/models/modelscope.py +455 -0
  203. evalscope/models/openai_compatible.py +126 -0
  204. evalscope/models/text2image_model.py +124 -0
  205. evalscope/models/utils/openai.py +701 -0
  206. evalscope/perf/benchmark.py +4 -1
  207. evalscope/perf/http_client.py +4 -2
  208. evalscope/perf/plugin/api/custom_api.py +5 -4
  209. evalscope/perf/plugin/api/openai_api.py +11 -9
  210. evalscope/perf/plugin/datasets/custom.py +2 -1
  211. evalscope/perf/plugin/datasets/flickr8k.py +1 -1
  212. evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
  213. evalscope/perf/plugin/datasets/line_by_line.py +2 -1
  214. evalscope/perf/plugin/datasets/longalpaca.py +2 -1
  215. evalscope/perf/plugin/datasets/openqa.py +4 -2
  216. evalscope/perf/utils/benchmark_util.py +15 -10
  217. evalscope/perf/utils/db_util.py +9 -6
  218. evalscope/perf/utils/local_server.py +11 -3
  219. evalscope/perf/utils/rich_display.py +16 -10
  220. evalscope/report/__init__.py +2 -3
  221. evalscope/report/combinator.py +18 -12
  222. evalscope/report/generator.py +51 -35
  223. evalscope/report/{utils.py → report.py} +8 -6
  224. evalscope/run.py +33 -47
  225. evalscope/summarizer.py +1 -1
  226. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  227. evalscope/utils/__init__.py +21 -2
  228. evalscope/utils/chat_service.py +3 -2
  229. evalscope/utils/deprecation_utils.py +12 -1
  230. evalscope/utils/function_utils.py +29 -0
  231. evalscope/utils/import_utils.py +23 -1
  232. evalscope/utils/io_utils.py +142 -6
  233. evalscope/utils/json_schema.py +208 -0
  234. evalscope/utils/logger.py +51 -12
  235. evalscope/utils/model_utils.py +11 -7
  236. evalscope/utils/multi_choices.py +288 -0
  237. evalscope/utils/url_utils.py +65 -0
  238. evalscope/version.py +2 -2
  239. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
  240. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
  241. tests/benchmark/test_eval.py +385 -0
  242. tests/benchmark/test_image_edit.py +65 -0
  243. tests/{aigc → benchmark}/test_t2i.py +22 -4
  244. tests/benchmark/test_vlm.py +80 -0
  245. tests/cli/test_all.py +85 -47
  246. tests/cli/test_collection.py +20 -8
  247. tests/cli/test_custom.py +22 -15
  248. tests/cli/test_reasoning.py +81 -0
  249. tests/common.py +73 -0
  250. tests/perf/test_perf.py +4 -2
  251. tests/rag/test_clip_benchmark.py +0 -2
  252. evalscope/benchmarks/aigc/t2i/base.py +0 -56
  253. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
  254. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
  255. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
  256. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
  257. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
  258. evalscope/benchmarks/arc/ai2_arc.py +0 -151
  259. evalscope/benchmarks/benchmark.py +0 -81
  260. evalscope/benchmarks/ceval/ceval_exam.py +0 -146
  261. evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
  262. evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
  263. evalscope/benchmarks/competition_math/competition_math.py +0 -79
  264. evalscope/benchmarks/data_adapter.py +0 -528
  265. evalscope/benchmarks/filters.py +0 -59
  266. evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
  267. evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
  268. evalscope/benchmarks/humaneval/humaneval.py +0 -79
  269. evalscope/benchmarks/mmlu/mmlu.py +0 -160
  270. evalscope/benchmarks/mmlu/samples.jsonl +0 -5
  271. evalscope/benchmarks/process_bench/critique_template.txt +0 -13
  272. evalscope/benchmarks/race/race.py +0 -104
  273. evalscope/benchmarks/race/samples.jsonl +0 -5
  274. evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
  275. evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
  276. evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
  277. evalscope/benchmarks/utils.py +0 -60
  278. evalscope/collections/evaluator.py +0 -375
  279. evalscope/metrics/completion_parsers.py +0 -227
  280. evalscope/metrics/named_metrics.py +0 -55
  281. evalscope/models/adapters/__init__.py +0 -14
  282. evalscope/models/adapters/base_adapter.py +0 -84
  283. evalscope/models/adapters/bfcl_adapter.py +0 -246
  284. evalscope/models/adapters/chat_adapter.py +0 -207
  285. evalscope/models/adapters/choice_adapter.py +0 -222
  286. evalscope/models/adapters/custom_adapter.py +0 -71
  287. evalscope/models/adapters/server_adapter.py +0 -236
  288. evalscope/models/adapters/t2i_adapter.py +0 -79
  289. evalscope/models/adapters/tau_bench_adapter.py +0 -189
  290. evalscope/models/custom/__init__.py +0 -4
  291. evalscope/models/custom/custom_model.py +0 -50
  292. evalscope/models/custom/dummy_model.py +0 -99
  293. evalscope/models/local_model.py +0 -128
  294. evalscope/models/register.py +0 -41
  295. tests/cli/test_run.py +0 -489
  296. /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
  297. /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
  298. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
  299. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
  300. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
  301. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
  302. /tests/{aigc → benchmark}/__init__.py +0 -0
@@ -111,7 +111,8 @@ class BlipVQA(BlipBase):
111
111
 
112
112
  image_embeds = self.visual_encoder.forward_features(samples['image'])
113
113
  encoder_output = self.text_encoder.forward_automask(
114
- tokenized_text=samples['tokenized_text'], visual_embeds=image_embeds)
114
+ tokenized_text=samples['tokenized_text'], visual_embeds=image_embeds
115
+ )
115
116
 
116
117
  return encoder_output, image_embeds
117
118
 
@@ -150,15 +151,17 @@ class BlipVQA(BlipBase):
150
151
 
151
152
  return loss, answer_output, answer_targets
152
153
 
153
- def predict_answers(self,
154
- samples,
155
- num_beams=3,
156
- inference_method='rank',
157
- max_len=10,
158
- min_len=1,
159
- num_ans_candidates=128,
160
- answer_list=None,
161
- **kwargs):
154
+ def predict_answers(
155
+ self,
156
+ samples,
157
+ num_beams=3,
158
+ inference_method='rank',
159
+ max_len=10,
160
+ min_len=1,
161
+ num_ans_candidates=128,
162
+ answer_list=None,
163
+ **kwargs
164
+ ):
162
165
  """
163
166
  Args:
164
167
  samples (dict): A dictionary containing the following keys:
@@ -204,8 +207,8 @@ class BlipVQA(BlipBase):
204
207
  if isinstance(samples['text_input'], str):
205
208
  samples['text_input'] = [samples['text_input']]
206
209
 
207
- assert len(samples['text_input']) == samples['image'].size(
208
- 0), 'The number of questions must be equal to the batch size.'
210
+ assert len(samples['text_input']
211
+ ) == samples['image'].size(0), 'The number of questions must be equal to the batch size.'
209
212
 
210
213
  if inference_method == 'generate':
211
214
  return self._generate_answers(samples, num_beams=num_beams, max_length=max_len, min_length=min_len)
@@ -239,7 +242,8 @@ class BlipVQA(BlipBase):
239
242
  num_beams=num_beams,
240
243
  eos_token_id=self.tokenizer.sep_token_id,
241
244
  pad_token_id=self.tokenizer.pad_token_id,
242
- **model_kwargs)
245
+ **model_kwargs
246
+ )
243
247
 
244
248
  # collect answers
245
249
  answers = []
@@ -10,10 +10,16 @@ import torch
10
10
  import torch.utils.checkpoint
11
11
  from torch import Tensor, device, nn
12
12
  from transformers.activations import ACT2FN
13
- from transformers.modeling_outputs import (BaseModelOutputWithPastAndCrossAttentions,
14
- BaseModelOutputWithPoolingAndCrossAttentions)
15
- from transformers.modeling_utils import (PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices,
16
- prune_linear_layer)
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPastAndCrossAttentions,
15
+ BaseModelOutputWithPoolingAndCrossAttentions,
16
+ )
17
+ from transformers.modeling_utils import (
18
+ PreTrainedModel,
19
+ apply_chunking_to_forward,
20
+ find_pruneable_heads_and_indices,
21
+ prune_linear_layer,
22
+ )
17
23
  from transformers.models.bert.configuration_bert import BertConfig
18
24
  from transformers.utils import logging
19
25
  from typing import Tuple
@@ -76,8 +82,10 @@ class BertSelfAttention(nn.Module):
76
82
  super().__init__()
77
83
  self.config = config
78
84
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
79
- raise ValueError('The hidden size (%d) is not a multiple of the number of attention '
80
- 'heads (%d)' % (config.hidden_size, config.num_attention_heads))
85
+ raise ValueError(
86
+ 'The hidden size (%d) is not a multiple of the number of attention '
87
+ 'heads (%d)' % (config.hidden_size, config.num_attention_heads)
88
+ )
81
89
 
82
90
  self.num_attention_heads = config.num_attention_heads
83
91
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -465,7 +473,8 @@ class BertEncoder(nn.Module):
465
473
 
466
474
  if use_cache:
467
475
  logger.warn(
468
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
476
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
477
+ )
469
478
  use_cache = False
470
479
 
471
480
  def create_custom_forward(module):
@@ -506,13 +515,15 @@ class BertEncoder(nn.Module):
506
515
  all_hidden_states = all_hidden_states + (hidden_states, )
507
516
 
508
517
  if not return_dict:
509
- return tuple(v for v in [
510
- hidden_states,
511
- next_decoder_cache,
512
- all_hidden_states,
513
- all_self_attentions,
514
- all_cross_attentions,
515
- ] if v is not None)
518
+ return tuple(
519
+ v for v in [
520
+ hidden_states,
521
+ next_decoder_cache,
522
+ all_hidden_states,
523
+ all_self_attentions,
524
+ all_cross_attentions,
525
+ ] if v is not None
526
+ )
516
527
  return BaseModelOutputWithPastAndCrossAttentions(
517
528
  last_hidden_state=hidden_states,
518
529
  past_key_values=next_decoder_cache,
@@ -703,8 +714,11 @@ class BertModel(BertPreTrainedModel):
703
714
  else:
704
715
  extended_attention_mask = attention_mask[:, None, None, :]
705
716
  else:
706
- raise ValueError('Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
707
- input_shape, attention_mask.shape))
717
+ raise ValueError(
718
+ 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
719
+ input_shape, attention_mask.shape
720
+ )
721
+ )
708
722
 
709
723
  # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
710
724
  # masked positions, this operation will create a tensor which is 0.0 for
@@ -753,7 +767,8 @@ class BertModel(BertPreTrainedModel):
753
767
  """
754
768
  output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
755
769
  output_hidden_states = (
756
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
770
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
771
+ )
757
772
  return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
758
773
 
759
774
  if is_decoder:
@@ -786,8 +801,9 @@ class BertModel(BertPreTrainedModel):
786
801
 
787
802
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
788
803
  # ourselves in which case we just need to make it broadcastable to all heads.
789
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device,
790
- is_decoder)
804
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
805
+ attention_mask, input_shape, device, is_decoder
806
+ )
791
807
 
792
808
  # If a 2D or 3D attention mask is provided for the cross-attention
793
809
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -39,7 +39,8 @@ class Bottleneck(nn.Module):
39
39
  self.downsample = nn.Sequential(
40
40
  OrderedDict([('-1', nn.AvgPool2d(stride)),
41
41
  ('0', nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
42
- ('1', nn.BatchNorm2d(planes * self.expansion))]))
42
+ ('1', nn.BatchNorm2d(planes * self.expansion))])
43
+ )
43
44
 
44
45
  def forward(self, x: torch.Tensor):
45
46
  identity = x
@@ -91,7 +92,8 @@ class AttentionPool2d(nn.Module):
91
92
  out_proj_bias=self.c_proj.bias,
92
93
  use_separate_proj_weight=True,
93
94
  training=self.training,
94
- need_weights=False)
95
+ need_weights=False
96
+ )
95
97
 
96
98
  return x[0]
97
99
 
@@ -120,7 +122,8 @@ class ResidualAttentionBlock(nn.Module):
120
122
  self.ln_1 = LayerNorm(d_model)
121
123
  self.mlp = nn.Sequential(
122
124
  OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()),
123
- ('c_proj', nn.Linear(d_model * 4, d_model))]))
125
+ ('c_proj', nn.Linear(d_model * 4, d_model))])
126
+ )
124
127
  self.ln_2 = LayerNorm(d_model)
125
128
  self.attn_mask = attn_mask
126
129
 
@@ -141,18 +144,16 @@ class ResidualAttentionBlock(nn.Module):
141
144
 
142
145
  class Transformer(nn.Module):
143
146
 
144
- def __init__(self,
145
- width: int,
146
- layers: int,
147
- heads: int,
148
- attn_mask: torch.Tensor = None,
149
- use_grad_checkpointing=False):
147
+ def __init__(
148
+ self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False
149
+ ):
150
150
  super().__init__()
151
151
  self.width = width
152
152
  self.layers = layers
153
153
  self.resblocks = nn.Sequential(
154
154
  *
155
- [ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)])
155
+ [ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)]
156
+ )
156
157
 
157
158
  def forward(self, x: torch.Tensor):
158
159
  return self.resblocks(x)
@@ -160,8 +161,9 @@ class Transformer(nn.Module):
160
161
 
161
162
  class VisionTransformer(nn.Module):
162
163
 
163
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int,
164
- use_grad_checkpointing: bool):
164
+ def __init__(
165
+ self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool
166
+ ):
165
167
  super().__init__()
166
168
  self.input_resolution = input_resolution
167
169
  self.num_features = width
@@ -72,15 +72,17 @@ class Mlp(nn.Module):
72
72
 
73
73
  class Attention(nn.Module):
74
74
 
75
- def __init__(self,
76
- dim,
77
- num_heads=8,
78
- qkv_bias=False,
79
- qk_scale=None,
80
- attn_drop=0.,
81
- proj_drop=0.,
82
- window_size=None,
83
- attn_head_dim=None):
75
+ def __init__(
76
+ self,
77
+ dim,
78
+ num_heads=8,
79
+ qkv_bias=False,
80
+ qk_scale=None,
81
+ attn_drop=0.,
82
+ proj_drop=0.,
83
+ window_size=None,
84
+ attn_head_dim=None
85
+ ):
84
86
  super().__init__()
85
87
  self.num_heads = num_heads
86
88
  head_dim = dim // num_heads
@@ -100,8 +102,9 @@ class Attention(nn.Module):
100
102
  if window_size:
101
103
  self.window_size = window_size
102
104
  self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
103
- self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
104
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
105
+ self.relative_position_bias_table = nn.Parameter(
106
+ torch.zeros(self.num_relative_distance, num_heads)
107
+ ) # 2*Wh-1 * 2*Ww-1, nH
105
108
  # cls to token & token 2 cls & cls to cls
106
109
 
107
110
  # get pair-wise relative position index for each token inside the window
@@ -166,20 +169,22 @@ class Attention(nn.Module):
166
169
 
167
170
  class Block(nn.Module):
168
171
 
169
- def __init__(self,
170
- dim,
171
- num_heads,
172
- mlp_ratio=4.,
173
- qkv_bias=False,
174
- qk_scale=None,
175
- drop=0.,
176
- attn_drop=0.,
177
- drop_path=0.,
178
- init_values=None,
179
- act_layer=nn.GELU,
180
- norm_layer=nn.LayerNorm,
181
- window_size=None,
182
- attn_head_dim=None):
172
+ def __init__(
173
+ self,
174
+ dim,
175
+ num_heads,
176
+ mlp_ratio=4.,
177
+ qkv_bias=False,
178
+ qk_scale=None,
179
+ drop=0.,
180
+ attn_drop=0.,
181
+ drop_path=0.,
182
+ init_values=None,
183
+ act_layer=nn.GELU,
184
+ norm_layer=nn.LayerNorm,
185
+ window_size=None,
186
+ attn_head_dim=None
187
+ ):
183
188
  super().__init__()
184
189
  self.norm1 = norm_layer(dim)
185
190
  self.attn = Attention(
@@ -190,7 +195,8 @@ class Block(nn.Module):
190
195
  attn_drop=attn_drop,
191
196
  proj_drop=drop,
192
197
  window_size=window_size,
193
- attn_head_dim=attn_head_dim)
198
+ attn_head_dim=attn_head_dim
199
+ )
194
200
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
195
201
  self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
196
202
  self.norm2 = norm_layer(dim)
@@ -244,8 +250,9 @@ class RelativePositionBias(nn.Module):
244
250
  super().__init__()
245
251
  self.window_size = window_size
246
252
  self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
247
- self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
248
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
253
+ self.relative_position_bias_table = nn.Parameter(
254
+ torch.zeros(self.num_relative_distance, num_heads)
255
+ ) # 2*Wh-1 * 2*Ww-1, nH
249
256
  # cls to token & token 2 cls & cls to cls
250
257
 
251
258
  # get pair-wise relative position index for each token inside the window
@@ -281,28 +288,30 @@ class VisionTransformer(nn.Module):
281
288
  """ Vision Transformer with support for patch or hybrid CNN input stage
282
289
  """
283
290
 
284
- def __init__(self,
285
- img_size=224,
286
- patch_size=16,
287
- in_chans=3,
288
- num_classes=1000,
289
- embed_dim=768,
290
- depth=12,
291
- num_heads=12,
292
- mlp_ratio=4.,
293
- qkv_bias=False,
294
- qk_scale=None,
295
- drop_rate=0.,
296
- attn_drop_rate=0.,
297
- drop_path_rate=0.,
298
- norm_layer=nn.LayerNorm,
299
- init_values=None,
300
- use_abs_pos_emb=True,
301
- use_rel_pos_bias=False,
302
- use_shared_rel_pos_bias=False,
303
- use_mean_pooling=True,
304
- init_scale=0.001,
305
- use_checkpoint=False):
291
+ def __init__(
292
+ self,
293
+ img_size=224,
294
+ patch_size=16,
295
+ in_chans=3,
296
+ num_classes=1000,
297
+ embed_dim=768,
298
+ depth=12,
299
+ num_heads=12,
300
+ mlp_ratio=4.,
301
+ qkv_bias=False,
302
+ qk_scale=None,
303
+ drop_rate=0.,
304
+ attn_drop_rate=0.,
305
+ drop_path_rate=0.,
306
+ norm_layer=nn.LayerNorm,
307
+ init_values=None,
308
+ use_abs_pos_emb=True,
309
+ use_rel_pos_bias=False,
310
+ use_shared_rel_pos_bias=False,
311
+ use_mean_pooling=True,
312
+ init_scale=0.001,
313
+ use_checkpoint=False
314
+ ):
306
315
  super().__init__()
307
316
  self.image_size = img_size
308
317
  self.num_classes = num_classes
@@ -338,7 +347,8 @@ class VisionTransformer(nn.Module):
338
347
  drop_path=dpr[i],
339
348
  norm_layer=norm_layer,
340
349
  init_values=init_values,
341
- window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)
350
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None
351
+ ) for i in range(depth)
342
352
  ])
343
353
  # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
344
354
  # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
@@ -450,7 +460,8 @@ def interpolate_pos_embed(model, checkpoint_model):
450
460
  pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
451
461
  pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
452
462
  pos_tokens = torch.nn.functional.interpolate(
453
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
463
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
464
+ )
454
465
  pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
455
466
  new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
456
467
  checkpoint_model['pos_embed'] = new_pos_embed
@@ -20,13 +20,23 @@ from torch.nn import CrossEntropyLoss
20
20
  from transformers import BatchEncoding, PreTrainedTokenizer
21
21
  from transformers.activations import ACT2FN
22
22
  from transformers.file_utils import ModelOutput
23
- from transformers.modeling_outputs import (BaseModelOutputWithPastAndCrossAttentions,
24
- BaseModelOutputWithPoolingAndCrossAttentions,
25
- CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput,
26
- NextSentencePredictorOutput, QuestionAnsweringModelOutput,
27
- SequenceClassifierOutput, TokenClassifierOutput)
28
- from transformers.modeling_utils import (PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices,
29
- prune_linear_layer)
23
+ from transformers.modeling_outputs import (
24
+ BaseModelOutputWithPastAndCrossAttentions,
25
+ BaseModelOutputWithPoolingAndCrossAttentions,
26
+ CausalLMOutputWithCrossAttentions,
27
+ MaskedLMOutput,
28
+ MultipleChoiceModelOutput,
29
+ NextSentencePredictorOutput,
30
+ QuestionAnsweringModelOutput,
31
+ SequenceClassifierOutput,
32
+ TokenClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import (
35
+ PreTrainedModel,
36
+ apply_chunking_to_forward,
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
30
40
  from transformers.models.bert.configuration_bert import BertConfig
31
41
  from transformers.utils import logging
32
42
  from typing import Optional, Tuple
@@ -102,8 +112,10 @@ class BertSelfAttention(nn.Module):
102
112
  super().__init__()
103
113
  self.config = config
104
114
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
105
- raise ValueError('The hidden size (%d) is not a multiple of the number of attention '
106
- 'heads (%d)' % (config.hidden_size, config.num_attention_heads))
115
+ raise ValueError(
116
+ 'The hidden size (%d) is not a multiple of the number of attention '
117
+ 'heads (%d)' % (config.hidden_size, config.num_attention_heads)
118
+ )
107
119
 
108
120
  self.num_attention_heads = config.num_attention_heads
109
121
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -408,8 +420,9 @@ class BertLayer(nn.Module):
408
420
  output_attentions=output_attentions,
409
421
  )
410
422
  attention_output = cross_attention_outputs[0]
411
- outputs = (outputs + cross_attention_outputs[1:-1]
412
- ) # add cross attentions if we output attention weights
423
+ outputs = (
424
+ outputs + cross_attention_outputs[1:-1]
425
+ ) # add cross attentions if we output attention weights
413
426
  layer_output = apply_chunking_to_forward(
414
427
  self.feed_forward_chunk,
415
428
  self.chunk_size_feed_forward,
@@ -492,7 +505,8 @@ class BertEncoder(nn.Module):
492
505
 
493
506
  if use_cache:
494
507
  logger.warn(
495
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
508
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
509
+ )
496
510
  use_cache = False
497
511
 
498
512
  def create_custom_forward(module):
@@ -533,13 +547,15 @@ class BertEncoder(nn.Module):
533
547
  all_hidden_states = all_hidden_states + (hidden_states, )
534
548
 
535
549
  if not return_dict:
536
- return tuple(v for v in [
537
- hidden_states,
538
- next_decoder_cache,
539
- all_hidden_states,
540
- all_self_attentions,
541
- all_cross_attentions,
542
- ] if v is not None)
550
+ return tuple(
551
+ v for v in [
552
+ hidden_states,
553
+ next_decoder_cache,
554
+ all_hidden_states,
555
+ all_self_attentions,
556
+ all_cross_attentions,
557
+ ] if v is not None
558
+ )
543
559
  return BaseModelOutputWithPastAndCrossAttentions(
544
560
  last_hidden_state=hidden_states,
545
561
  past_key_values=next_decoder_cache,
@@ -730,8 +746,11 @@ class BertModel(BertPreTrainedModel):
730
746
  else:
731
747
  extended_attention_mask = attention_mask[:, None, None, :]
732
748
  else:
733
- raise ValueError('Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
734
- input_shape, attention_mask.shape))
749
+ raise ValueError(
750
+ 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
751
+ input_shape, attention_mask.shape
752
+ )
753
+ )
735
754
 
736
755
  # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
737
756
  # masked positions, this operation will create a tensor which is 0.0 for
@@ -781,7 +800,8 @@ class BertModel(BertPreTrainedModel):
781
800
  """
782
801
  output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
783
802
  output_hidden_states = (
784
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
803
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
804
+ )
785
805
  return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
786
806
 
787
807
  if is_decoder:
@@ -814,8 +834,9 @@ class BertModel(BertPreTrainedModel):
814
834
 
815
835
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
816
836
  # ourselves in which case we just need to make it broadcastable to all heads.
817
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device,
818
- is_decoder)
837
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
838
+ attention_mask, input_shape, device, is_decoder
839
+ )
819
840
 
820
841
  # If a 2D or 3D attention mask is provided for the cross-attention
821
842
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1176,18 +1197,20 @@ class XBertLMHeadDecoder(BertLMHeadModel):
1176
1197
  else:
1177
1198
  return cls(config=med_config)
1178
1199
 
1179
- def generate_from_encoder(self,
1180
- tokenized_prompt,
1181
- visual_embeds,
1182
- sep_token_id,
1183
- pad_token_id,
1184
- use_nucleus_sampling=False,
1185
- num_beams=3,
1186
- max_length=30,
1187
- min_length=10,
1188
- top_p=0.9,
1189
- repetition_penalty=1.0,
1190
- **kwargs):
1200
+ def generate_from_encoder(
1201
+ self,
1202
+ tokenized_prompt,
1203
+ visual_embeds,
1204
+ sep_token_id,
1205
+ pad_token_id,
1206
+ use_nucleus_sampling=False,
1207
+ num_beams=3,
1208
+ max_length=30,
1209
+ min_length=10,
1210
+ top_p=0.9,
1211
+ repetition_penalty=1.0,
1212
+ **kwargs
1213
+ ):
1191
1214
 
1192
1215
  if not use_nucleus_sampling:
1193
1216
  num_beams = num_beams
@@ -1212,7 +1235,8 @@ class XBertLMHeadDecoder(BertLMHeadModel):
1212
1235
  eos_token_id=sep_token_id,
1213
1236
  pad_token_id=pad_token_id,
1214
1237
  repetition_penalty=1.1,
1215
- **model_kwargs)
1238
+ **model_kwargs
1239
+ )
1216
1240
  else:
1217
1241
  # beam search
1218
1242
  outputs = self.generate(
@@ -1223,7 +1247,8 @@ class XBertLMHeadDecoder(BertLMHeadModel):
1223
1247
  eos_token_id=sep_token_id,
1224
1248
  pad_token_id=pad_token_id,
1225
1249
  repetition_penalty=repetition_penalty,
1226
- **model_kwargs)
1250
+ **model_kwargs
1251
+ )
1227
1252
 
1228
1253
  return outputs
1229
1254
 
@@ -343,9 +343,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
343
343
  block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
344
344
  block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
345
345
  block.attn.qkv.weight.copy_(
346
- torch.cat([_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
346
+ torch.cat([_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])
347
+ )
347
348
  block.attn.qkv.bias.copy_(
348
- torch.cat([_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
349
+ torch.cat([_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])
350
+ )
349
351
  block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
350
352
  block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
351
353
  for r in range(2):
@@ -394,7 +396,8 @@ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
394
396
  pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
395
397
  pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
396
398
  pos_tokens = torch.nn.functional.interpolate(
397
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
399
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
400
+ )
398
401
  pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
399
402
  new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
400
403
  print('reshape position embedding from %d to %d' % (orig_size**2, new_size**2))
@@ -7,8 +7,12 @@
7
7
 
8
8
  from ..common.registry import registry
9
9
  from .base_processor import BaseProcessor
10
- from .blip_processors import (Blip2ImageTrainProcessor, BlipCaptionProcessor, BlipImageEvalProcessor,
11
- BlipImageTrainProcessor)
10
+ from .blip_processors import (
11
+ Blip2ImageTrainProcessor,
12
+ BlipCaptionProcessor,
13
+ BlipImageEvalProcessor,
14
+ BlipImageTrainProcessor,
15
+ )
12
16
 
13
17
  __all__ = [
14
18
  'BaseProcessor',
@@ -107,8 +107,9 @@ def color_func(img, factor):
107
107
  # np.eye(3) * factor
108
108
  # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
109
109
  # )[np.newaxis, np.newaxis, :]
110
- M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32(
111
- [[0.114], [0.587], [0.299]])
110
+ M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32([[
111
+ 0.114
112
+ ], [0.587], [0.299]])
112
113
  out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
113
114
  return out
114
115