evalscope 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (324) hide show
  1. evalscope/api/benchmark/__init__.py +9 -1
  2. evalscope/api/benchmark/adapters/__init__.py +4 -0
  3. evalscope/api/benchmark/adapters/agent_adapter.py +8 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +75 -4
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +5 -2
  7. evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
  8. evalscope/api/benchmark/adapters/text2image_adapter.py +12 -10
  9. evalscope/api/benchmark/adapters/vision_language_adapter.py +8 -0
  10. evalscope/api/benchmark/benchmark.py +85 -2
  11. evalscope/api/benchmark/meta.py +10 -1
  12. evalscope/api/dataset/dataset.py +27 -6
  13. evalscope/api/dataset/loader.py +8 -3
  14. evalscope/api/evaluator/cache.py +31 -4
  15. evalscope/api/evaluator/evaluator.py +5 -0
  16. evalscope/api/evaluator/state.py +17 -1
  17. evalscope/api/messages/__init__.py +1 -0
  18. evalscope/api/messages/chat_message.py +52 -2
  19. evalscope/api/metric/__init__.py +1 -1
  20. evalscope/api/metric/metric.py +6 -1
  21. evalscope/api/metric/scorer.py +15 -7
  22. evalscope/api/mixin/__init__.py +1 -1
  23. evalscope/api/mixin/llm_judge_mixin.py +2 -0
  24. evalscope/api/mixin/sandbox_mixin.py +182 -0
  25. evalscope/api/model/generate_config.py +10 -6
  26. evalscope/api/model/model.py +5 -2
  27. evalscope/api/tool/tool_info.py +1 -1
  28. evalscope/app/app.py +3 -0
  29. evalscope/app/ui/multi_model.py +6 -1
  30. evalscope/app/ui/single_model.py +11 -5
  31. evalscope/app/utils/data_utils.py +8 -7
  32. evalscope/app/utils/env_utils.py +12 -0
  33. evalscope/app/utils/text_utils.py +14 -12
  34. evalscope/app/utils/visualization.py +2 -2
  35. evalscope/arguments.py +8 -4
  36. evalscope/backend/opencompass/backend_manager.py +0 -2
  37. evalscope/backend/rag_eval/utils/embedding.py +9 -1
  38. evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
  39. evalscope/benchmarks/ai2d/ai2d_adapter.py +54 -0
  40. evalscope/benchmarks/aime/aime24_adapter.py +5 -0
  41. evalscope/benchmarks/aime/aime25_adapter.py +136 -1
  42. evalscope/benchmarks/aime/grader.py +307 -0
  43. evalscope/benchmarks/aime/math_normalize.py +189 -0
  44. evalscope/benchmarks/amc/amc_adapter.py +51 -0
  45. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -0
  46. evalscope/benchmarks/bbh/bbh_adapter.py +43 -17
  47. evalscope/benchmarks/bfcl/{bfcl_adapter.py → v3/bfcl_v3_adapter.py} +131 -19
  48. evalscope/benchmarks/bfcl/{generation.py → v3/generation.py} +9 -9
  49. evalscope/benchmarks/bfcl/v3/utils.py +23 -0
  50. evalscope/benchmarks/bfcl/v4/__init__.py +0 -0
  51. evalscope/benchmarks/bfcl/v4/bfcl_v4_adapter.py +229 -0
  52. evalscope/benchmarks/bfcl/v4/utils.py +410 -0
  53. evalscope/benchmarks/biomix_qa/__init__.py +0 -0
  54. evalscope/benchmarks/biomix_qa/biomix_qa_adapter.py +36 -0
  55. evalscope/benchmarks/blink/__init__.py +0 -0
  56. evalscope/benchmarks/blink/blink_adapter.py +61 -0
  57. evalscope/benchmarks/ceval/ceval_adapter.py +1 -2
  58. evalscope/benchmarks/chartqa/__init__.py +0 -0
  59. evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
  60. evalscope/benchmarks/chartqa/utils.py +38 -0
  61. evalscope/benchmarks/coin_flip/__init__.py +0 -0
  62. evalscope/benchmarks/coin_flip/coin_flip_adapter.py +128 -0
  63. evalscope/benchmarks/commonsense_qa/__init__.py +0 -0
  64. evalscope/benchmarks/commonsense_qa/commonsense_qa_adapter.py +32 -0
  65. evalscope/benchmarks/competition_math/competition_math_adapter.py +5 -0
  66. evalscope/benchmarks/data_collection/data_collection_adapter.py +24 -19
  67. evalscope/benchmarks/docvqa/__init__.py +0 -0
  68. evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
  69. evalscope/benchmarks/drivelology/__init__.py +0 -0
  70. evalscope/benchmarks/drivelology/drivelology_binary_adapter.py +170 -0
  71. evalscope/benchmarks/drivelology/drivelology_multilabel_adapter.py +254 -0
  72. evalscope/benchmarks/drivelology/drivelology_selection_adapter.py +49 -0
  73. evalscope/benchmarks/drivelology/drivelology_writing_adapter.py +218 -0
  74. evalscope/benchmarks/drop/drop_adapter.py +15 -44
  75. evalscope/benchmarks/drop/utils.py +97 -0
  76. evalscope/benchmarks/frames/frames_adapter.py +2 -1
  77. evalscope/benchmarks/general_arena/general_arena_adapter.py +7 -2
  78. evalscope/benchmarks/general_arena/utils.py +2 -1
  79. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
  80. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  81. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +25 -9
  82. evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
  83. evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +159 -0
  84. evalscope/benchmarks/halu_eval/__init__.py +0 -0
  85. evalscope/benchmarks/halu_eval/halu_eval_adapter.py +128 -0
  86. evalscope/benchmarks/halu_eval/halu_eval_instructions.py +84 -0
  87. evalscope/benchmarks/healthbench/__init__.py +0 -0
  88. evalscope/benchmarks/healthbench/healthbench_adapter.py +282 -0
  89. evalscope/benchmarks/healthbench/utils.py +102 -0
  90. evalscope/benchmarks/hle/hle_adapter.py +3 -2
  91. evalscope/benchmarks/humaneval/humaneval_adapter.py +24 -52
  92. evalscope/benchmarks/humaneval/utils.py +235 -0
  93. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  94. evalscope/benchmarks/image_edit/__init__.py +0 -0
  95. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  96. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  97. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  98. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  99. evalscope/benchmarks/infovqa/__init__.py +0 -0
  100. evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
  101. evalscope/benchmarks/live_code_bench/evaluate_utils.py +13 -6
  102. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +66 -54
  103. evalscope/benchmarks/live_code_bench/sandbox_evaluate_utils.py +220 -0
  104. evalscope/benchmarks/logi_qa/__int__.py +0 -0
  105. evalscope/benchmarks/logi_qa/logi_qa_adapter.py +41 -0
  106. evalscope/benchmarks/math_500/math_500_adapter.py +5 -1
  107. evalscope/benchmarks/math_qa/__init__.py +0 -0
  108. evalscope/benchmarks/math_qa/math_qa_adapter.py +35 -0
  109. evalscope/benchmarks/math_verse/__init__.py +0 -0
  110. evalscope/benchmarks/math_verse/math_verse_adapter.py +105 -0
  111. evalscope/benchmarks/math_vision/__init__.py +0 -0
  112. evalscope/benchmarks/math_vision/math_vision_adapter.py +116 -0
  113. evalscope/benchmarks/math_vista/__init__.py +0 -0
  114. evalscope/benchmarks/math_vista/math_vista_adapter.py +114 -0
  115. evalscope/benchmarks/med_mcqa/__init__.py +0 -0
  116. evalscope/benchmarks/med_mcqa/med_mcqa_adapter.py +32 -0
  117. evalscope/benchmarks/minerva_math/__init__.py +0 -0
  118. evalscope/benchmarks/minerva_math/minerva_math_adapter.py +53 -0
  119. evalscope/benchmarks/mm_bench/__init__.py +0 -0
  120. evalscope/benchmarks/mm_bench/mm_bench_adapter.py +99 -0
  121. evalscope/benchmarks/mm_star/__init__.py +0 -0
  122. evalscope/benchmarks/mm_star/mm_star_adapter.py +73 -0
  123. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
  124. evalscope/benchmarks/mmmu/__init__.py +0 -0
  125. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  126. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  127. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +124 -0
  128. evalscope/benchmarks/mri_mcqa/__init__.py +0 -0
  129. evalscope/benchmarks/mri_mcqa/mri_mcqa_adapter.py +34 -0
  130. evalscope/benchmarks/multi_if/__init__.py +0 -0
  131. evalscope/benchmarks/multi_if/ifeval.py +3354 -0
  132. evalscope/benchmarks/multi_if/metrics.py +120 -0
  133. evalscope/benchmarks/multi_if/multi_if_adapter.py +161 -0
  134. evalscope/benchmarks/music_trivia/__init__.py +0 -0
  135. evalscope/benchmarks/music_trivia/music_trivia_adapter.py +36 -0
  136. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +7 -6
  137. evalscope/benchmarks/ner/__init__.py +0 -0
  138. evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
  139. evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
  140. evalscope/benchmarks/ner/copious_adapter.py +85 -0
  141. evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
  142. evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
  143. evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
  144. evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
  145. evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
  146. evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
  147. evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
  148. evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
  149. evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
  150. evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
  151. evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
  152. evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
  153. evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
  154. evalscope/benchmarks/ocr_bench/__init__.py +0 -0
  155. evalscope/benchmarks/ocr_bench/ocr_bench/__init__.py +0 -0
  156. evalscope/benchmarks/ocr_bench/ocr_bench/ocr_bench_adapter.py +101 -0
  157. evalscope/benchmarks/ocr_bench/ocr_bench_v2/IoUscore_metric.py +87 -0
  158. evalscope/benchmarks/ocr_bench/ocr_bench_v2/TEDS_metric.py +963 -0
  159. evalscope/benchmarks/ocr_bench/ocr_bench_v2/__init__.py +0 -0
  160. evalscope/benchmarks/ocr_bench/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
  161. evalscope/benchmarks/ocr_bench/ocr_bench_v2/page_ocr_metric.py +50 -0
  162. evalscope/benchmarks/ocr_bench/ocr_bench_v2/parallel.py +46 -0
  163. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/__init__.py +0 -0
  164. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/readme.txt +26 -0
  165. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
  166. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/script.py +481 -0
  167. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_metric.py +179 -0
  168. evalscope/benchmarks/ocr_bench/ocr_bench_v2/utils.py +433 -0
  169. evalscope/benchmarks/ocr_bench/ocr_bench_v2/vqa_metric.py +254 -0
  170. evalscope/benchmarks/olympiad_bench/__init__.py +0 -0
  171. evalscope/benchmarks/olympiad_bench/olympiad_bench_adapter.py +163 -0
  172. evalscope/benchmarks/olympiad_bench/utils.py +565 -0
  173. evalscope/benchmarks/omni_bench/__init__.py +0 -0
  174. evalscope/benchmarks/omni_bench/omni_bench_adapter.py +86 -0
  175. evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
  176. evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
  177. evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
  178. evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
  179. evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
  180. evalscope/benchmarks/piqa/__init__.py +0 -0
  181. evalscope/benchmarks/piqa/piqa_adapter.py +32 -0
  182. evalscope/benchmarks/poly_math/__init__.py +0 -0
  183. evalscope/benchmarks/poly_math/poly_math_adapter.py +132 -0
  184. evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
  185. evalscope/benchmarks/pope/__init__.py +0 -0
  186. evalscope/benchmarks/pope/pope_adapter.py +112 -0
  187. evalscope/benchmarks/process_bench/process_bench_adapter.py +1 -0
  188. evalscope/benchmarks/pumed_qa/__init__.py +0 -0
  189. evalscope/benchmarks/pumed_qa/pubmed_qa_adapter.py +175 -0
  190. evalscope/benchmarks/qasc/__init__.py +0 -0
  191. evalscope/benchmarks/qasc/qasc_adapter.py +35 -0
  192. evalscope/benchmarks/real_world_qa/__init__.py +0 -0
  193. evalscope/benchmarks/real_world_qa/real_world_qa_adapter.py +64 -0
  194. evalscope/benchmarks/sciq/__init__.py +0 -0
  195. evalscope/benchmarks/sciq/sciq_adapter.py +36 -0
  196. evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
  197. evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
  198. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -1
  199. evalscope/benchmarks/simple_vqa/__init__.py +0 -0
  200. evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
  201. evalscope/benchmarks/siqa/__init__.py +0 -0
  202. evalscope/benchmarks/siqa/siqa_adapter.py +39 -0
  203. evalscope/benchmarks/tau_bench/tau2_bench/__init__.py +0 -0
  204. evalscope/benchmarks/tau_bench/tau2_bench/generation.py +158 -0
  205. evalscope/benchmarks/tau_bench/tau2_bench/tau2_bench_adapter.py +146 -0
  206. evalscope/benchmarks/tau_bench/tau_bench/__init__.py +0 -0
  207. evalscope/benchmarks/tau_bench/{generation.py → tau_bench/generation.py} +1 -1
  208. evalscope/benchmarks/tau_bench/{tau_bench_adapter.py → tau_bench/tau_bench_adapter.py} +29 -29
  209. evalscope/benchmarks/text2image/__init__.py +0 -0
  210. evalscope/benchmarks/{aigc/t2i → text2image}/evalmuse_adapter.py +3 -1
  211. evalscope/benchmarks/{aigc/t2i → text2image}/genai_bench_adapter.py +2 -2
  212. evalscope/benchmarks/{aigc/t2i → text2image}/general_t2i_adapter.py +1 -1
  213. evalscope/benchmarks/{aigc/t2i → text2image}/hpdv2_adapter.py +7 -2
  214. evalscope/benchmarks/{aigc/t2i → text2image}/tifa_adapter.py +1 -0
  215. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +3 -3
  216. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +1 -2
  217. evalscope/benchmarks/visu_logic/__init__.py +0 -0
  218. evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
  219. evalscope/benchmarks/wmt/__init__.py +0 -0
  220. evalscope/benchmarks/wmt/wmt24_adapter.py +294 -0
  221. evalscope/benchmarks/zerobench/__init__.py +0 -0
  222. evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
  223. evalscope/cli/start_app.py +7 -1
  224. evalscope/cli/start_perf.py +7 -1
  225. evalscope/config.py +103 -18
  226. evalscope/constants.py +18 -0
  227. evalscope/evaluator/evaluator.py +138 -82
  228. evalscope/metrics/bert_score/__init__.py +0 -0
  229. evalscope/metrics/bert_score/scorer.py +338 -0
  230. evalscope/metrics/bert_score/utils.py +697 -0
  231. evalscope/metrics/llm_judge.py +19 -7
  232. evalscope/metrics/math_parser.py +14 -0
  233. evalscope/metrics/metric.py +317 -13
  234. evalscope/metrics/metrics.py +37 -0
  235. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
  236. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
  237. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
  238. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
  239. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
  240. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
  241. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
  242. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
  243. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
  244. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
  245. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
  246. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
  247. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
  248. evalscope/models/image_edit_model.py +125 -0
  249. evalscope/models/model_apis.py +22 -0
  250. evalscope/models/openai_compatible.py +21 -0
  251. evalscope/models/text2image_model.py +2 -2
  252. evalscope/models/utils/openai.py +16 -6
  253. evalscope/perf/arguments.py +26 -4
  254. evalscope/perf/benchmark.py +76 -89
  255. evalscope/perf/http_client.py +31 -16
  256. evalscope/perf/main.py +15 -2
  257. evalscope/perf/plugin/api/base.py +9 -7
  258. evalscope/perf/plugin/api/custom_api.py +13 -58
  259. evalscope/perf/plugin/api/default_api.py +188 -79
  260. evalscope/perf/plugin/api/openai_api.py +85 -20
  261. evalscope/perf/plugin/datasets/base.py +21 -0
  262. evalscope/perf/plugin/datasets/custom.py +2 -3
  263. evalscope/perf/plugin/datasets/flickr8k.py +2 -2
  264. evalscope/perf/plugin/datasets/kontext_bench.py +2 -2
  265. evalscope/perf/plugin/datasets/line_by_line.py +2 -3
  266. evalscope/perf/plugin/datasets/longalpaca.py +2 -3
  267. evalscope/perf/plugin/datasets/openqa.py +2 -4
  268. evalscope/perf/plugin/datasets/random_dataset.py +1 -3
  269. evalscope/perf/plugin/datasets/random_vl_dataset.py +2 -2
  270. evalscope/perf/utils/benchmark_util.py +43 -27
  271. evalscope/perf/utils/db_util.py +14 -19
  272. evalscope/perf/utils/local_server.py +3 -44
  273. evalscope/perf/utils/log_utils.py +21 -6
  274. evalscope/report/__init__.py +13 -3
  275. evalscope/report/combinator.py +91 -20
  276. evalscope/report/generator.py +8 -87
  277. evalscope/report/report.py +8 -4
  278. evalscope/run.py +13 -5
  279. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  280. evalscope/utils/argument_utils.py +1 -1
  281. evalscope/utils/chat_service.py +1 -1
  282. evalscope/utils/function_utils.py +249 -12
  283. evalscope/utils/import_utils.py +73 -1
  284. evalscope/utils/io_utils.py +132 -7
  285. evalscope/utils/json_schema.py +25 -2
  286. evalscope/utils/logger.py +69 -18
  287. evalscope/utils/model_utils.py +4 -3
  288. evalscope/utils/multi_choices.py +39 -7
  289. evalscope/utils/ner.py +377 -0
  290. evalscope/version.py +2 -2
  291. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/METADATA +252 -408
  292. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/RECORD +290 -154
  293. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/WHEEL +1 -1
  294. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/top_level.txt +0 -1
  295. evalscope/api/mixin/dataset_mixin.py +0 -105
  296. evalscope/benchmarks/aigc/i2i/general_i2i_adapter.py +0 -44
  297. tests/__init__.py +0 -1
  298. tests/aigc/__init__.py +0 -1
  299. tests/aigc/test_t2i.py +0 -142
  300. tests/benchmark/__init__.py +0 -1
  301. tests/benchmark/test_eval.py +0 -386
  302. tests/cli/__init__.py +0 -1
  303. tests/cli/test_all.py +0 -229
  304. tests/cli/test_collection.py +0 -96
  305. tests/cli/test_custom.py +0 -268
  306. tests/perf/__init__.py +0 -1
  307. tests/perf/test_perf.py +0 -176
  308. tests/rag/test_clip_benchmark.py +0 -90
  309. tests/rag/test_mteb.py +0 -213
  310. tests/rag/test_ragas.py +0 -128
  311. tests/swift/__init__.py +0 -1
  312. tests/swift/test_run_swift_eval.py +0 -146
  313. tests/swift/test_run_swift_vlm_eval.py +0 -128
  314. tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
  315. tests/test_run_all.py +0 -12
  316. tests/utils.py +0 -13
  317. tests/vlm/__init__.py +0 -1
  318. tests/vlm/test_vlmeval.py +0 -102
  319. /evalscope/benchmarks/{aigc → aa_lcr}/__init__.py +0 -0
  320. /evalscope/benchmarks/{aigc/i2i → ai2d}/__init__.py +0 -0
  321. /evalscope/benchmarks/{aigc/t2i → amc}/__init__.py +0 -0
  322. {tests/rag → evalscope/benchmarks/bfcl/v3}/__init__.py +0 -0
  323. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/entry_points.txt +0 -0
  324. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info/licenses}/LICENSE +0 -0
@@ -8,15 +8,18 @@ and report generation.
8
8
  """
9
9
 
10
10
  import os
11
+ import traceback
11
12
  from collections import defaultdict
12
- from concurrent.futures import ThreadPoolExecutor, as_completed
13
13
  from tqdm import tqdm
14
- from typing import TYPE_CHECKING, Dict, List, Tuple, Union
14
+ from typing import TYPE_CHECKING, Callable, Dict, List
15
15
 
16
16
  from evalscope.api.dataset import Dataset, DatasetDict, Sample
17
17
  from evalscope.api.evaluator import CacheManager, Evaluator, TaskState
18
18
  from evalscope.api.metric import AggScore, SampleScore
19
+ from evalscope.constants import HEARTBEAT_INTERVAL_SEC
19
20
  from evalscope.report import Report, gen_table
21
+ from evalscope.utils.function_utils import run_in_threads_with_progress
22
+ from evalscope.utils.logger import get_logger
20
23
 
21
24
  if TYPE_CHECKING:
22
25
  from evalscope.api.benchmark import DataAdapter
@@ -24,8 +27,6 @@ if TYPE_CHECKING:
24
27
  from evalscope.config import TaskConfig
25
28
  from evalscope.utils.io_utils import OutputsStructure
26
29
 
27
- from evalscope.utils.logger import get_logger
28
-
29
30
  logger = get_logger()
30
31
 
31
32
 
@@ -91,17 +92,27 @@ class DefaultEvaluator(Evaluator):
91
92
  Report: The complete evaluation report containing all metrics and results.
92
93
  """
93
94
  # Load the dataset and evaluate each subset
95
+ logger.info(f'Start evaluating benchmark: {self.benchmark_name}')
94
96
  dataset_dict = self.benchmark.load_dataset()
95
97
  agg_score_dict = defaultdict(list)
96
98
 
97
99
  # Process each subset (e.g., test, validation) independently
100
+ logger.info('Evaluating all subsets of the dataset...')
98
101
  for subset, dataset in dataset_dict.items():
99
- assert len(dataset) > 0, f'No samples found in subset: {subset}'
102
+ if len(dataset) == 0:
103
+ logger.info(f'No samples found in subset: {subset}, skipping.')
104
+ continue
105
+ logger.info(f'Evaluating subset: {subset}')
100
106
  subset_score = self.evaluate_subset(subset, dataset)
101
107
  agg_score_dict[subset] = subset_score
102
108
 
103
109
  # Generate the report based on aggregated scores
110
+ logger.info('Generating report...')
104
111
  report = self.get_report(agg_score_dict)
112
+
113
+ # Finalize the evaluation process
114
+ self.finalize()
115
+ logger.info(f'Benchmark {self.benchmark_name} evaluation finished.')
105
116
  return report
106
117
 
107
118
  def evaluate_subset(self, subset: str, dataset: Dataset) -> List[AggScore]:
@@ -121,12 +132,15 @@ class DefaultEvaluator(Evaluator):
121
132
  List[AggScore]: Aggregated scores for this subset.
122
133
  """
123
134
  # Get model predictions for all samples in the subset
135
+ logger.info(f'Getting predictions for subset: {subset}')
124
136
  task_states = self.get_answers(subset, dataset)
125
137
 
126
138
  # Calculate evaluation metrics for each prediction
139
+ logger.info(f'Getting reviews for subset: {subset}')
127
140
  sample_scores = self.get_reviews(subset, task_states)
128
141
 
129
142
  # Aggregate individual sample scores into subset-level metrics
143
+ logger.info(f'Aggregating scores for subset: {subset}')
130
144
  agg_scores = self.benchmark.aggregate_scores(sample_scores=sample_scores)
131
145
  return agg_scores
132
146
 
@@ -148,51 +162,48 @@ class DefaultEvaluator(Evaluator):
148
162
  """
149
163
  # Initialize task state list and filter cached predictions if caching is enabled
150
164
  if self.use_cache:
151
- task_state_list, dataset = self.cache_manager.filter_prediction_cache(subset, dataset)
165
+ cached_task_state_list, dataset = self.cache_manager.filter_prediction_cache(subset, dataset)
152
166
  else:
153
- task_state_list = []
167
+ cached_task_state_list = []
154
168
 
155
169
  # Get output directory for storing model predictions
156
170
  model_prediction_dir = os.path.dirname(self.cache_manager.get_prediction_cache_path(subset))
157
171
 
158
172
  # Convert dataset to list for parallel processing
159
173
  dataset_list = list(dataset)
160
-
161
174
  if not dataset_list:
162
- return task_state_list
163
-
164
- # Process samples in parallel using ThreadPoolExecutor
165
- with ThreadPoolExecutor(max_workers=min(len(dataset_list), self.task_config.eval_batch_size)) as executor:
166
- # Submit all prediction tasks
167
- future_to_sample = {
168
- executor.submit(self._predict_sample, sample, model_prediction_dir): sample
169
- for sample in dataset_list
170
- }
171
-
172
- # Process completed tasks with progress bar
173
- with tqdm(total=len(dataset_list), desc=f'Predicting[{self.benchmark_name}@{subset}]: ') as pbar:
174
- for future in as_completed(future_to_sample):
175
- sample = future_to_sample[future]
176
- try:
177
- task_state = future.result()
178
- task_state_list.append(task_state)
179
-
180
- # Save the prediction result to cache for future use
181
- model_result = self.cache_manager.save_prediction_cache(
182
- subset, task_state, self.benchmark.save_metadata
183
- )
184
- logger.debug(f'Model result: \n{model_result.model_dump_json(indent=2)}')
185
-
186
- except Exception as exc:
187
- logger.error(f'{sample.model_dump_json(indent=2)} prediction failed: due to {exc}')
188
- if self.task_config.ignore_errors:
189
- logger.warning('Error ignored, continuing with next sample.')
190
- else:
191
- raise exc
192
- finally:
193
- pbar.update(1)
194
-
195
- return task_state_list
175
+ return cached_task_state_list
176
+
177
+ logger.info(f'Processing {len(dataset_list)} samples, if data is large, it may take a while.')
178
+
179
+ def worker(sample: Sample) -> TaskState:
180
+ return self._predict_sample(sample, model_prediction_dir)
181
+
182
+ def on_result(sample: Sample, task_state: TaskState) -> None:
183
+ model_result = self.cache_manager.save_prediction_cache(subset, task_state, self.benchmark.save_metadata)
184
+ logger.debug(f'Model result: \n{model_result.pretty_print()}')
185
+
186
+ def on_error(sample: Sample, exc: Exception) -> None:
187
+ tb_str = traceback.format_exc()
188
+ logger.error(f'{sample.model_dump_json(indent=2)} prediction failed: due to {exc}\nTraceback:\n{tb_str}')
189
+ if self.task_config.ignore_errors:
190
+ logger.warning('Error ignored, continuing with next sample.')
191
+ return
192
+ raise exc
193
+
194
+ finished_task_states = run_in_threads_with_progress(
195
+ dataset_list,
196
+ worker,
197
+ desc=f'Predicting[{self.benchmark_name}@{subset}]: ',
198
+ max_workers=self.task_config.eval_batch_size,
199
+ heartbeat_sec=HEARTBEAT_INTERVAL_SEC,
200
+ on_result=on_result,
201
+ on_error=on_error,
202
+ filter_none_results=True,
203
+ )
204
+
205
+ logger.info(f'Finished getting predictions for subset: {subset}.')
206
+ return cached_task_state_list + finished_task_states
196
207
 
197
208
  def _predict_sample(self, sample: Sample, model_prediction_dir: str) -> TaskState:
198
209
  """
@@ -229,50 +240,58 @@ class DefaultEvaluator(Evaluator):
229
240
  """
230
241
  # Initialize sample score list and filter cached reviews if caching is enabled
231
242
  if self.use_cache and not self.task_config.rerun_review:
232
- sample_score_list, task_states = self.cache_manager.filter_review_cache(subset, task_states)
243
+ cached_score_list, task_states = self.cache_manager.filter_review_cache(subset, task_states)
233
244
  else:
234
245
  # Init a clean sample score list
235
- sample_score_list = []
246
+ cached_score_list = []
236
247
  self.cache_manager.delete_review_cache(subset)
237
248
 
238
249
  if not task_states:
239
- return sample_score_list
240
-
241
- # Process task states in parallel using ThreadPoolExecutor
242
- with ThreadPoolExecutor(max_workers=min(len(task_states), self.task_config.judge_worker_num)) as executor:
243
- # Submit all review tasks
244
- future_to_task_state = {
245
- executor.submit(self._review_task_state, task_state): task_state
246
- for task_state in task_states
247
- }
248
-
249
- # Process completed tasks with progress bar
250
- with tqdm(total=len(task_states), desc=f'Reviewing[{self.benchmark_name}@{subset}]: ') as pbar:
251
- for future in as_completed(future_to_task_state):
252
- task_state = future_to_task_state[future]
253
- try:
254
- sample_score = future.result()
255
- sample_score_list.append(sample_score)
256
-
257
- # Save the review result to cache for future use
258
- review_result = self.cache_manager.save_review_cache(
259
- subset=subset,
260
- task_state=task_state,
261
- sample_score=sample_score,
262
- save_metadata=self.benchmark.save_metadata
263
- )
264
- logger.debug(f'Review result: \n{review_result.model_dump_json(indent=2)}')
265
-
266
- except Exception as exc:
267
- logger.error(f'Error when review sample {task_state.sample_id}: {exc}')
268
- if self.task_config.ignore_errors:
269
- logger.warning('Error ignored, continuing with next sample.')
270
- else:
271
- raise exc
272
- finally:
273
- pbar.update(1)
274
-
275
- return sample_score_list
250
+ return cached_score_list
251
+
252
+ logger.info(f'Reviewing {len(task_states)} samples, if data is large, it may take a while.')
253
+
254
+ def worker(task_state: TaskState) -> SampleScore:
255
+ return self._review_task_state(task_state)
256
+
257
+ def on_result(task_state: TaskState, sample_score: SampleScore) -> None:
258
+ review_result = self.cache_manager.save_review_cache(
259
+ subset=subset,
260
+ task_state=task_state,
261
+ sample_score=sample_score,
262
+ save_metadata=self.benchmark.save_metadata
263
+ )
264
+ logger.debug(f'Review result: \n{review_result.pretty_print()}')
265
+
266
+ def on_error(task_state: TaskState, exc: Exception) -> None:
267
+ tb_str = traceback.format_exc()
268
+ logger.error(f'Error when review sample {task_state.sample_id}: due to {exc}\nTraceback:\n{tb_str}')
269
+ if self.task_config.ignore_errors:
270
+ logger.warning('Error ignored, continuing with next sample.')
271
+ return
272
+ raise exc
273
+
274
+ # Run reviews in parallel
275
+ reviewed_scores = run_in_threads_with_progress(
276
+ task_states,
277
+ worker,
278
+ desc=f'Reviewing[{self.benchmark_name}@{subset}]: ',
279
+ max_workers=self.task_config.judge_worker_num,
280
+ heartbeat_sec=HEARTBEAT_INTERVAL_SEC,
281
+ on_error=on_error,
282
+ # Do not persist interim results when batch scoring is enabled
283
+ on_result=None if self.benchmark.use_batch_scoring else on_result,
284
+ filter_none_results=False,
285
+ )
286
+
287
+ # Batch calculate metrics if supported by the benchmark
288
+ if self.benchmark.use_batch_scoring:
289
+ reviewed_scores = self._batch_review_task_states(
290
+ task_states=task_states, reviewed_scores=reviewed_scores, on_result=on_result
291
+ )
292
+
293
+ logger.info(f'Finished reviewing subset: {subset}. Total reviewed: {len(reviewed_scores)}')
294
+ return cached_score_list + reviewed_scores
276
295
 
277
296
  def _review_task_state(self, task_state: TaskState) -> SampleScore:
278
297
  """
@@ -288,6 +307,40 @@ class DefaultEvaluator(Evaluator):
288
307
  sample_score = self.benchmark.calculate_metrics(task_state=task_state)
289
308
  return sample_score
290
309
 
310
+ def _batch_review_task_states(
311
+ self, task_states: List[TaskState], reviewed_scores: List[SampleScore],
312
+ on_result: Callable[[TaskState, SampleScore], None]
313
+ ) -> List[SampleScore]:
314
+ valid_indices = [i for i, score in enumerate(reviewed_scores) if score is not None]
315
+ if not valid_indices:
316
+ return reviewed_scores
317
+
318
+ task_states = [task_states[i] for i in valid_indices]
319
+ reviewed_scores = [reviewed_scores[i] for i in valid_indices]
320
+
321
+ # Iterate in batches with progress bar
322
+ all_reviewed_scores = []
323
+ total = len(task_states)
324
+ batch_size = self.task_config.judge_worker_num
325
+ with tqdm(total=total, desc='Scoring (batch)', unit='sample') as pbar:
326
+ for start in range(0, total, batch_size):
327
+ # Process batch
328
+ end = min(start + batch_size, total)
329
+ batch_task_states = task_states[start:end]
330
+ batch_scores = reviewed_scores[start:end]
331
+ # Batch calculate metrics
332
+ updated_reviewed_scores = self.benchmark.batch_calculate_metrics(
333
+ task_states=batch_task_states, sample_scores=batch_scores
334
+ )
335
+ # Append results
336
+ all_reviewed_scores.extend(updated_reviewed_scores)
337
+ # Save each result to cache
338
+ for task_state, sample_score in zip(batch_task_states, updated_reviewed_scores):
339
+ on_result(task_state, sample_score)
340
+
341
+ pbar.update(len(batch_task_states))
342
+ return all_reviewed_scores
343
+
291
344
  def get_report(self, agg_score_dict: Dict[str, List[AggScore]]) -> Report:
292
345
  """
293
346
  Generate a comprehensive evaluation report from aggregated scores.
@@ -317,7 +370,7 @@ class DefaultEvaluator(Evaluator):
317
370
 
318
371
  # Generate and display a summary table of results
319
372
  try:
320
- report_table = gen_table(report_list=[report], add_overall_metric=True)
373
+ report_table = gen_table(report_list=[report], add_overall_metric=self.benchmark.add_overall_metric)
321
374
  logger.info(f'\n{self.benchmark_name} report table:'
322
375
  f'\n{report_table} \n')
323
376
  except Exception:
@@ -335,3 +388,6 @@ class DefaultEvaluator(Evaluator):
335
388
  report.to_json(report_file)
336
389
  logger.info(f'Dump report to: {report_file} \n')
337
390
  return report
391
+
392
+ def finalize(self, *args, **kwargs):
393
+ self.benchmark.finalize(*args, **kwargs)
File without changes
@@ -0,0 +1,338 @@
1
+ # flake8: noqa
2
+ import numpy as np
3
+ import os
4
+ import pandas as pd
5
+ import time
6
+ import torch
7
+ import warnings
8
+ from collections import defaultdict
9
+
10
+ from .utils import (
11
+ bert_cos_score_idf,
12
+ get_bert_embedding,
13
+ get_hash,
14
+ get_idf_dict,
15
+ get_model,
16
+ get_tokenizer,
17
+ lang2model,
18
+ model2layers,
19
+ sent_encode,
20
+ )
21
+
22
+
23
+ class BERTScorer:
24
+ """
25
+ BERTScore Scorer Object.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model_id_or_path=None,
31
+ model_type=None,
32
+ num_layers=None,
33
+ batch_size=64,
34
+ nthreads=4,
35
+ all_layers=False,
36
+ idf=False,
37
+ idf_sents=None,
38
+ device=None,
39
+ lang=None,
40
+ rescale_with_baseline=False,
41
+ baseline_path=None,
42
+ use_fast_tokenizer=False,
43
+ ):
44
+ """
45
+ Args:
46
+ - :param: `model_type` (str): contexual embedding model specification, default using the suggested
47
+ model for the target langauge; has to specify at least one of
48
+ `model_type` or `lang`
49
+ - :param: `num_layers` (int): the layer of representation to use.
50
+ default using the number of layer tuned on WMT16 correlation data
51
+ - :param: `verbose` (bool): turn on intermediate status update
52
+ - :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given)
53
+ - :param: `idf_sents` (List of str): list of sentences used to compute the idf weights
54
+ - :param: `device` (str): on which the contextual embedding model will be allocated on.
55
+ If this argument is None, the model lives on cuda:0 if cuda is available.
56
+ - :param: `batch_size` (int): bert score processing batch size
57
+ - :param: `nthreads` (int): number of threads
58
+ - :param: `lang` (str): language of the sentences; has to specify
59
+ at least one of `model_type` or `lang`. `lang` needs to be
60
+ specified when `rescale_with_baseline` is True.
61
+ - :param: `return_hash` (bool): return hash code of the setting
62
+ - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
63
+ - :param: `baseline_path` (str): customized baseline file
64
+ - :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer
65
+ """
66
+
67
+ assert (lang is not None or model_type is not None), 'Either lang or model_type should be specified'
68
+
69
+ if rescale_with_baseline:
70
+ assert (lang is not None), 'Need to specify Language when rescaling with baseline'
71
+
72
+ if device is None:
73
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
74
+ else:
75
+ self.device = device
76
+
77
+ self._lang = lang
78
+ self._rescale_with_baseline = rescale_with_baseline
79
+ self._idf = idf
80
+ self.batch_size = batch_size
81
+ self.nthreads = nthreads
82
+ self.all_layers = all_layers
83
+ self.model_id_or_path = model_id_or_path
84
+
85
+ if model_type is None:
86
+ lang = lang.lower()
87
+ self._model_type = lang2model[lang]
88
+ else:
89
+ self._model_type = model_type
90
+
91
+ if num_layers is None:
92
+ self._num_layers = model2layers[self.model_type]
93
+ else:
94
+ self._num_layers = num_layers
95
+
96
+ # Building model and tokenizer
97
+ self._use_fast_tokenizer = use_fast_tokenizer
98
+ self._tokenizer = get_tokenizer(self.model_id_or_path, self._use_fast_tokenizer)
99
+ self._model = get_model(self.model_id_or_path, self.num_layers, self.all_layers)
100
+ self._model.to(self.device)
101
+
102
+ self._idf_dict = None
103
+ if idf_sents is not None:
104
+ self.compute_idf(idf_sents)
105
+
106
+ self._baseline_vals = None
107
+ self.baseline_path = baseline_path
108
+ self.use_custom_baseline = self.baseline_path is not None
109
+ if self.baseline_path is None:
110
+ self.baseline_path = os.path.join(
111
+ os.path.dirname(__file__),
112
+ f'rescale_baseline/{self.lang}/{self.model_type}.tsv',
113
+ )
114
+
115
+ @property
116
+ def lang(self):
117
+ return self._lang
118
+
119
+ @property
120
+ def idf(self):
121
+ return self._idf
122
+
123
+ @property
124
+ def model_type(self):
125
+ return self._model_type
126
+
127
+ @property
128
+ def num_layers(self):
129
+ return self._num_layers
130
+
131
+ @property
132
+ def rescale_with_baseline(self):
133
+ return self._rescale_with_baseline
134
+
135
+ @property
136
+ def baseline_vals(self):
137
+ if self._baseline_vals is None:
138
+ if os.path.isfile(self.baseline_path):
139
+ if not self.all_layers:
140
+ self._baseline_vals = torch.from_numpy(
141
+ pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy()
142
+ )[1:].float()
143
+ else:
144
+ self._baseline_vals = (
145
+ torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()
146
+ )
147
+ else:
148
+ raise ValueError(f'Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}')
149
+
150
+ return self._baseline_vals
151
+
152
+ @property
153
+ def use_fast_tokenizer(self):
154
+ return self._use_fast_tokenizer
155
+
156
+ @property
157
+ def hash(self):
158
+ return get_hash(
159
+ self.model_type,
160
+ self.num_layers,
161
+ self.idf,
162
+ self.rescale_with_baseline,
163
+ self.use_custom_baseline,
164
+ self.use_fast_tokenizer,
165
+ )
166
+
167
+ def compute_idf(self, sents):
168
+ """
169
+ Args:
170
+
171
+ """
172
+ if self._idf_dict is not None:
173
+ warnings.warn('Overwriting the previous importance weights.')
174
+
175
+ self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads)
176
+
177
+ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
178
+ """
179
+ Args:
180
+ - :param: `cands` (list of str): candidate sentences
181
+ - :param: `refs` (list of str or list of list of str): reference sentences
182
+
183
+ Return:
184
+ - :param: `(P, R, F)`: each is of shape (N); N = number of input
185
+ candidate reference pairs. if returning hashcode, the
186
+ output will be ((P, R, F), hashcode). If a candidate have
187
+ multiple references, the returned score of this candidate is
188
+ the *best* score among all references.
189
+ """
190
+
191
+ ref_group_boundaries = None
192
+ if not isinstance(refs[0], str):
193
+ ref_group_boundaries = []
194
+ ori_cands, ori_refs = cands, refs
195
+ cands, refs = [], []
196
+ count = 0
197
+ for cand, ref_group in zip(ori_cands, ori_refs):
198
+ cands += [cand] * len(ref_group)
199
+ refs += ref_group
200
+ ref_group_boundaries.append((count, count + len(ref_group)))
201
+ count += len(ref_group)
202
+
203
+ if verbose:
204
+ print('calculating scores...')
205
+ start = time.perf_counter()
206
+
207
+ if self.idf:
208
+ assert self._idf_dict, 'IDF weights are not computed'
209
+ idf_dict = self._idf_dict
210
+ else:
211
+ idf_dict = defaultdict(lambda: 1.0)
212
+ idf_dict[self._tokenizer.sep_token_id] = 0
213
+ idf_dict[self._tokenizer.cls_token_id] = 0
214
+
215
+ all_preds = bert_cos_score_idf(
216
+ self._model,
217
+ refs,
218
+ cands,
219
+ self._tokenizer,
220
+ idf_dict,
221
+ verbose=verbose,
222
+ device=self.device,
223
+ batch_size=batch_size,
224
+ all_layers=self.all_layers,
225
+ ).cpu()
226
+
227
+ if ref_group_boundaries is not None:
228
+ max_preds = []
229
+ for start, end in ref_group_boundaries:
230
+ max_preds.append(all_preds[start:end].max(dim=0)[0])
231
+ all_preds = torch.stack(max_preds, dim=0)
232
+
233
+ if self.rescale_with_baseline:
234
+ all_preds = (all_preds - self.baseline_vals) / (1 - self.baseline_vals)
235
+
236
+ out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F
237
+
238
+ if verbose:
239
+ time_diff = time.perf_counter() - start
240
+ print(f'done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec')
241
+
242
+ if return_hash:
243
+ out = tuple([out, self.hash])
244
+
245
+ return out
246
+
247
+ def plot_example(self, candidate, reference, fname=''):
248
+ """
249
+ Args:
250
+ - :param: `candidate` (str): a candidate sentence
251
+ - :param: `reference` (str): a reference sentence
252
+ - :param: `fname` (str): path to save the output plot
253
+ """
254
+ import matplotlib.pyplot as plt
255
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
256
+
257
+ assert isinstance(candidate, str)
258
+ assert isinstance(reference, str)
259
+
260
+ idf_dict = defaultdict(lambda: 1.0)
261
+ idf_dict[self._tokenizer.sep_token_id] = 0
262
+ idf_dict[self._tokenizer.cls_token_id] = 0
263
+
264
+ hyp_embedding, masks, padded_idf = get_bert_embedding(
265
+ [candidate],
266
+ self._model,
267
+ self._tokenizer,
268
+ idf_dict,
269
+ device=self.device,
270
+ all_layers=False,
271
+ )
272
+ ref_embedding, masks, padded_idf = get_bert_embedding(
273
+ [reference],
274
+ self._model,
275
+ self._tokenizer,
276
+ idf_dict,
277
+ device=self.device,
278
+ all_layers=False,
279
+ )
280
+ ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
281
+ hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
282
+ sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
283
+ sim = sim.squeeze(0).cpu()
284
+
285
+ r_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, reference)][1:-1]
286
+ h_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, candidate)][1:-1]
287
+ sim = sim[1:-1, 1:-1]
288
+
289
+ if self.rescale_with_baseline:
290
+ sim = (sim - self.baseline_vals[2].item()) / (1 - self.baseline_vals[2].item())
291
+
292
+ fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens)))
293
+ im = ax.imshow(sim, cmap='Blues', vmin=0, vmax=1)
294
+
295
+ # We want to show all ticks...
296
+ ax.set_xticks(np.arange(len(r_tokens)))
297
+ ax.set_yticks(np.arange(len(h_tokens)))
298
+ # ... and label them with the respective list entries
299
+ ax.set_xticklabels(r_tokens, fontsize=10)
300
+ ax.set_yticklabels(h_tokens, fontsize=10)
301
+ ax.grid(False)
302
+ plt.xlabel('Reference (tokenized)', fontsize=14)
303
+ plt.ylabel('Candidate (tokenized)', fontsize=14)
304
+ title = 'Similarity Matrix'
305
+ if self.rescale_with_baseline:
306
+ title += ' (after Rescaling)'
307
+ plt.title(title, fontsize=14)
308
+
309
+ divider = make_axes_locatable(ax)
310
+ cax = divider.append_axes('right', size='2%', pad=0.2)
311
+ fig.colorbar(im, cax=cax)
312
+
313
+ # Rotate the tick labels and set their alignment.
314
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
315
+
316
+ # Loop over data dimensions and create text annotations.
317
+ for i in range(len(h_tokens)):
318
+ for j in range(len(r_tokens)):
319
+ text = ax.text(
320
+ j,
321
+ i,
322
+ '{:.3f}'.format(sim[i, j].item()),
323
+ ha='center',
324
+ va='center',
325
+ color='k' if sim[i, j].item() < 0.5 else 'w',
326
+ )
327
+
328
+ fig.tight_layout()
329
+ if fname != '':
330
+ plt.savefig(fname, dpi=100)
331
+ print('Saved figure to file: ', fname)
332
+ plt.show()
333
+
334
+ def __repr__(self):
335
+ return f'{self.__class__.__name__}(hash={self.hash}, batch_size={self.batch_size}, nthreads={self.nthreads})'
336
+
337
+ def __str__(self):
338
+ return self.__repr__()