crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -1,24 +1,19 @@
1
+ from collections import defaultdict
1
2
  import math
2
- from dataclasses import dataclass, replace
3
- from typing import List, Callable, Optional, Dict, Tuple, Set, cast
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Set
4
5
  from urllib.parse import unquote
5
- from functools import partial
6
6
 
7
- import json
8
- import re
9
- import string
10
- import nltk
11
7
  import numpy as np
12
8
  import scipy
13
9
  import calibration as cal
14
- import importlib_resources as resources
15
- from nltk.metrics.scores import f_measure
16
- from nltk.tokenize import word_tokenize
17
- from nltk.translate.bleu_score import sentence_bleu
18
- from rouge_score import rouge_scorer
10
+ from helm.benchmark.adaptation.scenario_state import ScenarioState
11
+ from helm.benchmark.metrics.evaluate_reference_metrics import compute_reference_metrics
12
+ from helm.benchmark.metrics.efficiency_metrics import EfficiencyMetric
13
+ from helm.benchmark.metrics.reference_metric import ReferenceMetric
19
14
 
20
15
  from helm.common.hierarchical_logger import hlog
21
- from helm.common.request import Token, Sequence
16
+ from helm.common.request import Token, GeneratedOutput
22
17
  from helm.benchmark.adaptation.adapters.adapter_factory import (
23
18
  ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
24
19
  ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
@@ -29,177 +24,11 @@ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
29
24
  from helm.benchmark.window_services.window_service import WindowService
30
25
  from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
31
26
  from helm.benchmark.window_services.tokenizer_service import TokenizerService
32
- from helm.benchmark.scenarios.scenario import CORRECT_TAG, Instance, Reference
33
- from helm.benchmark.scenarios.math_scenario import is_equiv, is_equiv_chain_of_thought
34
- from helm.benchmark.scenarios.code_scenario import CodeReference
35
- from helm.benchmark.metrics.cleva_metrics_helper import ChineseTokenizer
36
- from . import code_metrics_helper
37
- from .metric import Metric, get_unique_stat_by_name
38
- from .metric_name import MetricName
27
+ from helm.benchmark.scenarios.scenario import CORRECT_TAG, Instance
28
+ from .metric import Metric, MetricInterface, MetricResult, add_context, get_unique_stat_by_name
29
+ from .metric_name import MetricContext, MetricName
39
30
  from .metric_service import MetricService
40
- from .statistic import Stat
41
-
42
-
43
- try:
44
- nltk.data.find("tokenizers/punkt")
45
- except LookupError:
46
- nltk.download("punkt") # Required for rouge
47
-
48
-
49
- EFFICIENCY_DATA_PACKAGE: str = "helm.benchmark.efficiency_data"
50
-
51
- INFERENCE_IDEALIZED_RUNTIMES_JSON_FILENAME: str = "inference_idealized_runtimes.json"
52
- INFERENCE_DENOISED_RUNTIMES_JSON_FILENAME: str = "inference_denoised_runtimes.json"
53
- TRAINING_EFFICIENCY_JSON_FILENAME: str = "training_efficiency.json"
54
-
55
-
56
- def compute_estimated_time_from_prompt_size_and_num_output_tokens(
57
- request_state: RequestState,
58
- inference_runtimes_dict: Dict[str, Dict],
59
- num_prompt_tokens: int,
60
- num_output_tokens: int,
61
- ) -> Optional[float]:
62
- estimated_runtime: Optional[float]
63
- if request_state.request.model in inference_runtimes_dict:
64
- inference_runtimes_dict_for_model = inference_runtimes_dict[request_state.request.model]
65
- runtime_per_output_token: float = inference_runtimes_dict_for_model["runtime_per_output_token"]
66
- raw_runtimes_for_prompt_tokens: Dict[str, float] = inference_runtimes_dict_for_model[
67
- "runtime_for_prompt_tokens"
68
- ]
69
- runtimes_for_prompt_tokens: Dict[int, float] = {int(k): v for (k, v) in raw_runtimes_for_prompt_tokens.items()}
70
-
71
- runtime_for_prompt_tokens: Optional[float] = None
72
- largest_num_tokens_in_efficiency_dict: int = max(runtimes_for_prompt_tokens.keys())
73
- # Find the smallest num_prompt_tokens larger than the number of tokens in the given prompt,
74
- # then scale runtime in dict by (num_prompt_tokens / key) to get more accurate estimate: we
75
- # assume that we can encode the prompt at the same throughput as the smallest key larger than
76
- # num_prompt_tokens, and number of compute operations scales linearly with num_prompt_tokens.
77
- for key in sorted(runtimes_for_prompt_tokens.keys()):
78
- if num_prompt_tokens <= key:
79
- runtime_for_prompt_tokens = runtimes_for_prompt_tokens[key] * (num_prompt_tokens / key)
80
- break
81
- # If number of tokens in the prompt exceeds the largest key in the efficiency dict, then
82
- # estimate the prompt encoding time by linearly scaling up the runtime for the largest
83
- # key (this is reasonably accurate under certain simplifying assumptions).
84
- if runtime_for_prompt_tokens is None:
85
- runtime_for_prompt_tokens = runtimes_for_prompt_tokens[largest_num_tokens_in_efficiency_dict] * (
86
- num_prompt_tokens / largest_num_tokens_in_efficiency_dict
87
- )
88
- overhead: Optional[float] = inference_runtimes_dict_for_model.get("overhead")
89
-
90
- # Idealized runtime is sum of the runtime of encoding the input tokens, the runtime of
91
- # generating `num_output_tokens` (`runtime_per_output_token` * (`num_output_tokens` - 1))
92
- # if number of output tokens is greater than 0, otherwise just `runtime_for_prompt_tokens`,
93
- # and the overhead if available.
94
- estimated_runtime = runtime_for_prompt_tokens
95
- if num_output_tokens > 0:
96
- estimated_runtime += runtime_per_output_token * (num_output_tokens - 1)
97
- # Add overhead if it is available.
98
- if overhead is not None:
99
- estimated_runtime += overhead
100
- else:
101
- estimated_runtime = None
102
-
103
- return estimated_runtime
104
-
105
-
106
- def pass_at_k_estimator(n: int, c: int, k: int) -> float:
107
- """Calculates 1 - comb(n - c, k) / comb(n, k).
108
-
109
- Numerically stable version defined in
110
- https://arxiv.org/pdf/2107.03374.pdf
111
- """
112
- if n - c < k:
113
- return 1.0
114
- return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
115
-
116
-
117
- def normalize_text(text: str) -> str:
118
- """Lower text and remove punctuation, articles and extra whitespace.
119
- Copied from the [QuAC](http://quac.ai/) evaluation script found at
120
- https://s3.amazonaws.com/my89public/quac/scorer.py"""
121
-
122
- def remove_articles(text: str) -> str:
123
- return re.sub(r"\b(a|an|the)\b", " ", text)
124
-
125
- def white_space_fix(text: str) -> str:
126
- return " ".join(text.split())
127
-
128
- def remove_punc(text: str) -> str:
129
- exclude = set(string.punctuation)
130
- return "".join(ch for ch in text if ch not in exclude)
131
-
132
- def lower(text: str) -> str:
133
- return text.lower()
134
-
135
- return white_space_fix(remove_articles(remove_punc(lower(text))))
136
-
137
-
138
- def exact_match(gold: str, pred: str) -> float:
139
- if not pred:
140
- return 0
141
-
142
- return 1 if gold.strip() == pred.strip() else 0
143
-
144
-
145
- def quasi_exact_match(gold: str, pred: str) -> float:
146
- if not pred:
147
- return 0
148
-
149
- return 1 if normalize_text(gold) == normalize_text(pred) else 0
150
-
151
-
152
- def prefix_exact_match(gold: str, pred: str) -> float:
153
- """
154
- The `prefix_exact_match` metric is particularly useful in the zero-shot setting, where the model is
155
- not given examples of the expected outputs and tends to output more tokens than it should.
156
-
157
- For example, for this zero-shot prompt from BoolQ,
158
-
159
- Passage: Elmendorf Air Force Base (IATA: EDF, ICAO: PAED, FAA LID: EDF) is a United States military facility
160
- in Anchorage, the largest city in Alaska. Originally known as Elmendorf Field, it became Elmendorf Air Force
161
- Base after World War II, and in 2010 it merged with nearby Fort Richardson to form Joint Base Elmendorf-Richardson.
162
- Question: Is there an air force base in anchorage alaska?
163
- Answer:
164
-
165
- the model could output up to `max_tokens` number of tokens "Yes, Elmendorf" instead of just "Yes".
166
- """
167
- if not pred:
168
- return 0
169
-
170
- return 1 if pred.strip().startswith(gold.strip()) else 0
171
-
172
-
173
- def quasi_prefix_exact_match(gold: str, pred: str) -> float:
174
- """
175
- Same thing as `prefix_exact_match` but we normalize the text before checking if the prefix match.
176
- """
177
- if not pred:
178
- return 0
179
-
180
- return 1 if normalize_text(pred).startswith(normalize_text(gold)) else 0
181
-
182
-
183
- def f1_score(gold: str, pred: str) -> float:
184
- ret = f_measure(set(normalize_text(gold).split()), set(normalize_text(pred).split()))
185
- if ret is None: # answer is the empty string after normalizing
186
- return 0.0
187
-
188
- return ret
189
-
190
-
191
- def exact_match_indicator(gold: str, pred: str, indicator: str = " ") -> float:
192
- """
193
- Exact match, allowing for some preceding context.
194
- For example, the following two answers are considered matching:
195
- - Because of x and y, the answer is ## <answer>
196
- - Given reasons y and z, the answer is ## <answer>
197
- While the following is considered different from the earlier two
198
- - Given reasons x and a, the answer is ## <other answer>
199
- """
200
- pred = pred.split(indicator)[-1].strip()
201
- gold = gold.split(indicator)[-1].strip()
202
- return exact_match(gold, pred)
31
+ from .statistic import Stat, merge_stat
203
32
 
204
33
 
205
34
  def get_num_bytes(tokens: List[Token]) -> int:
@@ -251,123 +80,6 @@ def convert_tokens_to_text(tokens: List[Token]) -> List[Dict]:
251
80
  return groups
252
81
 
253
82
 
254
- def rouge_score(gold: str, pred: str, rouge_type: str, scorer: rouge_scorer.RougeScorer) -> float:
255
- scores = scorer.score(gold, pred)
256
- return scores[rouge_type].fmeasure
257
-
258
-
259
- def get_rouge_function(rouge_type: str) -> Callable[[str, str], float]:
260
- scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True)
261
- return partial(rouge_score, scorer=scorer, rouge_type=rouge_type)
262
-
263
-
264
- def bleu_1(gold: str, pred: str) -> float:
265
- return sentence_bleu([word_tokenize(gold)], word_tokenize(pred), weights=(1, 0, 0, 0))
266
-
267
-
268
- def chinese_bleu_1(gold: str, pred: str) -> float:
269
- char_tokenizer = ChineseTokenizer()
270
- return sentence_bleu([char_tokenizer.tokenize(gold)], char_tokenizer.tokenize(pred), weights=(1, 0, 0, 0))
271
-
272
-
273
- def get_chinese_rouge_function(rouge_type: str) -> Callable[[str, str], float]:
274
- char_tokenizer = ChineseTokenizer()
275
- scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True, tokenizer=char_tokenizer)
276
- return partial(rouge_score, scorer=scorer, rouge_type=rouge_type)
277
-
278
-
279
- def cleva_math_result_match(gold: str, pred: str) -> float:
280
- """
281
- Exact match that only cares the last math expression.
282
- Common math expressions are numbers and fractions.
283
- """
284
- pattern = r"[-+*/%\.\(\)\d]+"
285
- matches = re.findall(pattern, pred)
286
- if matches:
287
- pred = matches[-1].lstrip(")")
288
- # remove space in front or at the end
289
- pred = pred.strip()
290
- return exact_match(gold, pred)
291
-
292
-
293
- def bleu_4(gold: str, pred: str) -> float:
294
- return sentence_bleu([word_tokenize(gold)], word_tokenize(pred), weights=(0, 0, 0, 1))
295
-
296
-
297
- def extract_set_from_text(
298
- set_str: str,
299
- set_start_str: str = " is ",
300
- set_separator: str = " and ",
301
- empty_set_str: str = "Nothing.",
302
- ) -> Set[str]:
303
- """
304
- Given a string, extract the set of strings implied by that string.
305
- set_start_str denotes the start of the set
306
- set_separator denotes the string separating set elements
307
- empty_set_str is the string which denotes the empty set
308
- """
309
- if set_str == empty_set_str:
310
- return set()
311
- set_str = set_str.replace(".", "")
312
- extracted_set = set(set_str.split(set_start_str)[-1].split(set_separator))
313
- return extracted_set
314
-
315
-
316
- def extract_gold_pred_sets(gold: str, pred: str) -> Tuple[Set[str], Set[str]]:
317
- """Extract the set of strings implied by the gold and pred strings"""
318
- gold_set = extract_set_from_text(gold)
319
- pred_set = extract_set_from_text(pred.split("\n")[0])
320
- return gold_set, pred_set
321
-
322
-
323
- def iou_set_match(gold: str, pred: str) -> float:
324
- """Compute the intersection over union of the gold and pred sets"""
325
- gold_set, pred_set = extract_gold_pred_sets(gold, pred)
326
- if len(gold_set) == 0: # If gold is empty, just check if the pred set is also empty
327
- return float(gold_set == pred_set)
328
- return len(gold_set.intersection(pred_set)) / len(gold_set.union(pred_set))
329
-
330
-
331
- def f1_set_match(gold: str, pred: str) -> float:
332
- """Compute the F1 score of the gold and pred sets"""
333
- gold_set, pred_set = extract_gold_pred_sets(gold, pred)
334
- if len(gold_set) == 0: # If gold is empty, just check if the pred set is also empty
335
- return float(gold_set == pred_set)
336
- true_positives = gold_set.intersection(pred_set)
337
- return 2 * len(true_positives) / (len(gold_set) + len(pred_set))
338
-
339
-
340
- def exact_set_match(gold: str, pred: str) -> float:
341
- """Compute whether the sets generated exactly match"""
342
- gold_set, pred_set = extract_gold_pred_sets(gold, pred)
343
- return float(gold_set == pred_set)
344
-
345
-
346
- def absolute_value_difference(gold: str, pred: str) -> float:
347
- """Compute the absolute value of the difference between two numbers (provided as strings),
348
- or 0.0 if invalid input.
349
- """
350
-
351
- def maybe_int(text: str):
352
- """Parse int, ignoring commas in numbers."""
353
- try:
354
- val = int(text.replace(",", ""))
355
- except ValueError:
356
- return 0.0
357
- return val
358
-
359
- gold_val = maybe_int(gold)
360
- pred_val = maybe_int(pred)
361
- return abs(gold_val - pred_val)
362
-
363
-
364
- def code_eval(gold: Tuple[str, Optional[Dict]], pred: str) -> float:
365
- """Evaluate Code Correctness on test examples."""
366
- assert gold[1] is not None # gold[1]["canonical_solution"]
367
- # Warning: will execute machine generated code; need to sandbox before executing
368
- return float(code_metrics_helper.check_correctness(gold[1], pred, 3.0)["passed"]) # type: ignore
369
-
370
-
371
83
  def compute_perplexity_metrics(stats: Dict[MetricName, Stat]) -> List[Stat]:
372
84
  # TODO: find out the root cause and undo num_X > 0 check
373
85
  # https://github.com/stanford-crfm/benchmarking/issues/350
@@ -392,7 +104,37 @@ def compute_perplexity_metrics(stats: Dict[MetricName, Stat]) -> List[Stat]:
392
104
  return derived_stats
393
105
 
394
106
 
395
- class BasicMetric(Metric):
107
+ class InstancesPerSplitMetric(MetricInterface):
108
+ """Report the average num_instances in each MetricContext across train_trials."""
109
+
110
+ def evaluate(
111
+ self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
112
+ ) -> MetricResult:
113
+ adapter_spec = scenario_state.adapter_spec
114
+ global_stats: Dict[MetricName, Stat] = {}
115
+
116
+ for train_trial_index in range(adapter_spec.num_train_trials):
117
+ trial_stats: Dict[MetricName, Stat] = {} # Statistics just for this trial
118
+ # Group instances in this train_trial by context.
119
+ instances_per_metric_context: Dict[MetricContext, Set[Instance]] = defaultdict(set)
120
+ for request_state in scenario_state.request_states:
121
+ if request_state.train_trial_index == train_trial_index:
122
+ instances_per_metric_context[MetricContext.from_instance(request_state.instance)].add(
123
+ request_state.instance
124
+ )
125
+ for context, instance_set in instances_per_metric_context.items():
126
+ stat = Stat(MetricName("num_instances")).add(len(instance_set))
127
+ merge_stat(trial_stats, add_context(stat, context))
128
+
129
+ # We take the mean value for each trial.
130
+ for stat in trial_stats.values():
131
+ merge_stat(global_stats, stat.take_mean())
132
+
133
+ # There are no per-instance Stats.
134
+ return MetricResult(list(global_stats.values()), [])
135
+
136
+
137
+ class BasicGenerationMetric(Metric):
396
138
  """
397
139
  Defines basic metrics which don't require domain knowledge. This should be
398
140
  fairly comprehensive already, and we should try to use this as much as possible.
@@ -403,333 +145,11 @@ class BasicMetric(Metric):
403
145
 
404
146
  def __init__(self, names: List[str]):
405
147
  self.names: List[str] = names
406
-
407
- # For Efficiency metrics:
408
- # The `inference_efficiency.json` file contains a `runtime_per_output_token` value
409
- # (the estimated runtime of generating one output token) and a
410
- # `runtime_for_prompt_tokens` dict (a mapping from various num_prompt_tokens values to
411
- # the estimated runtime of encoding a prompt with that many tokens).
412
- # For example:
413
- # "openai/davinci": {
414
- # "runtime_per_output_token": 0.080,
415
- # "runtime_for_prompt_tokens": {
416
- # "1": 0.016,
417
- # "16": 0.018,
418
- # "32": 0.020,
419
- # ...
420
- #
421
- # These runtimes are generated by initializing Megatron with a model of the right size,
422
- # obtaining end-to-end generation times for different numbers of prompt and output tokens,
423
- # and then fitting a linear regression model to the runtimes: the resulting slope is the
424
- # runtime_per_output_token, which is the processing time for generating each output token,
425
- # and the y-intercept is the runtime_for_prompt_tokens, with different values for different
426
- # num_prompt_tokens values.
427
- # Profiling code and logs, and code to fit the regression model is available at
428
- # https://github.com/stanford-crfm/benchmarking_efficiency.
429
- data_package = resources.files(EFFICIENCY_DATA_PACKAGE)
430
- with data_package.joinpath(INFERENCE_IDEALIZED_RUNTIMES_JSON_FILENAME).open("r") as f:
431
- self.inference_idealized_runtimes_dict = json.load(f)
432
- with data_package.joinpath(INFERENCE_DENOISED_RUNTIMES_JSON_FILENAME).open("r") as f:
433
- self.inference_denoised_runtimes_dict = json.load(f)
434
-
435
- # We use estimated emitted CO2 during training (in tons of CO2) as a proxy metric
436
- # for training efficiency. We use reported metrics where applicable, otherwise
437
- # we estimate them from runtime information, type and number of hardware accelerators
438
- # used, region, etc.
439
- with data_package.joinpath(TRAINING_EFFICIENCY_JSON_FILENAME).open("r") as f:
440
- self.training_efficiency_dict = json.load(f)
148
+ self.efficiency_metric = EfficiencyMetric()
441
149
 
442
150
  def __repr__(self):
443
151
  return f"BasicMetric({','.join(self.names)})"
444
152
 
445
- def compute_reference_metrics(
446
- self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
447
- ) -> List[Stat]:
448
- """
449
- Setup:
450
-
451
- - Gold (correct references): G1 ... Gm
452
- - Predictions (completions): P1 ... Pk
453
-
454
- For each pair (G, P), we can define a ${score} (e.g., exact match, F1, BLEU).
455
-
456
- We define the following stats:
457
-
458
- - ${score}: max_i score(Gi, P1)
459
- - ${score}@k: max_{i,j} score(Gi, Pj)
460
- """
461
-
462
- def compute_metrics_helper(
463
- name: MetricName,
464
- score_func: Callable,
465
- group: Optional[str] = None,
466
- ) -> List[Stat]:
467
- if name.name == "pass": # Calculate pass@k for HumanEval from CodeScenario.
468
- score_func = cast(Callable[[Tuple[str, Optional[Dict]], str], float], score_func) # Make mypy happy.
469
- code_golds = cast(List[CodeReference], golds)
470
- results = [
471
- score_func((gold.output.text, gold.test_cases), pred) for gold in code_golds for pred in preds
472
- ]
473
- _len, _sum = len(results), int(sum(results)) # Cast to int to make type match.
474
- score_1 = pass_at_k_estimator(_len, _sum, 1)
475
- score_k = pass_at_k_estimator(_len, _sum, adapter_spec.num_outputs)
476
- elif name.name == "code_eval_acc":
477
- score_func = cast(Callable[[Tuple[str, Optional[Dict]], str], float], score_func) # Make mypy happy.
478
- code_golds = cast(List[CodeReference], golds)
479
- score_1 = max(score_func((gold.output.text, gold.test_cases), preds[0]) for gold in code_golds)
480
- score_k = max(
481
- score_func((gold.output.text, gold.test_cases), pred) for gold in code_golds for pred in preds
482
- )
483
- else:
484
- score_func = cast(Callable[[str, str], float], score_func) # Make mypy happy.
485
- score_1 = max(score_func(gold.output.text, preds[0]) for gold in golds)
486
- score_k = max(score_func(gold.output.text, pred) for gold in golds for pred in preds)
487
-
488
- metrics = [Stat(name).add(score_1)] # score_1 corresponds using one prediction
489
- if adapter_spec.num_outputs != 1:
490
- metrics.append(Stat(replace(name, name=f"{name.name}@{adapter_spec.num_outputs}")).add(score_k))
491
- return metrics
492
-
493
- # maps each string metric name to its associated function
494
- metric_fn_mapping: Dict[str, Callable] = {
495
- "exact_match": exact_match,
496
- "quasi_exact_match": quasi_exact_match,
497
- "prefix_exact_match": prefix_exact_match,
498
- "quasi_prefix_exact_match": quasi_prefix_exact_match,
499
- "exact_match_indicator": exact_match_indicator,
500
- "exact_set_match": exact_set_match,
501
- "iou_set_match": iou_set_match,
502
- "f1_set_match": f1_set_match,
503
- "math_equiv": is_equiv,
504
- "math_equiv_chain_of_thought": is_equiv_chain_of_thought,
505
- "code_eval_acc": code_eval,
506
- "pass": code_eval,
507
- "f1_score": f1_score,
508
- "rouge_1": get_rouge_function("rouge1"),
509
- "rouge_2": get_rouge_function("rouge2"),
510
- "rouge_l": get_rouge_function("rougeL"),
511
- "bleu_1": bleu_1,
512
- "bleu_4": bleu_4,
513
- "chinese_bleu_1": chinese_bleu_1,
514
- "chinese_rouge_1": get_chinese_rouge_function("rouge1"),
515
- "chinese_rouge_2": get_chinese_rouge_function("rouge2"),
516
- "cleva_math_result_match": cleva_math_result_match,
517
- "absolute_value_difference": absolute_value_difference,
518
- }
519
-
520
- stats: List[Stat] = []
521
-
522
- # Gold outputs
523
- golds: List[Reference] = [reference for reference in request_state.instance.references if reference.is_correct]
524
- assert len(golds) > 0
525
-
526
- # Predicted outputs
527
- assert request_state.result is not None
528
- sorted_completions: List[Sequence] = sorted(request_state.result.completions, key=lambda x: -x.logprob)
529
- preds: List[str] = [completion.text.strip() for completion in sorted_completions]
530
-
531
- # Apply mapping if exists (e.g., for multiple-choice questions A -> Boston, B -> New York)
532
- # Note: If 'A' and 'B' were the only possible choices, smaller language models like GPT-2 would
533
- # sometimes predict a random letter like 'M'.
534
- if request_state.output_mapping is not None:
535
- preds = [request_state.output_mapping.get(pred) for pred in preds] # type: ignore
536
-
537
- # Compute max_prob, the probability that the model assigns to its generated text.
538
- # Use the log prob of sorted_completions[0], which is the completion with the highest
539
- # log_prob. We use this since that's what's used for computing metrics like exact_match.
540
- # One subtlety is that when computing exact_match, we strip whitespace, so the actual
541
- # max_prob is the sum of all the probabilities in the set {x : strip(x) = prediction}.
542
- # In practice, we think this may not make much of a difference because models may not place
543
- # high probabilities on having additional spaces (should check this). Also, the sum
544
- # involves computing the log_prob for many completions which could be intractable.
545
- max_prob = np.exp(sorted_completions[0].logprob)
546
- stats.append(Stat(MetricName("max_prob")).add(max_prob))
547
-
548
- # Add other metrics
549
- for metric_name in self.names:
550
- if metric_name in metric_fn_mapping:
551
- stats.extend(compute_metrics_helper(MetricName(metric_name), metric_fn_mapping[metric_name]))
552
- else:
553
- raise NameError(f"{metric_name} is not in the list of metric functions.")
554
-
555
- return stats
556
-
557
- def compute_efficiency_metrics(
558
- self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
559
- ) -> List[Stat]:
560
- """Compute efficiency metrics for both inference and training.
561
- For inference, we record both the actual runtime and an estimated idealized runtime
562
- for the given request with an optimized software implementation run on A100 GPU(s),
563
- taking into account both the number of tokens in the prompt of the request, and the
564
- number of generated output tokens.
565
- For training, we report the estimated total metric tons of CO2 emitted to train the
566
- model. This is the same for each request."""
567
- # Compute efficiency metrics for inference.
568
- assert request_state.result is not None
569
-
570
- runtime: Optional[float] = None
571
- batch_size: Optional[int] = None
572
- # Compute efficiency metrics for inference.
573
- if request_state.result.request_time is not None:
574
- runtime = request_state.result.request_time
575
- batch_size = 1
576
- # For models that perform offline batch inference, effective runtime is batch_request_time, but also
577
- # record batch_size to provide nuance.
578
- if request_state.result.batch_request_time is not None and request_state.result.batch_size is not None:
579
- runtime = request_state.result.batch_request_time
580
- batch_size = request_state.result.batch_size
581
-
582
- # Compute total number of prompt and output tokens.
583
- # Fetch the right `Tokenizer` depending on the model defined in `AdapterSpec`
584
- # and calculate the number of tokens in the prompt.
585
- tokenizer_service: TokenizerService = metric_service
586
- window_service: WindowService = WindowServiceFactory.get_window_service(adapter_spec.model, tokenizer_service)
587
- prompt: str = request_state.request.prompt
588
- num_prompt_tokens: int = window_service.get_num_tokens(prompt)
589
-
590
- # Total number of tokens in the completion.
591
- num_completion_tokens: int = sum([len(completion.tokens) for completion in request_state.result.completions])
592
- # Don't include prompt in number of generated tokens (e.g., for language modeling).
593
- # Assume that tokens for different completions are generated sequentially (instead of batched) when
594
- # computing num_output_tokens (for the purpose of runtime estimation).
595
- num_output_tokens: int = num_completion_tokens
596
- if request_state.request.echo_prompt:
597
- # num_prompt_tokens > num_output_tokens can happen if tokenizer doesn't round trip.
598
- if num_prompt_tokens <= num_output_tokens:
599
- num_output_tokens -= num_prompt_tokens
600
- else:
601
- hlog(
602
- f"WARNING: num_prompt_tokens ({num_prompt_tokens}) > num_output_tokens ({num_output_tokens}) "
603
- f"for prompt: {prompt}"
604
- )
605
- num_output_tokens = 0
606
-
607
- idealized_runtime: Optional[float] = compute_estimated_time_from_prompt_size_and_num_output_tokens(
608
- request_state, self.inference_idealized_runtimes_dict, num_prompt_tokens, num_output_tokens
609
- )
610
-
611
- denoised_runtime: Optional[float] = compute_estimated_time_from_prompt_size_and_num_output_tokens(
612
- request_state, self.inference_denoised_runtimes_dict, num_prompt_tokens, num_output_tokens
613
- )
614
- # Denoised runtime for offline models is just runtime.
615
- # We divide by batch_size to get approximate per-input runtime.
616
- if runtime is not None and request_state.result.batch_size is not None:
617
- denoised_runtime = runtime / request_state.result.batch_size
618
-
619
- # Compute efficiency metrics for training.
620
- training_co2_cost: Optional[float]
621
- if request_state.request.model in self.training_efficiency_dict["carbon"]:
622
- training_co2_cost = self.training_efficiency_dict["carbon"][request_state.request.model]["value"]
623
- else:
624
- training_co2_cost = None
625
-
626
- training_energy_cost: Optional[float]
627
- if request_state.request.model in self.training_efficiency_dict["energy"]:
628
- training_energy_cost = self.training_efficiency_dict["energy"][request_state.request.model]["value"]
629
- else:
630
- training_energy_cost = None
631
-
632
- stats = [
633
- Stat(MetricName("num_prompt_tokens")).add(num_prompt_tokens),
634
- Stat(MetricName("num_completion_tokens")).add(num_completion_tokens),
635
- Stat(MetricName("num_output_tokens")).add(num_output_tokens),
636
- Stat(MetricName("training_co2_cost")).add(training_co2_cost),
637
- Stat(MetricName("training_energy_cost")).add(training_energy_cost),
638
- ]
639
- if runtime is not None:
640
- stats.append(Stat(MetricName("inference_runtime")).add(runtime))
641
- if batch_size is not None:
642
- stats.append(Stat(MetricName("batch_size")).add(batch_size))
643
- if denoised_runtime is not None:
644
- stats.append(Stat(MetricName("inference_denoised_runtime")).add(denoised_runtime))
645
- if idealized_runtime is not None:
646
- stats.append(Stat(MetricName("inference_idealized_runtime")).add(idealized_runtime))
647
- return stats
648
-
649
- def compute_finish_reason_metrics(
650
- self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
651
- ) -> List[Stat]:
652
- """Record how often generation finished due to reaching token limit, stop token(s), or end of text"""
653
- assert request_state.result is not None
654
- sequence = request_state.result.completions[0]
655
- valid_reasons = [
656
- "length",
657
- "stop",
658
- "endoftext",
659
- "unknown",
660
- ]
661
- if sequence.finish_reason is None or sequence.finish_reason["reason"] not in valid_reasons:
662
- reason = "unknown"
663
- else:
664
- reason = sequence.finish_reason["reason"]
665
- return [
666
- Stat(MetricName(f"finish_reason_{valid_reason}")).add(int(reason == valid_reason))
667
- for valid_reason in valid_reasons
668
- ]
669
-
670
- def compute_truncation_metrics(
671
- self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
672
- ) -> List[Stat]:
673
- """
674
- Record the number of training instances used in the prompt and whether
675
- even the prompt needed to be truncated (once we hit zero training instances).
676
- """
677
- return [
678
- Stat(MetricName("num_train_instances")).add(request_state.num_train_instances),
679
- Stat(MetricName("prompt_truncated")).add(request_state.prompt_truncated),
680
- ]
681
-
682
- def compute_all_general_metrics(
683
- self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
684
- ) -> List[Stat]:
685
- """
686
- Compute metrics that are common to both `evaluate_generation` and `evaluate_references`.
687
- """
688
- stats: List[Stat] = []
689
-
690
- stats.append(Stat(MetricName("num_references")).add(len(request_state.instance.references)))
691
-
692
- # Copy from adapter spec
693
- stats.append(Stat(MetricName("num_train_trials")).add(adapter_spec.num_train_trials))
694
-
695
- stats.extend(self.compute_efficiency_metrics(adapter_spec, request_state, metric_service))
696
- stats.extend(self.compute_finish_reason_metrics(adapter_spec, request_state, metric_service))
697
- stats.extend(self.compute_truncation_metrics(adapter_spec, request_state, metric_service))
698
-
699
- return stats
700
-
701
- def compute_language_modeling_metrics(
702
- self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
703
- ) -> List[Stat]:
704
- """Compute the logprob and normalization factors for the first completion"""
705
- assert request_state.result is not None
706
- sequence = request_state.result.completions[0]
707
-
708
- # Remove the empty tokens (typically generated by the AI21 tokenizer in the beginning of the text)
709
- #
710
- # Some more details about AI21 tokenizer: If the input prompt begins with a space, then
711
- # the tokenizer inserts an empty token to the beginning.
712
- # e.g. " burying him" -> ["▁"(0,0), "▁burying"(0,8), "▁him"(8,12)].
713
- # TODO(#1522): Update this comment once solved.
714
- # Since this empty token is introduced by our chunking approach, we need to remove it.
715
- tokens: List[Token]
716
- if request_state.num_conditioning_tokens > 0 and sequence.tokens[0].text == "":
717
- tokens = sequence.tokens[1:]
718
- else:
719
- tokens = sequence.tokens
720
- pred_tokens = tokens[request_state.num_conditioning_tokens :]
721
- logprob, num_perplexity_tokens, num_bytes = (
722
- sum(token.logprob for token in pred_tokens),
723
- len(pred_tokens),
724
- get_num_bytes(pred_tokens),
725
- )
726
-
727
- return [
728
- Stat(MetricName("logprob")).add(logprob),
729
- Stat(MetricName("num_perplexity_tokens")).add(num_perplexity_tokens),
730
- Stat(MetricName("num_bytes")).add(num_bytes),
731
- ]
732
-
733
153
  def evaluate_generation(
734
154
  self,
735
155
  adapter_spec: AdapterSpec,
@@ -739,15 +159,40 @@ class BasicMetric(Metric):
739
159
  ) -> List[Stat]:
740
160
  """Compute all metrics."""
741
161
  stats: List[Stat] = []
742
- stats.extend(self.compute_all_general_metrics(adapter_spec, request_state, metric_service))
162
+ stats.extend(compute_request_state_metrics(self.efficiency_metric, adapter_spec, request_state, metric_service))
743
163
 
744
164
  if len(request_state.instance.references) > 0:
745
- stats.extend(self.compute_reference_metrics(adapter_spec, request_state, metric_service))
165
+ stats.extend(compute_reference_metrics(self.names, adapter_spec, request_state, metric_service))
746
166
 
747
- stats.extend(self.compute_language_modeling_metrics(adapter_spec, request_state, metric_service))
167
+ stats.extend(compute_language_modeling_metrics(adapter_spec, request_state, metric_service))
748
168
 
749
169
  return stats
750
170
 
171
+ def derive_stats(self, stats_dict: Dict[MetricName, Stat]) -> List[Stat]:
172
+ """Derive perplexity metrics if applicable. We don't worry about splits and perturbations here."""
173
+ derived_stats: List[Stat] = []
174
+ derived_stats.extend(compute_perplexity_metrics(stats_dict))
175
+ return derived_stats
176
+
177
+ def derive_per_instance_stats(self, per_instance_stats: Dict[Instance, List[Stat]]) -> List[Stat]:
178
+ """Derive calibration metrics if applicable. We don't worry about splits and perturbations here."""
179
+ derived_stats: List[Stat] = []
180
+ derived_stats.extend(compute_calibration_metrics(per_instance_stats))
181
+ return derived_stats
182
+
183
+
184
+ class BasicReferenceMetric(ReferenceMetric):
185
+ """
186
+ Defines basic metrics for Scenarios that use one Request per Reference instead of
187
+ one per Instance.
188
+ """
189
+
190
+ def __init__(self):
191
+ self.efficiency_metric = EfficiencyMetric()
192
+
193
+ def __repr__(self):
194
+ return "BasicReferenceMetric"
195
+
751
196
  def evaluate_references(
752
197
  self,
753
198
  adapter_spec: AdapterSpec,
@@ -777,7 +222,7 @@ class BasicMetric(Metric):
777
222
  assert len(request_state.result.completions) == 1
778
223
 
779
224
  reference_index = request_state.reference_index
780
- sequence: Sequence = request_state.result.completions[0]
225
+ sequence: GeneratedOutput = request_state.result.completions[0]
781
226
  reference: str = request_state.instance.references[reference_index].output.text
782
227
 
783
228
  # Find the span of the completion that matches the reference.
@@ -799,7 +244,9 @@ class BasicMetric(Metric):
799
244
  num_choices = len(references)
800
245
 
801
246
  tokenizer_service: TokenizerService = metric_service
802
- window_service: WindowService = WindowServiceFactory.get_window_service(adapter_spec.model, tokenizer_service)
247
+ window_service: WindowService = WindowServiceFactory.get_window_service(
248
+ adapter_spec.model_deployment, tokenizer_service
249
+ )
803
250
  reference_stats: Dict[ReferenceKey, ReferenceStat] = {}
804
251
  for request_state in reference_request_states:
805
252
  assert request_state.reference_index is not None and request_state.request_mode is not None
@@ -822,8 +269,14 @@ class BasicMetric(Metric):
822
269
  raise ValueError(f"Unknown adapter method: {adapter_spec.method}")
823
270
 
824
271
  stats: List[Stat] = []
825
- stats.extend(self.compute_all_general_metrics(adapter_spec, request_state, metric_service))
826
272
 
273
+ general_metrics: Dict[MetricName, Stat] = {}
274
+ for request_state in reference_request_states:
275
+ for stat in compute_request_state_metrics(
276
+ self.efficiency_metric, adapter_spec, request_state, metric_service
277
+ ):
278
+ merge_stat(general_metrics, stat)
279
+ stats.extend(general_metrics.values())
827
280
  max_prob = np.max(scipy.special.softmax(reference_scores))
828
281
 
829
282
  # Multiple references may attain the same maximal score; in such cases,
@@ -842,18 +295,96 @@ class BasicMetric(Metric):
842
295
  )
843
296
  return stats
844
297
 
845
- def derive_stats(self, stats_dict: Dict[MetricName, Stat]) -> List[Stat]:
846
- """Derive perplexity metrics if applicable. We don't worry about splits and perturbations here."""
847
- derived_stats: List[Stat] = []
848
- derived_stats.extend(compute_perplexity_metrics(stats_dict))
849
- return derived_stats
850
298
 
851
- def derive_per_instance_stats(self, per_instance_stats: Dict[Instance, List[Stat]]) -> List[Stat]:
852
- """Derive calibration metrics if applicable. We don't worry about splits and perturbations here."""
853
- derived_stats: List[Stat] = []
854
- derived_stats.extend(compute_calibration_metrics(per_instance_stats))
855
- derived_stats.append(Stat(MetricName("num_instances")).add(len(per_instance_stats)))
856
- return derived_stats
299
+ def compute_request_state_metrics(
300
+ efficiency_metric: EfficiencyMetric,
301
+ adapter_spec: AdapterSpec,
302
+ request_state: RequestState,
303
+ metric_service: MetricService,
304
+ ) -> List[Stat]:
305
+ """
306
+ Compute metrics that are common to both `evaluate_generation` and `evaluate_references`.
307
+ """
308
+ stats: List[Stat] = []
309
+
310
+ stats.append(Stat(MetricName("num_references")).add(len(request_state.instance.references)))
311
+
312
+ # Copy from adapter spec
313
+ stats.append(Stat(MetricName("num_train_trials")).add(adapter_spec.num_train_trials))
314
+
315
+ stats.extend(efficiency_metric.compute_efficiency_metrics(adapter_spec, request_state, metric_service))
316
+ stats.extend(_compute_finish_reason_metrics(adapter_spec, request_state, metric_service))
317
+ stats.extend(_compute_truncation_metrics(adapter_spec, request_state, metric_service))
318
+
319
+ return stats
320
+
321
+
322
+ def _compute_finish_reason_metrics(
323
+ adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
324
+ ) -> List[Stat]:
325
+ """Record how often generation finished due to reaching token limit, stop token(s), or end of text"""
326
+ assert request_state.result is not None
327
+ sequence = request_state.result.completions[0]
328
+ valid_reasons = [
329
+ "length",
330
+ "stop",
331
+ "endoftext",
332
+ "unknown",
333
+ ]
334
+ if sequence.finish_reason is None or sequence.finish_reason["reason"] not in valid_reasons:
335
+ reason = "unknown"
336
+ else:
337
+ reason = sequence.finish_reason["reason"]
338
+ return [
339
+ Stat(MetricName(f"finish_reason_{valid_reason}")).add(int(reason == valid_reason))
340
+ for valid_reason in valid_reasons
341
+ ]
342
+
343
+
344
+ def _compute_truncation_metrics(
345
+ adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
346
+ ) -> List[Stat]:
347
+ """
348
+ Record the number of training instances used in the prompt and whether
349
+ even the prompt needed to be truncated (once we hit zero training instances).
350
+ """
351
+ return [
352
+ Stat(MetricName("num_train_instances")).add(request_state.num_train_instances),
353
+ Stat(MetricName("prompt_truncated")).add(request_state.prompt_truncated),
354
+ ]
355
+
356
+
357
+ def compute_language_modeling_metrics(
358
+ adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
359
+ ) -> List[Stat]:
360
+ """Compute the logprob and normalization factors for the first completion"""
361
+ assert request_state.result is not None
362
+ sequence = request_state.result.completions[0]
363
+
364
+ # Remove the empty tokens (typically generated by the AI21 tokenizer in the beginning of the text)
365
+ #
366
+ # Some more details about AI21 tokenizer: If the input prompt begins with a space, then
367
+ # the tokenizer inserts an empty token to the beginning.
368
+ # e.g. " burying him" -> ["▁"(0,0), "▁burying"(0,8), "▁him"(8,12)].
369
+ # TODO(#1522): Update this comment once solved.
370
+ # Since this empty token is introduced by our chunking approach, we need to remove it.
371
+ tokens: List[Token]
372
+ if request_state.num_conditioning_tokens > 0 and sequence.tokens[0].text == "":
373
+ tokens = sequence.tokens[1:]
374
+ else:
375
+ tokens = sequence.tokens
376
+ pred_tokens = tokens[request_state.num_conditioning_tokens :]
377
+ logprob, num_perplexity_tokens, num_bytes = (
378
+ sum(token.logprob for token in pred_tokens),
379
+ len(pred_tokens),
380
+ get_num_bytes(pred_tokens),
381
+ )
382
+
383
+ return [
384
+ Stat(MetricName("logprob")).add(logprob),
385
+ Stat(MetricName("num_perplexity_tokens")).add(num_perplexity_tokens),
386
+ Stat(MetricName("num_bytes")).add(num_bytes),
387
+ ]
857
388
 
858
389
 
859
390
  def _has_non_zero_valued_logprobs(per_instance_stats: Dict[Instance, List[Stat]]) -> bool: