evalscope 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (324) hide show
  1. evalscope/api/benchmark/__init__.py +9 -1
  2. evalscope/api/benchmark/adapters/__init__.py +4 -0
  3. evalscope/api/benchmark/adapters/agent_adapter.py +8 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +75 -4
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +5 -2
  7. evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
  8. evalscope/api/benchmark/adapters/text2image_adapter.py +12 -10
  9. evalscope/api/benchmark/adapters/vision_language_adapter.py +8 -0
  10. evalscope/api/benchmark/benchmark.py +85 -2
  11. evalscope/api/benchmark/meta.py +10 -1
  12. evalscope/api/dataset/dataset.py +27 -6
  13. evalscope/api/dataset/loader.py +8 -3
  14. evalscope/api/evaluator/cache.py +31 -4
  15. evalscope/api/evaluator/evaluator.py +5 -0
  16. evalscope/api/evaluator/state.py +17 -1
  17. evalscope/api/messages/__init__.py +1 -0
  18. evalscope/api/messages/chat_message.py +52 -2
  19. evalscope/api/metric/__init__.py +1 -1
  20. evalscope/api/metric/metric.py +6 -1
  21. evalscope/api/metric/scorer.py +15 -7
  22. evalscope/api/mixin/__init__.py +1 -1
  23. evalscope/api/mixin/llm_judge_mixin.py +2 -0
  24. evalscope/api/mixin/sandbox_mixin.py +182 -0
  25. evalscope/api/model/generate_config.py +10 -6
  26. evalscope/api/model/model.py +5 -2
  27. evalscope/api/tool/tool_info.py +1 -1
  28. evalscope/app/app.py +3 -0
  29. evalscope/app/ui/multi_model.py +6 -1
  30. evalscope/app/ui/single_model.py +11 -5
  31. evalscope/app/utils/data_utils.py +8 -7
  32. evalscope/app/utils/env_utils.py +12 -0
  33. evalscope/app/utils/text_utils.py +14 -12
  34. evalscope/app/utils/visualization.py +2 -2
  35. evalscope/arguments.py +8 -4
  36. evalscope/backend/opencompass/backend_manager.py +0 -2
  37. evalscope/backend/rag_eval/utils/embedding.py +9 -1
  38. evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
  39. evalscope/benchmarks/ai2d/ai2d_adapter.py +54 -0
  40. evalscope/benchmarks/aime/aime24_adapter.py +5 -0
  41. evalscope/benchmarks/aime/aime25_adapter.py +136 -1
  42. evalscope/benchmarks/aime/grader.py +307 -0
  43. evalscope/benchmarks/aime/math_normalize.py +189 -0
  44. evalscope/benchmarks/amc/amc_adapter.py +51 -0
  45. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -0
  46. evalscope/benchmarks/bbh/bbh_adapter.py +43 -17
  47. evalscope/benchmarks/bfcl/{bfcl_adapter.py → v3/bfcl_v3_adapter.py} +131 -19
  48. evalscope/benchmarks/bfcl/{generation.py → v3/generation.py} +9 -9
  49. evalscope/benchmarks/bfcl/v3/utils.py +23 -0
  50. evalscope/benchmarks/bfcl/v4/__init__.py +0 -0
  51. evalscope/benchmarks/bfcl/v4/bfcl_v4_adapter.py +229 -0
  52. evalscope/benchmarks/bfcl/v4/utils.py +410 -0
  53. evalscope/benchmarks/biomix_qa/__init__.py +0 -0
  54. evalscope/benchmarks/biomix_qa/biomix_qa_adapter.py +36 -0
  55. evalscope/benchmarks/blink/__init__.py +0 -0
  56. evalscope/benchmarks/blink/blink_adapter.py +61 -0
  57. evalscope/benchmarks/ceval/ceval_adapter.py +1 -2
  58. evalscope/benchmarks/chartqa/__init__.py +0 -0
  59. evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
  60. evalscope/benchmarks/chartqa/utils.py +38 -0
  61. evalscope/benchmarks/coin_flip/__init__.py +0 -0
  62. evalscope/benchmarks/coin_flip/coin_flip_adapter.py +128 -0
  63. evalscope/benchmarks/commonsense_qa/__init__.py +0 -0
  64. evalscope/benchmarks/commonsense_qa/commonsense_qa_adapter.py +32 -0
  65. evalscope/benchmarks/competition_math/competition_math_adapter.py +5 -0
  66. evalscope/benchmarks/data_collection/data_collection_adapter.py +24 -19
  67. evalscope/benchmarks/docvqa/__init__.py +0 -0
  68. evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
  69. evalscope/benchmarks/drivelology/__init__.py +0 -0
  70. evalscope/benchmarks/drivelology/drivelology_binary_adapter.py +170 -0
  71. evalscope/benchmarks/drivelology/drivelology_multilabel_adapter.py +254 -0
  72. evalscope/benchmarks/drivelology/drivelology_selection_adapter.py +49 -0
  73. evalscope/benchmarks/drivelology/drivelology_writing_adapter.py +218 -0
  74. evalscope/benchmarks/drop/drop_adapter.py +15 -44
  75. evalscope/benchmarks/drop/utils.py +97 -0
  76. evalscope/benchmarks/frames/frames_adapter.py +2 -1
  77. evalscope/benchmarks/general_arena/general_arena_adapter.py +7 -2
  78. evalscope/benchmarks/general_arena/utils.py +2 -1
  79. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
  80. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  81. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +25 -9
  82. evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
  83. evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +159 -0
  84. evalscope/benchmarks/halu_eval/__init__.py +0 -0
  85. evalscope/benchmarks/halu_eval/halu_eval_adapter.py +128 -0
  86. evalscope/benchmarks/halu_eval/halu_eval_instructions.py +84 -0
  87. evalscope/benchmarks/healthbench/__init__.py +0 -0
  88. evalscope/benchmarks/healthbench/healthbench_adapter.py +282 -0
  89. evalscope/benchmarks/healthbench/utils.py +102 -0
  90. evalscope/benchmarks/hle/hle_adapter.py +3 -2
  91. evalscope/benchmarks/humaneval/humaneval_adapter.py +24 -52
  92. evalscope/benchmarks/humaneval/utils.py +235 -0
  93. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  94. evalscope/benchmarks/image_edit/__init__.py +0 -0
  95. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  96. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  97. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  98. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  99. evalscope/benchmarks/infovqa/__init__.py +0 -0
  100. evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
  101. evalscope/benchmarks/live_code_bench/evaluate_utils.py +13 -6
  102. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +66 -54
  103. evalscope/benchmarks/live_code_bench/sandbox_evaluate_utils.py +220 -0
  104. evalscope/benchmarks/logi_qa/__int__.py +0 -0
  105. evalscope/benchmarks/logi_qa/logi_qa_adapter.py +41 -0
  106. evalscope/benchmarks/math_500/math_500_adapter.py +5 -1
  107. evalscope/benchmarks/math_qa/__init__.py +0 -0
  108. evalscope/benchmarks/math_qa/math_qa_adapter.py +35 -0
  109. evalscope/benchmarks/math_verse/__init__.py +0 -0
  110. evalscope/benchmarks/math_verse/math_verse_adapter.py +105 -0
  111. evalscope/benchmarks/math_vision/__init__.py +0 -0
  112. evalscope/benchmarks/math_vision/math_vision_adapter.py +116 -0
  113. evalscope/benchmarks/math_vista/__init__.py +0 -0
  114. evalscope/benchmarks/math_vista/math_vista_adapter.py +114 -0
  115. evalscope/benchmarks/med_mcqa/__init__.py +0 -0
  116. evalscope/benchmarks/med_mcqa/med_mcqa_adapter.py +32 -0
  117. evalscope/benchmarks/minerva_math/__init__.py +0 -0
  118. evalscope/benchmarks/minerva_math/minerva_math_adapter.py +53 -0
  119. evalscope/benchmarks/mm_bench/__init__.py +0 -0
  120. evalscope/benchmarks/mm_bench/mm_bench_adapter.py +99 -0
  121. evalscope/benchmarks/mm_star/__init__.py +0 -0
  122. evalscope/benchmarks/mm_star/mm_star_adapter.py +73 -0
  123. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
  124. evalscope/benchmarks/mmmu/__init__.py +0 -0
  125. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  126. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  127. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +124 -0
  128. evalscope/benchmarks/mri_mcqa/__init__.py +0 -0
  129. evalscope/benchmarks/mri_mcqa/mri_mcqa_adapter.py +34 -0
  130. evalscope/benchmarks/multi_if/__init__.py +0 -0
  131. evalscope/benchmarks/multi_if/ifeval.py +3354 -0
  132. evalscope/benchmarks/multi_if/metrics.py +120 -0
  133. evalscope/benchmarks/multi_if/multi_if_adapter.py +161 -0
  134. evalscope/benchmarks/music_trivia/__init__.py +0 -0
  135. evalscope/benchmarks/music_trivia/music_trivia_adapter.py +36 -0
  136. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +7 -6
  137. evalscope/benchmarks/ner/__init__.py +0 -0
  138. evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
  139. evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
  140. evalscope/benchmarks/ner/copious_adapter.py +85 -0
  141. evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
  142. evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
  143. evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
  144. evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
  145. evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
  146. evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
  147. evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
  148. evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
  149. evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
  150. evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
  151. evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
  152. evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
  153. evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
  154. evalscope/benchmarks/ocr_bench/__init__.py +0 -0
  155. evalscope/benchmarks/ocr_bench/ocr_bench/__init__.py +0 -0
  156. evalscope/benchmarks/ocr_bench/ocr_bench/ocr_bench_adapter.py +101 -0
  157. evalscope/benchmarks/ocr_bench/ocr_bench_v2/IoUscore_metric.py +87 -0
  158. evalscope/benchmarks/ocr_bench/ocr_bench_v2/TEDS_metric.py +963 -0
  159. evalscope/benchmarks/ocr_bench/ocr_bench_v2/__init__.py +0 -0
  160. evalscope/benchmarks/ocr_bench/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
  161. evalscope/benchmarks/ocr_bench/ocr_bench_v2/page_ocr_metric.py +50 -0
  162. evalscope/benchmarks/ocr_bench/ocr_bench_v2/parallel.py +46 -0
  163. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/__init__.py +0 -0
  164. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/readme.txt +26 -0
  165. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
  166. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_eval/script.py +481 -0
  167. evalscope/benchmarks/ocr_bench/ocr_bench_v2/spotting_metric.py +179 -0
  168. evalscope/benchmarks/ocr_bench/ocr_bench_v2/utils.py +433 -0
  169. evalscope/benchmarks/ocr_bench/ocr_bench_v2/vqa_metric.py +254 -0
  170. evalscope/benchmarks/olympiad_bench/__init__.py +0 -0
  171. evalscope/benchmarks/olympiad_bench/olympiad_bench_adapter.py +163 -0
  172. evalscope/benchmarks/olympiad_bench/utils.py +565 -0
  173. evalscope/benchmarks/omni_bench/__init__.py +0 -0
  174. evalscope/benchmarks/omni_bench/omni_bench_adapter.py +86 -0
  175. evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
  176. evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
  177. evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
  178. evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
  179. evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
  180. evalscope/benchmarks/piqa/__init__.py +0 -0
  181. evalscope/benchmarks/piqa/piqa_adapter.py +32 -0
  182. evalscope/benchmarks/poly_math/__init__.py +0 -0
  183. evalscope/benchmarks/poly_math/poly_math_adapter.py +132 -0
  184. evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
  185. evalscope/benchmarks/pope/__init__.py +0 -0
  186. evalscope/benchmarks/pope/pope_adapter.py +112 -0
  187. evalscope/benchmarks/process_bench/process_bench_adapter.py +1 -0
  188. evalscope/benchmarks/pumed_qa/__init__.py +0 -0
  189. evalscope/benchmarks/pumed_qa/pubmed_qa_adapter.py +175 -0
  190. evalscope/benchmarks/qasc/__init__.py +0 -0
  191. evalscope/benchmarks/qasc/qasc_adapter.py +35 -0
  192. evalscope/benchmarks/real_world_qa/__init__.py +0 -0
  193. evalscope/benchmarks/real_world_qa/real_world_qa_adapter.py +64 -0
  194. evalscope/benchmarks/sciq/__init__.py +0 -0
  195. evalscope/benchmarks/sciq/sciq_adapter.py +36 -0
  196. evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
  197. evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
  198. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -1
  199. evalscope/benchmarks/simple_vqa/__init__.py +0 -0
  200. evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
  201. evalscope/benchmarks/siqa/__init__.py +0 -0
  202. evalscope/benchmarks/siqa/siqa_adapter.py +39 -0
  203. evalscope/benchmarks/tau_bench/tau2_bench/__init__.py +0 -0
  204. evalscope/benchmarks/tau_bench/tau2_bench/generation.py +158 -0
  205. evalscope/benchmarks/tau_bench/tau2_bench/tau2_bench_adapter.py +146 -0
  206. evalscope/benchmarks/tau_bench/tau_bench/__init__.py +0 -0
  207. evalscope/benchmarks/tau_bench/{generation.py → tau_bench/generation.py} +1 -1
  208. evalscope/benchmarks/tau_bench/{tau_bench_adapter.py → tau_bench/tau_bench_adapter.py} +29 -29
  209. evalscope/benchmarks/text2image/__init__.py +0 -0
  210. evalscope/benchmarks/{aigc/t2i → text2image}/evalmuse_adapter.py +3 -1
  211. evalscope/benchmarks/{aigc/t2i → text2image}/genai_bench_adapter.py +2 -2
  212. evalscope/benchmarks/{aigc/t2i → text2image}/general_t2i_adapter.py +1 -1
  213. evalscope/benchmarks/{aigc/t2i → text2image}/hpdv2_adapter.py +7 -2
  214. evalscope/benchmarks/{aigc/t2i → text2image}/tifa_adapter.py +1 -0
  215. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +3 -3
  216. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +1 -2
  217. evalscope/benchmarks/visu_logic/__init__.py +0 -0
  218. evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
  219. evalscope/benchmarks/wmt/__init__.py +0 -0
  220. evalscope/benchmarks/wmt/wmt24_adapter.py +294 -0
  221. evalscope/benchmarks/zerobench/__init__.py +0 -0
  222. evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
  223. evalscope/cli/start_app.py +7 -1
  224. evalscope/cli/start_perf.py +7 -1
  225. evalscope/config.py +103 -18
  226. evalscope/constants.py +18 -0
  227. evalscope/evaluator/evaluator.py +138 -82
  228. evalscope/metrics/bert_score/__init__.py +0 -0
  229. evalscope/metrics/bert_score/scorer.py +338 -0
  230. evalscope/metrics/bert_score/utils.py +697 -0
  231. evalscope/metrics/llm_judge.py +19 -7
  232. evalscope/metrics/math_parser.py +14 -0
  233. evalscope/metrics/metric.py +317 -13
  234. evalscope/metrics/metrics.py +37 -0
  235. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
  236. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
  237. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
  238. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
  239. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
  240. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
  241. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
  242. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
  243. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
  244. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
  245. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
  246. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
  247. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
  248. evalscope/models/image_edit_model.py +125 -0
  249. evalscope/models/model_apis.py +22 -0
  250. evalscope/models/openai_compatible.py +21 -0
  251. evalscope/models/text2image_model.py +2 -2
  252. evalscope/models/utils/openai.py +16 -6
  253. evalscope/perf/arguments.py +26 -4
  254. evalscope/perf/benchmark.py +76 -89
  255. evalscope/perf/http_client.py +31 -16
  256. evalscope/perf/main.py +15 -2
  257. evalscope/perf/plugin/api/base.py +9 -7
  258. evalscope/perf/plugin/api/custom_api.py +13 -58
  259. evalscope/perf/plugin/api/default_api.py +188 -79
  260. evalscope/perf/plugin/api/openai_api.py +85 -20
  261. evalscope/perf/plugin/datasets/base.py +21 -0
  262. evalscope/perf/plugin/datasets/custom.py +2 -3
  263. evalscope/perf/plugin/datasets/flickr8k.py +2 -2
  264. evalscope/perf/plugin/datasets/kontext_bench.py +2 -2
  265. evalscope/perf/plugin/datasets/line_by_line.py +2 -3
  266. evalscope/perf/plugin/datasets/longalpaca.py +2 -3
  267. evalscope/perf/plugin/datasets/openqa.py +2 -4
  268. evalscope/perf/plugin/datasets/random_dataset.py +1 -3
  269. evalscope/perf/plugin/datasets/random_vl_dataset.py +2 -2
  270. evalscope/perf/utils/benchmark_util.py +43 -27
  271. evalscope/perf/utils/db_util.py +14 -19
  272. evalscope/perf/utils/local_server.py +3 -44
  273. evalscope/perf/utils/log_utils.py +21 -6
  274. evalscope/report/__init__.py +13 -3
  275. evalscope/report/combinator.py +91 -20
  276. evalscope/report/generator.py +8 -87
  277. evalscope/report/report.py +8 -4
  278. evalscope/run.py +13 -5
  279. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  280. evalscope/utils/argument_utils.py +1 -1
  281. evalscope/utils/chat_service.py +1 -1
  282. evalscope/utils/function_utils.py +249 -12
  283. evalscope/utils/import_utils.py +73 -1
  284. evalscope/utils/io_utils.py +132 -7
  285. evalscope/utils/json_schema.py +25 -2
  286. evalscope/utils/logger.py +69 -18
  287. evalscope/utils/model_utils.py +4 -3
  288. evalscope/utils/multi_choices.py +39 -7
  289. evalscope/utils/ner.py +377 -0
  290. evalscope/version.py +2 -2
  291. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/METADATA +252 -408
  292. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/RECORD +290 -154
  293. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/WHEEL +1 -1
  294. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/top_level.txt +0 -1
  295. evalscope/api/mixin/dataset_mixin.py +0 -105
  296. evalscope/benchmarks/aigc/i2i/general_i2i_adapter.py +0 -44
  297. tests/__init__.py +0 -1
  298. tests/aigc/__init__.py +0 -1
  299. tests/aigc/test_t2i.py +0 -142
  300. tests/benchmark/__init__.py +0 -1
  301. tests/benchmark/test_eval.py +0 -386
  302. tests/cli/__init__.py +0 -1
  303. tests/cli/test_all.py +0 -229
  304. tests/cli/test_collection.py +0 -96
  305. tests/cli/test_custom.py +0 -268
  306. tests/perf/__init__.py +0 -1
  307. tests/perf/test_perf.py +0 -176
  308. tests/rag/test_clip_benchmark.py +0 -90
  309. tests/rag/test_mteb.py +0 -213
  310. tests/rag/test_ragas.py +0 -128
  311. tests/swift/__init__.py +0 -1
  312. tests/swift/test_run_swift_eval.py +0 -146
  313. tests/swift/test_run_swift_vlm_eval.py +0 -128
  314. tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
  315. tests/test_run_all.py +0 -12
  316. tests/utils.py +0 -13
  317. tests/vlm/__init__.py +0 -1
  318. tests/vlm/test_vlmeval.py +0 -102
  319. /evalscope/benchmarks/{aigc → aa_lcr}/__init__.py +0 -0
  320. /evalscope/benchmarks/{aigc/i2i → ai2d}/__init__.py +0 -0
  321. /evalscope/benchmarks/{aigc/t2i → amc}/__init__.py +0 -0
  322. {tests/rag → evalscope/benchmarks/bfcl/v3}/__init__.py +0 -0
  323. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info}/entry_points.txt +0 -0
  324. {evalscope-1.0.0.dist-info → evalscope-1.2.0.dist-info/licenses}/LICENSE +0 -0
@@ -2,6 +2,7 @@ import os
2
2
  import re
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
+ from evalscope.api.messages import ChatMessage, ChatMessageSystem, ChatMessageUser
5
6
  from evalscope.constants import JudgeScoreType
6
7
  from evalscope.utils.logger import get_logger
7
8
 
@@ -109,20 +110,31 @@ class LLMJudge:
109
110
  config=GenerateConfig(**self.generation_config),
110
111
  )
111
112
 
112
- def judge(self, prompt: str, system_prompt: Optional[str] = None) -> str:
113
+ def judge(
114
+ self,
115
+ prompt: str = '',
116
+ system_prompt: Optional[str] = None,
117
+ messages: Optional[List[ChatMessage]] = None
118
+ ) -> str:
113
119
  """
120
+ Generate a response from the LLM based on the provided prompt and context.
121
+ If messages is provided, it will be used as the input context.
122
+
114
123
  Args:
115
124
  prompt (str): The prompt to evaluate
116
125
  system_prompt (str, optional): The system prompt to use for the evaluation
126
+ messages (List[ChatMessage], optional): A list of chat messages to include in the evaluation
117
127
  Returns:
118
128
  str: The response from the LLM
119
129
  """
120
- from evalscope.api.messages import ChatMessageSystem, ChatMessageUser
121
-
122
- system_content = system_prompt or self.system_prompt
123
- input_messages = [ChatMessageUser(content=prompt)]
124
- if system_content:
125
- input_messages.insert(0, ChatMessageSystem(content=system_content))
130
+ # parse messages
131
+ if messages is not None:
132
+ input_messages = messages
133
+ else:
134
+ system_content = system_prompt or self.system_prompt
135
+ input_messages = [ChatMessageUser(content=prompt)]
136
+ if system_content:
137
+ input_messages.insert(0, ChatMessageSystem(content=system_content))
126
138
  try:
127
139
  # Send request using ServerModelAdapter
128
140
  response = self.model.generate(input_messages)
@@ -211,6 +211,11 @@ def strip_answer_string(string):
211
211
  # Remove grade level (e.g., 12th grade) and just maintain the integer
212
212
  string = re.sub(r'thgrade$', '', string)
213
213
 
214
+ # Normalize thousands-formatted numbers (e.g., 70,000 or -1,234,567.89) by removing commas
215
+ # This must run before the "list of integers" sorting to avoid misclassifying numbers with thousand separators.
216
+ if re.fullmatch(r'\s*-?\d{1,3}(?:,\d{3})+(?:\.\d+)?\s*', string):
217
+ string = string.replace(',', '')
218
+
214
219
  # If the answer is a list of integers (without parenthesis), sort them
215
220
  if re.fullmatch(r'(\s*-?\d+\s*,)*\s*-?\d+\s*', string):
216
221
  # Split the string into a list of integers
@@ -262,6 +267,8 @@ def extract_answer(pred_str, use_last_number=True):
262
267
  elif '答案是' in pred_str:
263
268
  # Handle Chinese few-shot multiple choice problem answer extraction
264
269
  pred = pred_str.split('答案是')[1].strip().split('\n\n')[0].strip()
270
+ elif 'ANSWER:' in pred_str:
271
+ pred = pred_str.split('ANSWER:')[-1].strip()
265
272
  else: # use the last number
266
273
  if use_last_number:
267
274
  pattern = '-?\d*\.?\d+'
@@ -529,3 +536,10 @@ def symbolic_equal(a, b):
529
536
  pass
530
537
 
531
538
  return False
539
+
540
+
541
+ if __name__ == '__main__':
542
+ print(math_equal('\n\\boxed{70,\\!000}\n', '70000'))
543
+ print(extract_answer('The answer is \\boxed{70,\\!000}'))
544
+ print(strip_answer_string(extract_answer('The answer is \\boxed{70,\\!000}')))
545
+ print(math_equal(extract_answer('The answer is \\boxed{70,\\!000}'), '70000'))
@@ -1,16 +1,27 @@
1
+ import json
2
+ import numpy as np
3
+ import os
1
4
  from collections import defaultdict
2
- from typing import List
5
+ from typing import Dict, List
3
6
 
4
- from evalscope.api.metric import Aggregator, AggScore, Metric, SampleScore, T2IMetric
7
+ from evalscope.api.metric import Aggregator, AggScore, Metric, SampleScore, SingletonMetric, T2IMetric
5
8
  from evalscope.api.registry import register_aggregation, register_metric
6
- from .metrics import mean
9
+ from evalscope.utils.import_utils import check_import
10
+ from .metrics import calculate_pass_at_k, calculate_pass_hat_k, mean, normalize_text
11
+
12
+ # ##################
13
+ # NLP Metrics ######
14
+ # ##################
7
15
 
8
16
 
9
17
  @register_metric(name='exact_match')
10
18
  class ExactMatch(Metric):
11
19
 
12
20
  def apply(self, predictions, references):
13
- return [float(prediction == reference) for prediction, reference in zip(predictions, references)]
21
+ return [
22
+ float(normalize_text(prediction) == normalize_text(reference))
23
+ for prediction, reference in zip(predictions, references)
24
+ ]
14
25
 
15
26
 
16
27
  @register_metric(name='acc')
@@ -30,13 +41,12 @@ class Accuracy(ExactMatch):
30
41
  results.append(0.0)
31
42
  return results
32
43
  elif self.numeric:
33
- from .math_parser import extract_answer, math_equal, strip_answer_string
44
+ from .math_parser import math_equal, strip_answer_string
34
45
 
35
46
  results = []
36
47
  for prediction, reference in zip(predictions, references):
37
- pred_answer = strip_answer_string(extract_answer(prediction))
38
48
  ref_answer = strip_answer_string(reference)
39
- results.append(float(math_equal(pred_answer, ref_answer)))
49
+ results.append(float(math_equal(prediction, ref_answer)))
40
50
 
41
51
  return results
42
52
  else:
@@ -92,9 +102,114 @@ class MultiChoiceAcc(Metric):
92
102
  return res
93
103
 
94
104
 
105
+ @register_metric(name='anls')
106
+ class ANLS(Metric):
107
+
108
+ def __init__(self, thresh_hold=0.5):
109
+ self.thresh_hold = thresh_hold
110
+
111
+ def apply(self, predictions, references):
112
+ """
113
+ Calculate ANLS (Average Normalized Levenshtein Similarity) for a list of predictions and references.
114
+ This implementation is adapted from
115
+ https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/infographicsvqa_eval.py
116
+
117
+ Args:
118
+ references (List[str]): List of correct answers. Each answer can be a string of json.
119
+ predictions (List[str]): List of predicted answers.
120
+ """
121
+ from .metrics import levenshtein_distance
122
+
123
+ res = []
124
+ # Unwrap predictions if it's a nested list
125
+ for prediction, reference in zip(predictions, references):
126
+ # Parse the reference which is a json string
127
+ try:
128
+ answer = json.loads(reference)
129
+ except json.JSONDecodeError:
130
+ answer = reference
131
+ if isinstance(answer, str):
132
+ answer = [answer]
133
+ assert isinstance(answer, list), 'The reference answer should be a list of answers.'
134
+
135
+ # Calculate ANLS for each reference answer
136
+ values = []
137
+ for ans in answer:
138
+ # preprocess both the answers - gt and prediction
139
+ gt_answer = ' '.join(ans.strip().lower().split())
140
+ det_answer = ' '.join(prediction.strip().lower().split())
141
+
142
+ dist = levenshtein_distance(gt_answer, det_answer)
143
+ length = max(len(ans.upper()), len(prediction.upper()))
144
+ values.append(0.0 if length == 0 else float(dist) / float(length))
145
+
146
+ question_result = 0.0
147
+ if values:
148
+ question_result = 1 - min(values)
149
+ if question_result < self.thresh_hold:
150
+ question_result = 0.0
151
+ res.append(question_result)
152
+ return res
153
+
154
+
155
+ @register_metric(name='bertscore')
156
+ class BertScore(SingletonMetric):
157
+
158
+ def _init_once(self, model_id_or_path: str = 'google-bert/bert-base-chinese', **kwargs):
159
+ """BertScore metric.
160
+
161
+ Args:
162
+ model_id_or_path (str, optional): The model ID on modelscope or path to the pre-trained model.
163
+ Defaults to 'google-bert/bert-base-chinese'.
164
+ """
165
+ check_import('torch', 'torch', raise_error=True, feature_name='BertScore Metric')
166
+
167
+ from .bert_score.scorer import BERTScorer
168
+ self.scorer = BERTScorer(model_id_or_path=model_id_or_path, batch_size=1024, **kwargs)
169
+
170
+ def apply(self, predictions: List[str], references: List[str]) -> List[float]:
171
+ _, _, F1 = self.scorer.score(predictions, references)
172
+ return [round(f1.item(), 6) for f1 in F1]
173
+
174
+
175
+ @register_metric(name='comet')
176
+ class COMETScore(SingletonMetric):
177
+
178
+ def _init_once(self, model_id_or_path: str = 'evalscope/wmt22-comet-da'):
179
+ """COMETScore metric.
180
+
181
+ Args:
182
+ model_name (str, optional): The model name on huggingface.
183
+ Defaults to 'evalscope/wmt22-comet-da'.
184
+ """
185
+ check_import('comet', 'unbabel-comet', raise_error=True, feature_name='COMETScore Metric')
186
+
187
+ from comet import load_from_checkpoint
188
+ from modelscope import snapshot_download
189
+
190
+ self.model_name = model_id_or_path
191
+ model_path = snapshot_download(model_id_or_path)
192
+ checkpoint_path = os.path.join(model_path, 'checkpoints', 'model.ckpt')
193
+ self.comet_scorer = load_from_checkpoint(checkpoint_path)
194
+
195
+ def apply(self, samples: List[Dict[str, str]]) -> List[float]:
196
+ """Apply COMET scoring."""
197
+ import torch
198
+
199
+ model_output = self.comet_scorer.predict(
200
+ samples=samples,
201
+ batch_size=1024,
202
+ gpus=1 if torch.cuda.is_available() else 0,
203
+ progress_bar=False,
204
+ )
205
+ scores = model_output.scores if hasattr(model_output, 'scores') else [model_output.system_score] * len(samples)
206
+
207
+ return [round(score, 6) for score in scores]
208
+
209
+
95
210
  # ##################
96
211
  # T2I Metrics ######
97
- ####################
212
+ # ##################
98
213
  @register_metric(name='VQAScore')
99
214
  class VQAScore(T2IMetric):
100
215
 
@@ -202,6 +317,9 @@ class Mean(Aggregator):
202
317
 
203
318
  name = 'mean'
204
319
 
320
+ def agg_func(self, values: List[float]) -> float:
321
+ return mean(values)
322
+
205
323
  def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
206
324
  """Aggregate scores by computing the mean for each metric.
207
325
 
@@ -230,7 +348,7 @@ class Mean(Aggregator):
230
348
  if values: # Only process non-empty value lists
231
349
  aggregated_scores.append(
232
350
  AggScore(
233
- score=mean(values),
351
+ score=self.agg_func(values),
234
352
  metric_name=metric_name,
235
353
  aggregation_name=self.name,
236
354
  num=len(values),
@@ -241,6 +359,20 @@ class Mean(Aggregator):
241
359
  return aggregated_scores
242
360
 
243
361
 
362
+ @register_aggregation(name='clipped_mean')
363
+ class ClippedMean(Mean):
364
+
365
+ name = 'clipped_mean'
366
+
367
+ def __init__(self, clip_min: float = 0.0, clip_max: float = 1.0):
368
+ self.clip_min = clip_min
369
+ self.clip_max = clip_max
370
+
371
+ def agg_func(self, values: List[float]) -> float:
372
+ clipped_values = min(max(mean(values), self.clip_min), self.clip_max)
373
+ return clipped_values
374
+
375
+
244
376
  @register_aggregation(name='pass_at_k')
245
377
  class PassAtK(Aggregator):
246
378
 
@@ -260,10 +392,6 @@ class PassAtK(Aggregator):
260
392
  if not scores:
261
393
  return []
262
394
 
263
- import numpy as np
264
-
265
- from .metrics import calculate_pass_at_k
266
-
267
395
  # Group scores by metric name and group_id
268
396
  metric_groups = defaultdict(lambda: defaultdict(list))
269
397
 
@@ -305,3 +433,179 @@ class PassAtK(Aggregator):
305
433
  )
306
434
 
307
435
  return aggregated_scores
436
+
437
+
438
+ @register_aggregation(name='mean_and_pass_at_k')
439
+ class MeanPassAtK(Aggregator):
440
+
441
+ def __init__(self):
442
+ self.name = 'mean_and_pass_at_k'
443
+
444
+ def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
445
+ """Add per-metric pass@k (computed via calculate_pass_at_k) to each sample, then mean-aggregate.
446
+
447
+ For each metric:
448
+ - Group scores by group_id
449
+ - Collect binary correctness values
450
+ - Infer k as (total samples / number of groups) assuming uniform repetitions
451
+ - Compute per-group pass@k via calculate_pass_at_k
452
+ - Annotate each sample with metric_pass@k for its group
453
+ Finally run Mean() over the augmented metric set.
454
+ """
455
+ if not scores:
456
+ return []
457
+
458
+ # Extract metric names present in score values
459
+ metrics = list(scores[0].score.value.keys())
460
+
461
+ for metric_name in metrics:
462
+ # group_id -> list[float] (0/1 correctness values)
463
+ group_values: Dict[str, List[float]] = defaultdict(list)
464
+ for s in scores:
465
+ group_id = getattr(s, 'group_id', s.sample_id)
466
+ value = float(s.score.value[metric_name])
467
+ group_values[group_id].append(value)
468
+
469
+ if not group_values:
470
+ continue
471
+
472
+ # Infer k (assumes roughly uniform repeats)
473
+ k = int(len(scores) / len(group_values)) if len(group_values) > 0 else 1
474
+ if k <= 0:
475
+ k = 1
476
+
477
+ # Prepare inputs for calculate_pass_at_k
478
+ num_samples: List[int] = []
479
+ num_correct: List[int] = []
480
+ group_order: List[str] = []
481
+ for gid, vals in group_values.items():
482
+ group_order.append(gid)
483
+ num_samples.append(len(vals))
484
+ num_correct.append(int(sum(vals)))
485
+
486
+ # Compute per-group pass@k
487
+ pass_at_k_list = calculate_pass_at_k(num_samples, num_correct, k)
488
+ # Map back: group_id -> pass@k value
489
+ pass_at_k_map = {gid: float(v) for gid, v in zip(group_order, pass_at_k_list)}
490
+
491
+ # Annotate each sample with its group's pass@k
492
+ for s in scores:
493
+ group_id = getattr(s, 'group_id', s.sample_id)
494
+ s.score.value[f'{metric_name}_pass@{k}'] = pass_at_k_map[group_id]
495
+
496
+ # Delegate mean aggregation over original + injected pass@k metrics
497
+ m = Mean()
498
+ return m(scores)
499
+
500
+
501
+ @register_aggregation(name='mean_and_vote_at_k')
502
+ class MeanVoteAtK(Aggregator):
503
+
504
+ def __init__(self):
505
+
506
+ self.name = 'mean_and_vote_at_k'
507
+
508
+ def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
509
+ """Aggregate scores by computing the vote@k for each metric using group_id.
510
+
511
+ Args:
512
+ scores: List of sample scores to aggregate
513
+
514
+ Returns:
515
+ List of aggregated scores with vote@k values
516
+ """
517
+ if not scores:
518
+ return []
519
+
520
+ metrics = list(scores[0].score.value.keys())
521
+
522
+ # Calculate vote@k for all metrics
523
+ for metric_name in metrics:
524
+
525
+ # Count of occurrences for each answer in each group_id
526
+ answer_groups = defaultdict(lambda: defaultdict(int))
527
+ # Score for each answer in each group_id
528
+ scores_groups = defaultdict(lambda: defaultdict(float))
529
+ # Score of the most frequently occurring answer
530
+ final_scores_groups = defaultdict(float)
531
+ # Count different answers for this metric
532
+ for score in scores:
533
+ group_id = getattr(score, 'group_id', score.sample_id) # fallback to sample_id if no group_id
534
+ answer_prediction = getattr(score.score, 'extracted_prediction', None)
535
+ answer_groups[group_id][answer_prediction] += 1
536
+ scores_groups[group_id][answer_prediction] = score.score.value[metric_name]
537
+ # Calculate the repetition count k for each problem
538
+ k = int(len(scores) / len(answer_groups))
539
+
540
+ # Use the score of the most frequently occurring answer as the group's score
541
+ for group_id in answer_groups:
542
+ final_scores_groups[group_id] = scores_groups[group_id][
543
+ max(answer_groups[group_id], key=answer_groups[group_id].get)]
544
+
545
+ # Add the corresponding vote@k for the metric to each score's value
546
+ for score in scores:
547
+ group_id = getattr(score, 'group_id', score.sample_id)
548
+ score.score.value.update({f'{metric_name}_vote@{k}': final_scores_groups[group_id]})
549
+
550
+ # Calculate the mean value for all metrics and their corresponding vote@k
551
+ m = Mean()
552
+ return m(scores)
553
+
554
+
555
+ @register_aggregation(name='mean_and_pass_hat_k')
556
+ class MeanPassHatK(Aggregator):
557
+
558
+ def __init__(self):
559
+ self.name = 'mean_and_pass_hat_k'
560
+
561
+ def __call__(self, scores: List[SampleScore]) -> List[AggScore]:
562
+ """Add per-metric pass^k using calculate_pass_hat_k, then mean-aggregate.
563
+
564
+ For each metric:
565
+ - Group scores by group_id
566
+ - Collect binary correctness values
567
+ - Infer k as approximate repeats and clamp to min attempts across groups
568
+ - Compute per-group pass^k via calculate_pass_hat_k
569
+ - Annotate each sample with metric_pass^{k} for its group
570
+ Finally run Mean() over the augmented metric set.
571
+ """
572
+ if not scores:
573
+ return []
574
+
575
+ # Freeze metric names before augmenting values to avoid iterating injected keys
576
+ metrics = list(scores[0].score.value.keys())
577
+
578
+ for metric_name in metrics:
579
+ # group_id -> list[float] (0/1 correctness values)
580
+ group_values: Dict[str, List[float]] = defaultdict(list)
581
+ for s in scores:
582
+ group_id = getattr(s, 'group_id', s.sample_id)
583
+ value = float(s.score.value[metric_name])
584
+ group_values[group_id].append(value)
585
+
586
+ if not group_values:
587
+ continue
588
+
589
+ # Infer repeats and clamp to the smallest group size to satisfy k <= n
590
+ approx_k = int(len(scores) / len(group_values)) if len(group_values) > 0 else 1
591
+ min_n = min(len(vals) for vals in group_values.values())
592
+ k = max(1, min(approx_k, min_n))
593
+
594
+ # Compute per-group pass^k
595
+ pass_hat_k_map: Dict[str, float] = {}
596
+ for gid, vals in group_values.items():
597
+ n = len(vals)
598
+ c = int(sum(vals))
599
+ # calculate_pass_hat_k requires k <= n; ensured by clamping above
600
+ pass_hat_k_map[gid] = float(calculate_pass_hat_k(n, c, k))
601
+
602
+ # Annotate each sample with its group's pass^k
603
+ suffix = f'pass^{k}'
604
+ injected_key = f'{metric_name}_{suffix}'
605
+ for s in scores:
606
+ group_id = getattr(s, 'group_id', s.sample_id)
607
+ s.score.value[injected_key] = pass_hat_k_map[group_id]
608
+
609
+ # Mean aggregate over original + injected pass^k metrics
610
+ m = Mean()
611
+ return m(scores)
@@ -12,6 +12,11 @@ from collections.abc import Iterable
12
12
  from typing import Dict, List, Union
13
13
 
14
14
 
15
+ def normalize_text(text: str) -> str:
16
+ """Normalize text by lowering case and stripping whitespace."""
17
+ return text.strip().lower()
18
+
19
+
15
20
  def mean(arr: list):
16
21
  if not arr:
17
22
  return 0.0
@@ -467,3 +472,35 @@ def calculate_pass_at_k(
467
472
  num_samples_it = iter(num_samples)
468
473
 
469
474
  return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
475
+
476
+
477
+ def calculate_pass_hat_k(num_trials: int, success_count: int, k: int) -> float:
478
+ """
479
+ Compute the pass^k metric for the given number of trials, success count, and k.
480
+ from https://arxiv.org/pdf/2406.12045
481
+ Args:
482
+ num_trials: The number of trials.
483
+ success_count: The number of successful trials.
484
+ k: The number of trials to consider.
485
+ Returns:
486
+ The pass^k metric.
487
+ """
488
+ if num_trials < k:
489
+ raise ValueError(f'Number of trials {num_trials} is less than k {k}.')
490
+ return math.comb(success_count, k) / math.comb(num_trials, k)
491
+
492
+
493
+ def levenshtein_distance(s1, s2):
494
+ if len(s1) > len(s2):
495
+ s1, s2 = s2, s1
496
+
497
+ distances = range(len(s1) + 1)
498
+ for i2, c2 in enumerate(s2):
499
+ distances_ = [i2 + 1]
500
+ for i1, c1 in enumerate(s1):
501
+ if c1 == c2:
502
+ distances_.append(distances[i1])
503
+ else:
504
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
505
+ distances = distances_
506
+ return distances[-1]
@@ -30,13 +30,9 @@ from transformers.modeling_outputs import (
30
30
  SequenceClassifierOutput,
31
31
  TokenClassifierOutput,
32
32
  )
33
- from transformers.modeling_utils import (
34
- PreTrainedModel,
35
- apply_chunking_to_forward,
36
- find_pruneable_heads_and_indices,
37
- prune_linear_layer,
38
- )
33
+ from transformers.modeling_utils import PreTrainedModel
39
34
  from transformers.models.bert.configuration_bert import BertConfig
35
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
36
  from transformers.utils import logging
41
37
  from typing import Any, Dict, Optional, Tuple
42
38
 
@@ -14,13 +14,9 @@ from transformers.modeling_outputs import (
14
14
  BaseModelOutputWithPastAndCrossAttentions,
15
15
  BaseModelOutputWithPoolingAndCrossAttentions,
16
16
  )
17
- from transformers.modeling_utils import (
18
- PreTrainedModel,
19
- apply_chunking_to_forward,
20
- find_pruneable_heads_and_indices,
21
- prune_linear_layer,
22
- )
17
+ from transformers.modeling_utils import PreTrainedModel
23
18
  from transformers.models.bert.configuration_bert import BertConfig
19
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
24
20
  from transformers.utils import logging
25
21
  from typing import Tuple
26
22
 
@@ -31,13 +31,9 @@ from transformers.modeling_outputs import (
31
31
  SequenceClassifierOutput,
32
32
  TokenClassifierOutput,
33
33
  )
34
- from transformers.modeling_utils import (
35
- PreTrainedModel,
36
- apply_chunking_to_forward,
37
- find_pruneable_heads_and_indices,
38
- prune_linear_layer,
39
- )
34
+ from transformers.modeling_utils import PreTrainedModel
40
35
  from transformers.models.bert.configuration_bert import BertConfig
36
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
41
37
  from transformers.utils import logging
42
38
  from typing import Optional, Tuple
43
39