crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -1,4 +1,3 @@
1
- import os
2
1
  from typing import Dict, Optional, List
3
2
  from dataclasses import dataclass
4
3
 
@@ -7,10 +6,12 @@ import yaml
7
6
 
8
7
  from helm.common.hierarchical_logger import hlog
9
8
  from helm.common.object_spec import ObjectSpec
10
- from helm.proxy.models import ALL_MODELS, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, MODEL_NAME_TO_MODEL, TEXT_MODEL_TAG, Model
11
-
12
-
13
- MODEL_DEPLOYMENTS_FILE = "model_deployments.yaml"
9
+ from helm.benchmark.model_metadata_registry import (
10
+ ModelMetadata,
11
+ get_model_metadata,
12
+ get_unknown_model_metadata,
13
+ register_model_metadata,
14
+ )
14
15
 
15
16
 
16
17
  class ClientSpec(ObjectSpec):
@@ -23,65 +24,99 @@ class WindowServiceSpec(ObjectSpec):
23
24
 
24
25
  @dataclass(frozen=True)
25
26
  class ModelDeployment:
26
- """A model deployment is an accessible instance of this model (e.g. a hosted endpoint).
27
-
28
- A model can have multiple model deployments."""
27
+ """
28
+ A model deployment is an accessible instance of this model (e.g., a hosted endpoint).
29
+ A model can have multiple model deployments.
30
+ """
29
31
 
30
32
  name: str
31
- """Name of the model deployment."""
33
+ """Name of the model deployment. Usually formatted as "<hosting_group>/<engine_name>".
34
+ Example: "huggingface/t5-11b"."""
32
35
 
33
36
  client_spec: ClientSpec
34
37
  """Specification for instantiating the client for this model deployment."""
35
38
 
36
39
  model_name: Optional[str] = None
37
- """Name of the model that this model deployment is for.
38
-
39
- If unset, defaults to the the same value as `name`."""
40
+ """Name of the model that this model deployment is for. Refers to the field "name" in the Model class.
41
+ If unset, defaults to the same value as `name`."""
40
42
 
41
43
  tokenizer_name: Optional[str] = None
42
- """Tokenizer for this model deployment.
43
-
44
- If unset, auto-inferred by the WindowService."""
44
+ """Tokenizer for this model deployment. If unset, auto-inferred by the WindowService."""
45
45
 
46
46
  window_service_spec: Optional[WindowServiceSpec] = None
47
- """Specification for instantiating the window service for this model deployment"""
47
+ """Specification for instantiating the window service for this model deployment."""
48
48
 
49
49
  max_sequence_length: Optional[int] = None
50
50
  """Maximum sequence length for this model deployment."""
51
51
 
52
52
  max_request_length: Optional[int] = None
53
53
  """Maximum request length for this model deployment.
54
-
55
54
  If unset, defaults to the same value as max_sequence_length."""
56
55
 
56
+ max_sequence_and_generated_tokens_length: Optional[int] = None
57
+ """The max length of the model input and output tokens.
58
+ Some models (like Anthropic/Claude and Megatron) have a specific limit sequence length + max_token.
59
+ If unset, defaults to INT_MAX (i.e., no limit)."""
60
+
61
+ deprecated: bool = False
62
+ """Whether this model deployment is deprecated."""
63
+
64
+ @property
65
+ def host_organization(self) -> str:
66
+ """
67
+ Extracts the host group from the model deployment name.
68
+ Example: "huggingface" from "huggingface/t5-11b"
69
+ This can be different from the creator organization (for example "together")
70
+ """
71
+ return self.name.split("/")[0]
72
+
73
+ @property
74
+ def engine(self) -> str:
75
+ """
76
+ Extracts the model engine from the model deployment name.
77
+ Example: 'ai21/j1-jumbo' => 'j1-jumbo'
78
+ """
79
+ return self.name.split("/")[1]
80
+
81
+ def __post_init__(self):
82
+ if not self.model_name:
83
+ object.__setattr__(self, "model_name", self.name)
84
+
57
85
 
58
86
  @dataclass(frozen=True)
59
87
  class ModelDeployments:
60
88
  model_deployments: List[ModelDeployment]
61
89
 
62
90
 
63
- _name_to_model_deployment: Dict[str, ModelDeployment] = {}
91
+ ALL_MODEL_DEPLOYMENTS: List[ModelDeployment] = []
92
+ DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT: Dict[str, ModelDeployment] = {
93
+ deployment.name: deployment for deployment in ALL_MODEL_DEPLOYMENTS
94
+ }
64
95
 
65
96
 
66
97
  def register_model_deployment(model_deployment: ModelDeployment) -> None:
67
- hlog(f"Registered model deployment {model_deployment.name}")
68
- _name_to_model_deployment[model_deployment.name] = model_deployment
69
-
70
- # Auto-register a model with this name if none exists
71
- model_name = model_deployment.model_name or model_deployment.name
72
- if model_name not in MODEL_NAME_TO_MODEL:
73
- model = Model(
74
- group="unknown",
75
- name=model_name,
76
- tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG],
98
+ DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[model_deployment.name] = model_deployment
99
+ ALL_MODEL_DEPLOYMENTS.append(model_deployment)
100
+
101
+ model_name: str = model_deployment.model_name or model_deployment.name
102
+
103
+ model_metadata: ModelMetadata
104
+ try:
105
+ model_metadata = get_model_metadata(model_name)
106
+ except ValueError:
107
+ hlog(
108
+ f"WARNING: Could not find model metadata for model {model_name} of model deployment {model_deployment.name}"
77
109
  )
78
- MODEL_NAME_TO_MODEL[model_name] = model
79
- ALL_MODELS.append(model)
80
- hlog(f"Registered default metadata for model {model_name}")
110
+ model_metadata = get_unknown_model_metadata(model_name)
111
+ register_model_metadata(model_metadata)
112
+ deployment_names: List[str] = model_metadata.deployment_names or [model_metadata.name]
113
+ if model_deployment.name not in deployment_names:
114
+ if model_metadata.deployment_names is None:
115
+ model_metadata.deployment_names = []
116
+ model_metadata.deployment_names.append(model_deployment.name)
81
117
 
82
118
 
83
119
  def register_model_deployments_from_path(path: str) -> None:
84
- global _name_to_model_deployment
85
120
  hlog(f"Reading model deployments from {path}...")
86
121
  with open(path, "r") as f:
87
122
  raw = yaml.safe_load(f)
@@ -90,12 +125,100 @@ def register_model_deployments_from_path(path: str) -> None:
90
125
  register_model_deployment(model_deployment)
91
126
 
92
127
 
93
- def maybe_register_model_deployments_from_base_path(base_path: str) -> None:
94
- """Register model deployments from prod_env/model_deployments.yaml"""
95
- path = os.path.join(base_path, MODEL_DEPLOYMENTS_FILE)
96
- if os.path.exists(path):
97
- register_model_deployments_from_path(path)
98
-
99
-
100
- def get_model_deployment(name: str) -> Optional[ModelDeployment]:
101
- return _name_to_model_deployment.get(name)
128
+ def get_model_deployment(name: str, warn_deprecated: bool = False) -> ModelDeployment:
129
+ if name not in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
130
+ raise ValueError(f"Model deployment {name} not found")
131
+ deployment: ModelDeployment = DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[name]
132
+ if deployment.deprecated and warn_deprecated:
133
+ hlog(f"WARNING: DEPLOYMENT Model deployment {name} is deprecated")
134
+ return deployment
135
+
136
+
137
+ def get_model_deployment_host_organization(name: str) -> str:
138
+ """Return the host organization name based on the model deployment name.
139
+
140
+ Example: "huggingface/t5-11b" -> "huggingface"""
141
+ deployment: ModelDeployment = get_model_deployment(name)
142
+ return deployment.host_organization
143
+
144
+
145
+ def get_model_names_with_tokenizer(tokenizer_name: str) -> List[str]:
146
+ """Return the names of all models with the given tokenizer."""
147
+ deployments: List[ModelDeployment] = [
148
+ deployment for deployment in ALL_MODEL_DEPLOYMENTS if deployment.tokenizer_name == tokenizer_name
149
+ ]
150
+ return [deployment.model_name or deployment.name for deployment in deployments]
151
+
152
+
153
+ def get_default_model_deployment_for_model(
154
+ model_name: str, warn_arg_deprecated: bool = False, ignore_deprecated: bool = False
155
+ ) -> Optional[str]:
156
+ """Returns a valid model deployment name corresponding to the given model arg.
157
+ This is used as a backwards compatibility layer for model names that are now moved to model deployments.
158
+ Example: "anthropic/claude-v1.3" => "anthropic/claude-v1.3"
159
+ Example: "meta/llama-7b" => "together/llama-7b"
160
+
161
+ The process to find a model deployment name is as follows:
162
+ 1. If there is a model deployment with the same name as the model arg, use it.
163
+ 2. If there is at least one deployment for the model, use the first one that is available.
164
+ 3. If there are no deployments for the model, returns None.
165
+
166
+ This function will also try to find a model deployment name that is not deprecated.
167
+ If there are no non-deprecated deployments, it will return the first deployment (even if it's deprecated).
168
+ If ignore_deprecated is True, this function will return None if the model deployment is deprecated.
169
+
170
+ If warn_arg_deprecated is True, this function will print a warning if the model deployment name is not the same
171
+ as the model arg. This is to remind the user that the model name is deprecated and should be replaced with
172
+ the model deployment name (in their config).
173
+
174
+ Args:
175
+ model_arg: The model arg to convert to a model deployment name.
176
+ warn_arg_deprecated: Whether to print a warning if the model deployment name is not the same as the model arg.
177
+ ignore_deprecated: Whether to return None if the model deployment is deprecated.
178
+ """
179
+
180
+ # If there is a model deployment with the same name as the model arg, use it.
181
+ if model_name in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
182
+ deployment: ModelDeployment = DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[model_name]
183
+ if deployment.deprecated and ignore_deprecated:
184
+ if warn_arg_deprecated:
185
+ hlog(f"WARNING: Model deployment {model_name} is deprecated")
186
+ return None
187
+ return deployment.name
188
+
189
+ # If there is at least one deployment for the model, use the first one that is available.
190
+ available_deployments: List[ModelDeployment] = [
191
+ deployment for deployment in ALL_MODEL_DEPLOYMENTS if deployment.model_name == model_name
192
+ ]
193
+ if len(available_deployments) > 0:
194
+ available_deployment_names: List[str] = [deployment.name for deployment in available_deployments]
195
+ if warn_arg_deprecated:
196
+ hlog("WARNING: Model name is deprecated. Please use the model deployment name instead.")
197
+ hlog(f"Available model deployments for model {model_name}: {available_deployment_names}")
198
+
199
+ # Additionally, if there is a non-deprecated deployment, use it.
200
+ non_deprecated_deployments: List[ModelDeployment] = [
201
+ deployment for deployment in available_deployments if not deployment.deprecated
202
+ ]
203
+ if len(non_deprecated_deployments) > 0:
204
+ chosen_deployment = non_deprecated_deployments[0]
205
+ # There are no non-deprecated deployments, so there are two options:
206
+ # 1. If we can return an empty string, return it. (no model deployment is available)
207
+ # 2. If we can't return an empty string, return the first deployment (even if it's deprecated).
208
+ elif ignore_deprecated:
209
+ return None
210
+ else:
211
+ chosen_deployment = available_deployments[0]
212
+ if warn_arg_deprecated:
213
+ hlog(f"WARNING: All model deployments for model {model_name} are deprecated.")
214
+ if warn_arg_deprecated:
215
+ hlog(
216
+ f"Choosing {chosen_deployment.name} (the first one) as "
217
+ f"the default model deployment for model {model_name}"
218
+ )
219
+ hlog("If you want to use a different model deployment, please specify it explicitly.")
220
+ return chosen_deployment.name
221
+
222
+ # Some models are added but have no deployments yet.
223
+ # In this case, we return None.
224
+ return None
@@ -1,50 +1,142 @@
1
- import os
2
- from typing import Optional, List
1
+ from typing import Dict, Optional, List
3
2
  from dataclasses import dataclass, field
4
3
  from datetime import date
5
4
 
6
5
  import dacite
7
6
  import yaml
8
7
 
9
- from helm.proxy.models import ALL_MODELS, MODEL_NAME_TO_MODEL, Model
10
8
 
9
+ # Different modalities
10
+ TEXT_MODEL_TAG: str = "TEXT_MODEL_TAG"
11
+ IMAGE_MODEL_TAG: str = "IMAGE_MODEL_TAG"
12
+ CODE_MODEL_TAG: str = "CODE_MODEL_TAG"
13
+ EMBEDDING_MODEL_TAG: str = "EMBEDDING_MODEL_TAG"
11
14
 
12
- MODEL_METADATA_FILE = "model_metadata.yaml"
15
+ # Some model APIs have limited functionalities
16
+ FULL_FUNCTIONALITY_TEXT_MODEL_TAG: str = "FULL_FUNCTIONALITY_TEXT_MODEL_TAG"
17
+ LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG: str = "LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG"
13
18
 
19
+ # ChatML format
20
+ CHATML_MODEL_TAG: str = "CHATML_MODEL_TAG"
14
21
 
15
- @dataclass(frozen=True)
22
+ # OpenAI Chat format
23
+ OPENAI_CHATGPT_MODEL_TAG: str = "OPENAI_CHATGPT_MODEL_TAG"
24
+
25
+ # Mistral instruction-following format
26
+ MISTRAL_MODEL_TAG: str = "MISTRAL_MODEL_TAG"
27
+
28
+ # For Anthropic models
29
+ ANTHROPIC_CLAUDE_1_MODEL_TAG: str = "ANTHROPIC_CLAUDE_1_MODEL_TAG"
30
+ ANTHROPIC_CLAUDE_2_MODEL_TAG: str = "ANTHROPIC_CLAUDE_2_MODEL_TAG"
31
+ ANTHROPIC_CLAUDE_3_MODEL_TAG: str = "ANTHROPIC_CLAUDE_3_MODEL_TAG"
32
+
33
+ GOOGLE_PALM_2_MODEL_TAG: str = "GOOGLE_PALM_2_MODEL_TAG"
34
+ GOOGLE_GEMINI_MODEL_TAG: str = "GOOGLE_GEMINI_MODEL_TAG"
35
+ GOOGLE_GEMMA_INSTRUCT_MODEL_TAG: str = "GOOGLE_GEMMA_INSTRUCT_MODEL_TAG"
36
+
37
+ # Models which emit garbage tokens when temperature=0.
38
+ BUGGY_TEMP_0_TAG: str = "BUGGY_TEMP_0_TAG"
39
+
40
+ # Models that are used for ablations and fine-grained analyses.
41
+ # These models are selected specifically because of their low marginal cost to evaluate.
42
+ ABLATION_MODEL_TAG: str = "ABLATION_MODEL_TAG"
43
+
44
+ # Some models (e.g., T5) have stripped newlines.
45
+ # So we cannot use \n as a stop sequence for these models.
46
+ NO_NEWLINES_TAG: str = "NO_NEWLINES_TAG"
47
+
48
+ # Some models (e.g., UL2) require a prefix (e.g., [NLG]) in the
49
+ # prompts to indicate the mode before doing inference.
50
+ NLG_PREFIX_TAG: str = "NLG_PREFIX_TAG"
51
+
52
+ # Some models can follow instructions.
53
+ INSTRUCTION_FOLLOWING_MODEL_TAG: str = "INSTRUCTION_FOLLOWING_MODEL_TAG"
54
+
55
+ # For text-to-image models
56
+ TEXT_TO_IMAGE_MODEL_TAG: str = "TEXT_TO_IMAGE_MODEL_TAG"
57
+
58
+ # For Vision-langauge models (VLMs)
59
+ VISION_LANGUAGE_MODEL_TAG: str = "VISION_LANGUAGE_MODEL_TAG"
60
+ # IDEFICS require a special prompt format (see `IDEFICSInstructRunExpander`)
61
+ IDEFICS_INSTRUCT_MODEL_TAG: str = "IDEFICS_INSTRUCT_MODEL_TAG"
62
+ IDEFICS_MODEL_TAG: str = "IDEFICS_MODEL_TAG"
63
+ # Llava should use a special prompt format (see `LlavaRunExpander`)
64
+ LLAVA_MODEL_TAG: str = "LLAVA_MODEL_TAG"
65
+ # OpenFlamingo has a special prompt format (see `OpenFlamingoRunExpander`)
66
+ OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG"
67
+ # Some VLMs do not support multiple images in the prompt
68
+ LIMITED_FUNCTIONALITY_VLM_TAG: str = "LIMITED_FUNCTIONALITY_VLM_TAG"
69
+ FULL_FUNCTIONALITY_VLM_TAG: str = "FULL_FUNCTIONALITY_VLM_TAG"
70
+
71
+
72
+ # Frozen is set to false as the model_deployment_registry.py file
73
+ # might populate the deployment_names field.
74
+
75
+
76
+ @dataclass(frozen=False)
16
77
  class ModelMetadata:
17
78
  name: str
18
- """Name of the model e.g. "meta/llama-2"."""
79
+ """Name of the model group (e.g., "openai/davinci"). This is the name of the model,
80
+ not the name of the deployment.
81
+ Usually formatted as "<creator_organization>/<engine_name>". Example: "ai21/j1-jumbo"."""
19
82
 
20
- creator_organization: Optional[str] = None
21
- """Organization that originally created the model (e.g. "meta")."""
83
+ creator_organization_name: str
84
+ """Name of the organization that created the model."""
22
85
 
23
- access: Optional[str] = None
24
- """How this model is available (e.g., limited).
86
+ display_name: str
87
+ """Name that is going to be displayed to the user (on the website, etc.)."""
25
88
 
26
- If there are multiple deployments, this should be the most permissive access across
27
- all deployments."""
89
+ description: str
90
+ """Description of the model, to be displayed on the website."""
28
91
 
29
- todo: bool = False
30
- """Whether we have yet to evaluate this model."""
92
+ access: str
93
+ """Description of the access level of the model. Should be one of the following:
94
+ - "open": the model is open-source and can be downloaded from the internet.
95
+ - "closed": not accessible
96
+ - "limited": accessible with an API key.
97
+ If there are multiple deployments, this should be the most permissive access across all deployments."""
31
98
 
32
- release_date: Optional[date] = None
33
- """When the model was released."""
99
+ release_date: Optional[date]
100
+ """Release date of the model."""
34
101
 
35
- num_parameters: Optional[int] = None
36
- """The number of model parameters.
102
+ tags: List[str] = field(default_factory=list)
103
+ """Tags corresponding to the properties of the model."""
37
104
 
105
+ num_parameters: Optional[int] = None
106
+ """Number of parameters in the model.
38
107
  This should be a string as the number of parameters is usually a round number (175B),
39
108
  but we set it as an int for plotting purposes."""
40
109
 
41
- tags: List[str] = field(default_factory=list)
42
- """"""
110
+ deployment_names: Optional[List[str]] = None
111
+ """List of the model deployments for this model. Should at least contain one model deployment.
112
+ Refers to the field "name" in the ModelDeployment class. Defaults to a single model deployment
113
+ with the same name as the model."""
114
+
115
+ @property
116
+ def creator_organization(self) -> str:
117
+ """
118
+ Extracts the creator organization from the model name.
119
+ Example: 'ai21/j1-jumbo' => 'ai21'
120
+ This can be different from the hosting organization.
121
+ """
122
+ return self.name.split("/")[0]
123
+
124
+ @property
125
+ def engine(self) -> str:
126
+ """
127
+ Extracts the model engine from the model name.
128
+ Example: 'ai21/j1-jumbo' => 'j1-jumbo'
129
+ """
130
+ return self.name.split("/")[1]
43
131
 
44
132
 
45
133
  @dataclass(frozen=True)
46
134
  class ModelMetadataList:
47
- models: List[ModelMetadata]
135
+ models: List[ModelMetadata] = field(default_factory=list)
136
+
137
+
138
+ ALL_MODELS_METADATA: List[ModelMetadata] = []
139
+ MODEL_NAME_TO_MODEL_METADATA: Dict[str, ModelMetadata] = {model.name: model for model in ALL_MODELS_METADATA}
48
140
 
49
141
 
50
142
  def register_model_metadata_from_path(path: str) -> None:
@@ -55,17 +147,71 @@ def register_model_metadata_from_path(path: str) -> None:
55
147
  # serialization format for dates
56
148
  model_metadata_list = dacite.from_dict(ModelMetadataList, raw)
57
149
  for model_metadata in model_metadata_list.models:
58
- model = Model(
59
- group="none", # TODO: Group should be part of model deployment, not model
60
- name=model_metadata.name,
61
- tags=model_metadata.tags,
62
- )
63
- MODEL_NAME_TO_MODEL[model_metadata.name] = model
64
- ALL_MODELS.append(model)
65
-
66
-
67
- def maybe_register_model_metadata_from_base_path(base_path: str) -> None:
68
- """Register model metadata from prod_env/model_metadata.yaml"""
69
- path = os.path.join(base_path, MODEL_METADATA_FILE)
70
- if os.path.exists(path):
71
- register_model_metadata_from_path(path)
150
+ register_model_metadata(model_metadata)
151
+
152
+
153
+ def register_model_metadata(model_metadata: ModelMetadata) -> None:
154
+ """Register a single model configuration."""
155
+ ALL_MODELS_METADATA.append(model_metadata)
156
+ MODEL_NAME_TO_MODEL_METADATA[model_metadata.name] = model_metadata
157
+
158
+
159
+ def get_model_metadata(model_name: str) -> ModelMetadata:
160
+ """Return the `ModelMetadata` for the model name."""
161
+ if model_name not in MODEL_NAME_TO_MODEL_METADATA:
162
+ raise ValueError(f"No model with name: {model_name}")
163
+
164
+ return MODEL_NAME_TO_MODEL_METADATA[model_name]
165
+
166
+
167
+ def get_all_models() -> List[str]:
168
+ """Return all model names."""
169
+ return list(MODEL_NAME_TO_MODEL_METADATA.keys())
170
+
171
+
172
+ def get_model_names_with_tag(tag: str) -> List[str]:
173
+ """Return all model names of models with the given tag."""
174
+ return [model.name for model in ALL_MODELS_METADATA if tag in model.tags]
175
+
176
+
177
+ def model_has_tag(model_name: str, tag: str) -> bool:
178
+ """Return True if the model has the given tag. False otherwise."""
179
+ return tag in get_model_metadata(model_name).tags
180
+
181
+
182
+ def get_all_text_models() -> List[str]:
183
+ """Return all model names of text models."""
184
+ return get_model_names_with_tag(TEXT_MODEL_TAG)
185
+
186
+
187
+ def get_all_code_models() -> List[str]:
188
+ """Return all model names of code models."""
189
+ return get_model_names_with_tag(CODE_MODEL_TAG)
190
+
191
+
192
+ def get_all_instruction_following_models() -> List[str]:
193
+ """Return all model names of instruction following models."""
194
+ return get_model_names_with_tag(INSTRUCTION_FOLLOWING_MODEL_TAG)
195
+
196
+
197
+ def is_text_to_image_model(model_name: str) -> bool:
198
+ """Returns True if the model is a text-to-image model. False otherwise."""
199
+ return model_has_tag(model_name, TEXT_TO_IMAGE_MODEL_TAG)
200
+
201
+
202
+ def is_vlm(model_name: str) -> bool:
203
+ """Returns True if the model is a vision-language model (VLM). False otherwise."""
204
+ return model_has_tag(model_name, VISION_LANGUAGE_MODEL_TAG)
205
+
206
+
207
+ def get_unknown_model_metadata(helm_model_name: str) -> ModelMetadata:
208
+ """Return placeholder ModelMetadata for an unknown model."""
209
+ return ModelMetadata(
210
+ name=helm_model_name,
211
+ creator_organization_name="Unknown",
212
+ display_name=helm_model_name,
213
+ description=helm_model_name,
214
+ access="open",
215
+ release_date=date.today(),
216
+ tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG],
217
+ )
@@ -0,0 +1,133 @@
1
+ import signal
2
+ import threading
3
+ import traceback
4
+ from typing import List
5
+ import os
6
+ import time
7
+ import torch
8
+ import torch.multiprocessing as multiprocessing
9
+ from concurrent.futures import ProcessPoolExecutor as Pool
10
+ from tqdm import tqdm
11
+
12
+ from helm.benchmark.config_registry import (
13
+ register_configs_from_directory,
14
+ register_builtin_configs_from_helm_package,
15
+ )
16
+ from helm.benchmark.executor import ExecutionSpec
17
+ from helm.benchmark.runner import Runner, RunSpec, RunnerError
18
+ from helm.common.hierarchical_logger import hlog, htrack_block
19
+ from helm.benchmark.runner_config_registry import RUNNER_CONFIG
20
+
21
+ _MAX_CONCURRENT_WORKERS_ENV_NAME = "HELM_MAX_CONCURRENT_WORKERS"
22
+
23
+
24
+ # From
25
+ # https://stackoverflow.com/questions/71300294/how-to-terminate-pythons-processpoolexecutor-when-parent-process-dies
26
+ def start_thread_to_terminate_when_parent_process_dies(ppid):
27
+ pid = os.getpid()
28
+
29
+ def f():
30
+ while True:
31
+ try:
32
+ os.kill(ppid, 0)
33
+ except OSError:
34
+ os.kill(pid, signal.SIGTERM)
35
+ time.sleep(1)
36
+
37
+ thread = threading.Thread(target=f, daemon=True)
38
+ thread.start()
39
+
40
+
41
+ def initialize_worker(gpu_id: int):
42
+ hlog(f"Worker {gpu_id} initializing")
43
+
44
+ # Wait for 0.1 seconds to ensure all workers are initialized with different CUDA_VISIBLE_DEVICES
45
+ time.sleep(0.1)
46
+
47
+ # Pin GPU to worker process
48
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
49
+
50
+ # Necessary for code_metrics in humaneval to work properly
51
+ multiprocessing.set_start_method("fork", force=True)
52
+
53
+
54
+ class MultiGPURunner(Runner):
55
+ """Runner that runs the entire benchmark on multiple GPUs.
56
+
57
+ This is a thin wrapper around `Runner` that runs the entire benchmark on
58
+ multiple GPUs using `multiprocessing`.
59
+
60
+ Note that this runner will load multiple models into memory at the same
61
+ time if your running configuration specifies that, similar to the `Runner`
62
+ class. `SlurmRunner` on the other hand will load at most one model on a
63
+ GPU"""
64
+
65
+ def __init__(
66
+ self,
67
+ execution_spec: ExecutionSpec,
68
+ output_path: str,
69
+ suite: str,
70
+ skip_instances: bool,
71
+ cache_instances: bool,
72
+ cache_instances_only: bool,
73
+ skip_completed_runs: bool,
74
+ exit_on_error: bool,
75
+ ):
76
+ super().__init__(
77
+ execution_spec=execution_spec,
78
+ output_path=output_path,
79
+ suite=suite,
80
+ skip_instances=skip_instances,
81
+ cache_instances=cache_instances,
82
+ cache_instances_only=cache_instances_only,
83
+ skip_completed_runs=skip_completed_runs,
84
+ exit_on_error=exit_on_error,
85
+ )
86
+ # Configure max concurrent worker jobs from the environment variable.
87
+ env_max_concurrent_workers = os.getenv(_MAX_CONCURRENT_WORKERS_ENV_NAME)
88
+ self.max_concurrent_workers = (
89
+ int(env_max_concurrent_workers)
90
+ if env_max_concurrent_workers
91
+ else (
92
+ RUNNER_CONFIG.helm_max_concurrent_workers
93
+ if RUNNER_CONFIG.helm_max_concurrent_workers > 0
94
+ else torch.cuda.device_count()
95
+ )
96
+ )
97
+
98
+ def safe_run_one(self, run_spec: RunSpec):
99
+ register_builtin_configs_from_helm_package()
100
+ if self.executor.execution_spec.local_path is not None:
101
+ register_configs_from_directory(self.executor.execution_spec.local_path)
102
+
103
+ try:
104
+ with htrack_block(f"Running {run_spec.name}"):
105
+ self.run_one(run_spec)
106
+ except Exception as e:
107
+ hlog(f"Error when running {run_spec.name}:\n{traceback.format_exc()}")
108
+ return e
109
+
110
+ def run_all(self, run_specs: List[RunSpec]):
111
+ """Run the entire benchmark on multiple GPU"""
112
+
113
+ # Set the start method to forkserver to avoid issues with CUDA.
114
+ multiprocessing.set_start_method("forkserver")
115
+
116
+ with Pool(
117
+ max_workers=self.max_concurrent_workers,
118
+ initializer=start_thread_to_terminate_when_parent_process_dies,
119
+ initargs=(os.getpid(),),
120
+ ) as pool:
121
+ # Pin GPUs to each worker process
122
+ pool.map(initialize_worker, [i for i in range(self.max_concurrent_workers)])
123
+
124
+ # Run all queued tasks
125
+ error_msgs = list(tqdm(pool.map(self.safe_run_one, run_specs), total=len(run_specs), disable=None))
126
+
127
+ # Raise exception for failed runs, if any.
128
+ failed_run_names = [
129
+ run_spec.name for error_msg, run_spec in zip(error_msgs, run_specs) if error_msg is not None
130
+ ]
131
+ if failed_run_names:
132
+ failed_runs_str = ", ".join([f'"{run_name}"' for run_name in failed_run_names])
133
+ raise RunnerError(f"Failed runs: [{failed_runs_str}]")
@@ -2,10 +2,10 @@ from dataclasses import dataclass
2
2
  from typing import List, Optional
3
3
  import dacite
4
4
  import importlib_resources as resources
5
- import yaml # type: ignore
5
+ import yaml
6
6
 
7
7
  from helm.common.hierarchical_logger import htrack, hlog
8
- from helm.proxy.models import MODEL_NAME_TO_MODEL
8
+ from helm.benchmark.model_metadata_registry import MODEL_NAME_TO_MODEL_METADATA
9
9
  from helm.benchmark.presentation.schema import Schema
10
10
 
11
11
 
@@ -70,7 +70,7 @@ def validate_contamination(contamination: Contamination, schema: Schema):
70
70
  """Make sure models and groups in contamination are defined according to `schema`."""
71
71
  for point in contamination.points:
72
72
  for model in point.models:
73
- if model not in MODEL_NAME_TO_MODEL:
73
+ if model not in MODEL_NAME_TO_MODEL_METADATA:
74
74
  hlog(f"WARNING: model {model} not defined in schema")
75
75
  for group in point.groups:
76
76
  if group not in schema.name_to_run_group: