crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import os
|
|
2
1
|
from typing import Dict, Optional, List
|
|
3
2
|
from dataclasses import dataclass
|
|
4
3
|
|
|
@@ -7,10 +6,12 @@ import yaml
|
|
|
7
6
|
|
|
8
7
|
from helm.common.hierarchical_logger import hlog
|
|
9
8
|
from helm.common.object_spec import ObjectSpec
|
|
10
|
-
from helm.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
from helm.benchmark.model_metadata_registry import (
|
|
10
|
+
ModelMetadata,
|
|
11
|
+
get_model_metadata,
|
|
12
|
+
get_unknown_model_metadata,
|
|
13
|
+
register_model_metadata,
|
|
14
|
+
)
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class ClientSpec(ObjectSpec):
|
|
@@ -23,65 +24,99 @@ class WindowServiceSpec(ObjectSpec):
|
|
|
23
24
|
|
|
24
25
|
@dataclass(frozen=True)
|
|
25
26
|
class ModelDeployment:
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
A model can have multiple model deployments.
|
|
27
|
+
"""
|
|
28
|
+
A model deployment is an accessible instance of this model (e.g., a hosted endpoint).
|
|
29
|
+
A model can have multiple model deployments.
|
|
30
|
+
"""
|
|
29
31
|
|
|
30
32
|
name: str
|
|
31
|
-
"""Name of the model deployment.""
|
|
33
|
+
"""Name of the model deployment. Usually formatted as "<hosting_group>/<engine_name>".
|
|
34
|
+
Example: "huggingface/t5-11b"."""
|
|
32
35
|
|
|
33
36
|
client_spec: ClientSpec
|
|
34
37
|
"""Specification for instantiating the client for this model deployment."""
|
|
35
38
|
|
|
36
39
|
model_name: Optional[str] = None
|
|
37
|
-
"""Name of the model that this model deployment is for.
|
|
38
|
-
|
|
39
|
-
If unset, defaults to the the same value as `name`."""
|
|
40
|
+
"""Name of the model that this model deployment is for. Refers to the field "name" in the Model class.
|
|
41
|
+
If unset, defaults to the same value as `name`."""
|
|
40
42
|
|
|
41
43
|
tokenizer_name: Optional[str] = None
|
|
42
|
-
"""Tokenizer for this model deployment.
|
|
43
|
-
|
|
44
|
-
If unset, auto-inferred by the WindowService."""
|
|
44
|
+
"""Tokenizer for this model deployment. If unset, auto-inferred by the WindowService."""
|
|
45
45
|
|
|
46
46
|
window_service_spec: Optional[WindowServiceSpec] = None
|
|
47
|
-
"""Specification for instantiating the window service for this model deployment"""
|
|
47
|
+
"""Specification for instantiating the window service for this model deployment."""
|
|
48
48
|
|
|
49
49
|
max_sequence_length: Optional[int] = None
|
|
50
50
|
"""Maximum sequence length for this model deployment."""
|
|
51
51
|
|
|
52
52
|
max_request_length: Optional[int] = None
|
|
53
53
|
"""Maximum request length for this model deployment.
|
|
54
|
-
|
|
55
54
|
If unset, defaults to the same value as max_sequence_length."""
|
|
56
55
|
|
|
56
|
+
max_sequence_and_generated_tokens_length: Optional[int] = None
|
|
57
|
+
"""The max length of the model input and output tokens.
|
|
58
|
+
Some models (like Anthropic/Claude and Megatron) have a specific limit sequence length + max_token.
|
|
59
|
+
If unset, defaults to INT_MAX (i.e., no limit)."""
|
|
60
|
+
|
|
61
|
+
deprecated: bool = False
|
|
62
|
+
"""Whether this model deployment is deprecated."""
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def host_organization(self) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Extracts the host group from the model deployment name.
|
|
68
|
+
Example: "huggingface" from "huggingface/t5-11b"
|
|
69
|
+
This can be different from the creator organization (for example "together")
|
|
70
|
+
"""
|
|
71
|
+
return self.name.split("/")[0]
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def engine(self) -> str:
|
|
75
|
+
"""
|
|
76
|
+
Extracts the model engine from the model deployment name.
|
|
77
|
+
Example: 'ai21/j1-jumbo' => 'j1-jumbo'
|
|
78
|
+
"""
|
|
79
|
+
return self.name.split("/")[1]
|
|
80
|
+
|
|
81
|
+
def __post_init__(self):
|
|
82
|
+
if not self.model_name:
|
|
83
|
+
object.__setattr__(self, "model_name", self.name)
|
|
84
|
+
|
|
57
85
|
|
|
58
86
|
@dataclass(frozen=True)
|
|
59
87
|
class ModelDeployments:
|
|
60
88
|
model_deployments: List[ModelDeployment]
|
|
61
89
|
|
|
62
90
|
|
|
63
|
-
|
|
91
|
+
ALL_MODEL_DEPLOYMENTS: List[ModelDeployment] = []
|
|
92
|
+
DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT: Dict[str, ModelDeployment] = {
|
|
93
|
+
deployment.name: deployment for deployment in ALL_MODEL_DEPLOYMENTS
|
|
94
|
+
}
|
|
64
95
|
|
|
65
96
|
|
|
66
97
|
def register_model_deployment(model_deployment: ModelDeployment) -> None:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
98
|
+
DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[model_deployment.name] = model_deployment
|
|
99
|
+
ALL_MODEL_DEPLOYMENTS.append(model_deployment)
|
|
100
|
+
|
|
101
|
+
model_name: str = model_deployment.model_name or model_deployment.name
|
|
102
|
+
|
|
103
|
+
model_metadata: ModelMetadata
|
|
104
|
+
try:
|
|
105
|
+
model_metadata = get_model_metadata(model_name)
|
|
106
|
+
except ValueError:
|
|
107
|
+
hlog(
|
|
108
|
+
f"WARNING: Could not find model metadata for model {model_name} of model deployment {model_deployment.name}"
|
|
77
109
|
)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
110
|
+
model_metadata = get_unknown_model_metadata(model_name)
|
|
111
|
+
register_model_metadata(model_metadata)
|
|
112
|
+
deployment_names: List[str] = model_metadata.deployment_names or [model_metadata.name]
|
|
113
|
+
if model_deployment.name not in deployment_names:
|
|
114
|
+
if model_metadata.deployment_names is None:
|
|
115
|
+
model_metadata.deployment_names = []
|
|
116
|
+
model_metadata.deployment_names.append(model_deployment.name)
|
|
81
117
|
|
|
82
118
|
|
|
83
119
|
def register_model_deployments_from_path(path: str) -> None:
|
|
84
|
-
global _name_to_model_deployment
|
|
85
120
|
hlog(f"Reading model deployments from {path}...")
|
|
86
121
|
with open(path, "r") as f:
|
|
87
122
|
raw = yaml.safe_load(f)
|
|
@@ -90,12 +125,100 @@ def register_model_deployments_from_path(path: str) -> None:
|
|
|
90
125
|
register_model_deployment(model_deployment)
|
|
91
126
|
|
|
92
127
|
|
|
93
|
-
def
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
128
|
+
def get_model_deployment(name: str, warn_deprecated: bool = False) -> ModelDeployment:
|
|
129
|
+
if name not in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
|
|
130
|
+
raise ValueError(f"Model deployment {name} not found")
|
|
131
|
+
deployment: ModelDeployment = DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[name]
|
|
132
|
+
if deployment.deprecated and warn_deprecated:
|
|
133
|
+
hlog(f"WARNING: DEPLOYMENT Model deployment {name} is deprecated")
|
|
134
|
+
return deployment
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_model_deployment_host_organization(name: str) -> str:
|
|
138
|
+
"""Return the host organization name based on the model deployment name.
|
|
139
|
+
|
|
140
|
+
Example: "huggingface/t5-11b" -> "huggingface"""
|
|
141
|
+
deployment: ModelDeployment = get_model_deployment(name)
|
|
142
|
+
return deployment.host_organization
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_model_names_with_tokenizer(tokenizer_name: str) -> List[str]:
|
|
146
|
+
"""Return the names of all models with the given tokenizer."""
|
|
147
|
+
deployments: List[ModelDeployment] = [
|
|
148
|
+
deployment for deployment in ALL_MODEL_DEPLOYMENTS if deployment.tokenizer_name == tokenizer_name
|
|
149
|
+
]
|
|
150
|
+
return [deployment.model_name or deployment.name for deployment in deployments]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_default_model_deployment_for_model(
|
|
154
|
+
model_name: str, warn_arg_deprecated: bool = False, ignore_deprecated: bool = False
|
|
155
|
+
) -> Optional[str]:
|
|
156
|
+
"""Returns a valid model deployment name corresponding to the given model arg.
|
|
157
|
+
This is used as a backwards compatibility layer for model names that are now moved to model deployments.
|
|
158
|
+
Example: "anthropic/claude-v1.3" => "anthropic/claude-v1.3"
|
|
159
|
+
Example: "meta/llama-7b" => "together/llama-7b"
|
|
160
|
+
|
|
161
|
+
The process to find a model deployment name is as follows:
|
|
162
|
+
1. If there is a model deployment with the same name as the model arg, use it.
|
|
163
|
+
2. If there is at least one deployment for the model, use the first one that is available.
|
|
164
|
+
3. If there are no deployments for the model, returns None.
|
|
165
|
+
|
|
166
|
+
This function will also try to find a model deployment name that is not deprecated.
|
|
167
|
+
If there are no non-deprecated deployments, it will return the first deployment (even if it's deprecated).
|
|
168
|
+
If ignore_deprecated is True, this function will return None if the model deployment is deprecated.
|
|
169
|
+
|
|
170
|
+
If warn_arg_deprecated is True, this function will print a warning if the model deployment name is not the same
|
|
171
|
+
as the model arg. This is to remind the user that the model name is deprecated and should be replaced with
|
|
172
|
+
the model deployment name (in their config).
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
model_arg: The model arg to convert to a model deployment name.
|
|
176
|
+
warn_arg_deprecated: Whether to print a warning if the model deployment name is not the same as the model arg.
|
|
177
|
+
ignore_deprecated: Whether to return None if the model deployment is deprecated.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# If there is a model deployment with the same name as the model arg, use it.
|
|
181
|
+
if model_name in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
|
|
182
|
+
deployment: ModelDeployment = DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[model_name]
|
|
183
|
+
if deployment.deprecated and ignore_deprecated:
|
|
184
|
+
if warn_arg_deprecated:
|
|
185
|
+
hlog(f"WARNING: Model deployment {model_name} is deprecated")
|
|
186
|
+
return None
|
|
187
|
+
return deployment.name
|
|
188
|
+
|
|
189
|
+
# If there is at least one deployment for the model, use the first one that is available.
|
|
190
|
+
available_deployments: List[ModelDeployment] = [
|
|
191
|
+
deployment for deployment in ALL_MODEL_DEPLOYMENTS if deployment.model_name == model_name
|
|
192
|
+
]
|
|
193
|
+
if len(available_deployments) > 0:
|
|
194
|
+
available_deployment_names: List[str] = [deployment.name for deployment in available_deployments]
|
|
195
|
+
if warn_arg_deprecated:
|
|
196
|
+
hlog("WARNING: Model name is deprecated. Please use the model deployment name instead.")
|
|
197
|
+
hlog(f"Available model deployments for model {model_name}: {available_deployment_names}")
|
|
198
|
+
|
|
199
|
+
# Additionally, if there is a non-deprecated deployment, use it.
|
|
200
|
+
non_deprecated_deployments: List[ModelDeployment] = [
|
|
201
|
+
deployment for deployment in available_deployments if not deployment.deprecated
|
|
202
|
+
]
|
|
203
|
+
if len(non_deprecated_deployments) > 0:
|
|
204
|
+
chosen_deployment = non_deprecated_deployments[0]
|
|
205
|
+
# There are no non-deprecated deployments, so there are two options:
|
|
206
|
+
# 1. If we can return an empty string, return it. (no model deployment is available)
|
|
207
|
+
# 2. If we can't return an empty string, return the first deployment (even if it's deprecated).
|
|
208
|
+
elif ignore_deprecated:
|
|
209
|
+
return None
|
|
210
|
+
else:
|
|
211
|
+
chosen_deployment = available_deployments[0]
|
|
212
|
+
if warn_arg_deprecated:
|
|
213
|
+
hlog(f"WARNING: All model deployments for model {model_name} are deprecated.")
|
|
214
|
+
if warn_arg_deprecated:
|
|
215
|
+
hlog(
|
|
216
|
+
f"Choosing {chosen_deployment.name} (the first one) as "
|
|
217
|
+
f"the default model deployment for model {model_name}"
|
|
218
|
+
)
|
|
219
|
+
hlog("If you want to use a different model deployment, please specify it explicitly.")
|
|
220
|
+
return chosen_deployment.name
|
|
221
|
+
|
|
222
|
+
# Some models are added but have no deployments yet.
|
|
223
|
+
# In this case, we return None.
|
|
224
|
+
return None
|
|
@@ -1,50 +1,142 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Optional, List
|
|
1
|
+
from typing import Dict, Optional, List
|
|
3
2
|
from dataclasses import dataclass, field
|
|
4
3
|
from datetime import date
|
|
5
4
|
|
|
6
5
|
import dacite
|
|
7
6
|
import yaml
|
|
8
7
|
|
|
9
|
-
from helm.proxy.models import ALL_MODELS, MODEL_NAME_TO_MODEL, Model
|
|
10
8
|
|
|
9
|
+
# Different modalities
|
|
10
|
+
TEXT_MODEL_TAG: str = "TEXT_MODEL_TAG"
|
|
11
|
+
IMAGE_MODEL_TAG: str = "IMAGE_MODEL_TAG"
|
|
12
|
+
CODE_MODEL_TAG: str = "CODE_MODEL_TAG"
|
|
13
|
+
EMBEDDING_MODEL_TAG: str = "EMBEDDING_MODEL_TAG"
|
|
11
14
|
|
|
12
|
-
|
|
15
|
+
# Some model APIs have limited functionalities
|
|
16
|
+
FULL_FUNCTIONALITY_TEXT_MODEL_TAG: str = "FULL_FUNCTIONALITY_TEXT_MODEL_TAG"
|
|
17
|
+
LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG: str = "LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG"
|
|
13
18
|
|
|
19
|
+
# ChatML format
|
|
20
|
+
CHATML_MODEL_TAG: str = "CHATML_MODEL_TAG"
|
|
14
21
|
|
|
15
|
-
|
|
22
|
+
# OpenAI Chat format
|
|
23
|
+
OPENAI_CHATGPT_MODEL_TAG: str = "OPENAI_CHATGPT_MODEL_TAG"
|
|
24
|
+
|
|
25
|
+
# Mistral instruction-following format
|
|
26
|
+
MISTRAL_MODEL_TAG: str = "MISTRAL_MODEL_TAG"
|
|
27
|
+
|
|
28
|
+
# For Anthropic models
|
|
29
|
+
ANTHROPIC_CLAUDE_1_MODEL_TAG: str = "ANTHROPIC_CLAUDE_1_MODEL_TAG"
|
|
30
|
+
ANTHROPIC_CLAUDE_2_MODEL_TAG: str = "ANTHROPIC_CLAUDE_2_MODEL_TAG"
|
|
31
|
+
ANTHROPIC_CLAUDE_3_MODEL_TAG: str = "ANTHROPIC_CLAUDE_3_MODEL_TAG"
|
|
32
|
+
|
|
33
|
+
GOOGLE_PALM_2_MODEL_TAG: str = "GOOGLE_PALM_2_MODEL_TAG"
|
|
34
|
+
GOOGLE_GEMINI_MODEL_TAG: str = "GOOGLE_GEMINI_MODEL_TAG"
|
|
35
|
+
GOOGLE_GEMMA_INSTRUCT_MODEL_TAG: str = "GOOGLE_GEMMA_INSTRUCT_MODEL_TAG"
|
|
36
|
+
|
|
37
|
+
# Models which emit garbage tokens when temperature=0.
|
|
38
|
+
BUGGY_TEMP_0_TAG: str = "BUGGY_TEMP_0_TAG"
|
|
39
|
+
|
|
40
|
+
# Models that are used for ablations and fine-grained analyses.
|
|
41
|
+
# These models are selected specifically because of their low marginal cost to evaluate.
|
|
42
|
+
ABLATION_MODEL_TAG: str = "ABLATION_MODEL_TAG"
|
|
43
|
+
|
|
44
|
+
# Some models (e.g., T5) have stripped newlines.
|
|
45
|
+
# So we cannot use \n as a stop sequence for these models.
|
|
46
|
+
NO_NEWLINES_TAG: str = "NO_NEWLINES_TAG"
|
|
47
|
+
|
|
48
|
+
# Some models (e.g., UL2) require a prefix (e.g., [NLG]) in the
|
|
49
|
+
# prompts to indicate the mode before doing inference.
|
|
50
|
+
NLG_PREFIX_TAG: str = "NLG_PREFIX_TAG"
|
|
51
|
+
|
|
52
|
+
# Some models can follow instructions.
|
|
53
|
+
INSTRUCTION_FOLLOWING_MODEL_TAG: str = "INSTRUCTION_FOLLOWING_MODEL_TAG"
|
|
54
|
+
|
|
55
|
+
# For text-to-image models
|
|
56
|
+
TEXT_TO_IMAGE_MODEL_TAG: str = "TEXT_TO_IMAGE_MODEL_TAG"
|
|
57
|
+
|
|
58
|
+
# For Vision-langauge models (VLMs)
|
|
59
|
+
VISION_LANGUAGE_MODEL_TAG: str = "VISION_LANGUAGE_MODEL_TAG"
|
|
60
|
+
# IDEFICS require a special prompt format (see `IDEFICSInstructRunExpander`)
|
|
61
|
+
IDEFICS_INSTRUCT_MODEL_TAG: str = "IDEFICS_INSTRUCT_MODEL_TAG"
|
|
62
|
+
IDEFICS_MODEL_TAG: str = "IDEFICS_MODEL_TAG"
|
|
63
|
+
# Llava should use a special prompt format (see `LlavaRunExpander`)
|
|
64
|
+
LLAVA_MODEL_TAG: str = "LLAVA_MODEL_TAG"
|
|
65
|
+
# OpenFlamingo has a special prompt format (see `OpenFlamingoRunExpander`)
|
|
66
|
+
OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG"
|
|
67
|
+
# Some VLMs do not support multiple images in the prompt
|
|
68
|
+
LIMITED_FUNCTIONALITY_VLM_TAG: str = "LIMITED_FUNCTIONALITY_VLM_TAG"
|
|
69
|
+
FULL_FUNCTIONALITY_VLM_TAG: str = "FULL_FUNCTIONALITY_VLM_TAG"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Frozen is set to false as the model_deployment_registry.py file
|
|
73
|
+
# might populate the deployment_names field.
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=False)
|
|
16
77
|
class ModelMetadata:
|
|
17
78
|
name: str
|
|
18
|
-
"""Name of the model e.g
|
|
79
|
+
"""Name of the model group (e.g., "openai/davinci"). This is the name of the model,
|
|
80
|
+
not the name of the deployment.
|
|
81
|
+
Usually formatted as "<creator_organization>/<engine_name>". Example: "ai21/j1-jumbo"."""
|
|
19
82
|
|
|
20
|
-
|
|
21
|
-
"""
|
|
83
|
+
creator_organization_name: str
|
|
84
|
+
"""Name of the organization that created the model."""
|
|
22
85
|
|
|
23
|
-
|
|
24
|
-
"""
|
|
86
|
+
display_name: str
|
|
87
|
+
"""Name that is going to be displayed to the user (on the website, etc.)."""
|
|
25
88
|
|
|
26
|
-
|
|
27
|
-
|
|
89
|
+
description: str
|
|
90
|
+
"""Description of the model, to be displayed on the website."""
|
|
28
91
|
|
|
29
|
-
|
|
30
|
-
"""
|
|
92
|
+
access: str
|
|
93
|
+
"""Description of the access level of the model. Should be one of the following:
|
|
94
|
+
- "open": the model is open-source and can be downloaded from the internet.
|
|
95
|
+
- "closed": not accessible
|
|
96
|
+
- "limited": accessible with an API key.
|
|
97
|
+
If there are multiple deployments, this should be the most permissive access across all deployments."""
|
|
31
98
|
|
|
32
|
-
release_date: Optional[date]
|
|
33
|
-
"""
|
|
99
|
+
release_date: Optional[date]
|
|
100
|
+
"""Release date of the model."""
|
|
34
101
|
|
|
35
|
-
|
|
36
|
-
"""
|
|
102
|
+
tags: List[str] = field(default_factory=list)
|
|
103
|
+
"""Tags corresponding to the properties of the model."""
|
|
37
104
|
|
|
105
|
+
num_parameters: Optional[int] = None
|
|
106
|
+
"""Number of parameters in the model.
|
|
38
107
|
This should be a string as the number of parameters is usually a round number (175B),
|
|
39
108
|
but we set it as an int for plotting purposes."""
|
|
40
109
|
|
|
41
|
-
|
|
42
|
-
"""
|
|
110
|
+
deployment_names: Optional[List[str]] = None
|
|
111
|
+
"""List of the model deployments for this model. Should at least contain one model deployment.
|
|
112
|
+
Refers to the field "name" in the ModelDeployment class. Defaults to a single model deployment
|
|
113
|
+
with the same name as the model."""
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def creator_organization(self) -> str:
|
|
117
|
+
"""
|
|
118
|
+
Extracts the creator organization from the model name.
|
|
119
|
+
Example: 'ai21/j1-jumbo' => 'ai21'
|
|
120
|
+
This can be different from the hosting organization.
|
|
121
|
+
"""
|
|
122
|
+
return self.name.split("/")[0]
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def engine(self) -> str:
|
|
126
|
+
"""
|
|
127
|
+
Extracts the model engine from the model name.
|
|
128
|
+
Example: 'ai21/j1-jumbo' => 'j1-jumbo'
|
|
129
|
+
"""
|
|
130
|
+
return self.name.split("/")[1]
|
|
43
131
|
|
|
44
132
|
|
|
45
133
|
@dataclass(frozen=True)
|
|
46
134
|
class ModelMetadataList:
|
|
47
|
-
models: List[ModelMetadata]
|
|
135
|
+
models: List[ModelMetadata] = field(default_factory=list)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
ALL_MODELS_METADATA: List[ModelMetadata] = []
|
|
139
|
+
MODEL_NAME_TO_MODEL_METADATA: Dict[str, ModelMetadata] = {model.name: model for model in ALL_MODELS_METADATA}
|
|
48
140
|
|
|
49
141
|
|
|
50
142
|
def register_model_metadata_from_path(path: str) -> None:
|
|
@@ -55,17 +147,71 @@ def register_model_metadata_from_path(path: str) -> None:
|
|
|
55
147
|
# serialization format for dates
|
|
56
148
|
model_metadata_list = dacite.from_dict(ModelMetadataList, raw)
|
|
57
149
|
for model_metadata in model_metadata_list.models:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
150
|
+
register_model_metadata(model_metadata)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def register_model_metadata(model_metadata: ModelMetadata) -> None:
|
|
154
|
+
"""Register a single model configuration."""
|
|
155
|
+
ALL_MODELS_METADATA.append(model_metadata)
|
|
156
|
+
MODEL_NAME_TO_MODEL_METADATA[model_metadata.name] = model_metadata
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def get_model_metadata(model_name: str) -> ModelMetadata:
|
|
160
|
+
"""Return the `ModelMetadata` for the model name."""
|
|
161
|
+
if model_name not in MODEL_NAME_TO_MODEL_METADATA:
|
|
162
|
+
raise ValueError(f"No model with name: {model_name}")
|
|
163
|
+
|
|
164
|
+
return MODEL_NAME_TO_MODEL_METADATA[model_name]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get_all_models() -> List[str]:
|
|
168
|
+
"""Return all model names."""
|
|
169
|
+
return list(MODEL_NAME_TO_MODEL_METADATA.keys())
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_model_names_with_tag(tag: str) -> List[str]:
|
|
173
|
+
"""Return all model names of models with the given tag."""
|
|
174
|
+
return [model.name for model in ALL_MODELS_METADATA if tag in model.tags]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def model_has_tag(model_name: str, tag: str) -> bool:
|
|
178
|
+
"""Return True if the model has the given tag. False otherwise."""
|
|
179
|
+
return tag in get_model_metadata(model_name).tags
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_all_text_models() -> List[str]:
|
|
183
|
+
"""Return all model names of text models."""
|
|
184
|
+
return get_model_names_with_tag(TEXT_MODEL_TAG)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def get_all_code_models() -> List[str]:
|
|
188
|
+
"""Return all model names of code models."""
|
|
189
|
+
return get_model_names_with_tag(CODE_MODEL_TAG)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def get_all_instruction_following_models() -> List[str]:
|
|
193
|
+
"""Return all model names of instruction following models."""
|
|
194
|
+
return get_model_names_with_tag(INSTRUCTION_FOLLOWING_MODEL_TAG)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def is_text_to_image_model(model_name: str) -> bool:
|
|
198
|
+
"""Returns True if the model is a text-to-image model. False otherwise."""
|
|
199
|
+
return model_has_tag(model_name, TEXT_TO_IMAGE_MODEL_TAG)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def is_vlm(model_name: str) -> bool:
|
|
203
|
+
"""Returns True if the model is a vision-language model (VLM). False otherwise."""
|
|
204
|
+
return model_has_tag(model_name, VISION_LANGUAGE_MODEL_TAG)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def get_unknown_model_metadata(helm_model_name: str) -> ModelMetadata:
|
|
208
|
+
"""Return placeholder ModelMetadata for an unknown model."""
|
|
209
|
+
return ModelMetadata(
|
|
210
|
+
name=helm_model_name,
|
|
211
|
+
creator_organization_name="Unknown",
|
|
212
|
+
display_name=helm_model_name,
|
|
213
|
+
description=helm_model_name,
|
|
214
|
+
access="open",
|
|
215
|
+
release_date=date.today(),
|
|
216
|
+
tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG],
|
|
217
|
+
)
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import signal
|
|
2
|
+
import threading
|
|
3
|
+
import traceback
|
|
4
|
+
from typing import List
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
import torch
|
|
8
|
+
import torch.multiprocessing as multiprocessing
|
|
9
|
+
from concurrent.futures import ProcessPoolExecutor as Pool
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from helm.benchmark.config_registry import (
|
|
13
|
+
register_configs_from_directory,
|
|
14
|
+
register_builtin_configs_from_helm_package,
|
|
15
|
+
)
|
|
16
|
+
from helm.benchmark.executor import ExecutionSpec
|
|
17
|
+
from helm.benchmark.runner import Runner, RunSpec, RunnerError
|
|
18
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
19
|
+
from helm.benchmark.runner_config_registry import RUNNER_CONFIG
|
|
20
|
+
|
|
21
|
+
_MAX_CONCURRENT_WORKERS_ENV_NAME = "HELM_MAX_CONCURRENT_WORKERS"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# From
|
|
25
|
+
# https://stackoverflow.com/questions/71300294/how-to-terminate-pythons-processpoolexecutor-when-parent-process-dies
|
|
26
|
+
def start_thread_to_terminate_when_parent_process_dies(ppid):
|
|
27
|
+
pid = os.getpid()
|
|
28
|
+
|
|
29
|
+
def f():
|
|
30
|
+
while True:
|
|
31
|
+
try:
|
|
32
|
+
os.kill(ppid, 0)
|
|
33
|
+
except OSError:
|
|
34
|
+
os.kill(pid, signal.SIGTERM)
|
|
35
|
+
time.sleep(1)
|
|
36
|
+
|
|
37
|
+
thread = threading.Thread(target=f, daemon=True)
|
|
38
|
+
thread.start()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def initialize_worker(gpu_id: int):
|
|
42
|
+
hlog(f"Worker {gpu_id} initializing")
|
|
43
|
+
|
|
44
|
+
# Wait for 0.1 seconds to ensure all workers are initialized with different CUDA_VISIBLE_DEVICES
|
|
45
|
+
time.sleep(0.1)
|
|
46
|
+
|
|
47
|
+
# Pin GPU to worker process
|
|
48
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
49
|
+
|
|
50
|
+
# Necessary for code_metrics in humaneval to work properly
|
|
51
|
+
multiprocessing.set_start_method("fork", force=True)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MultiGPURunner(Runner):
|
|
55
|
+
"""Runner that runs the entire benchmark on multiple GPUs.
|
|
56
|
+
|
|
57
|
+
This is a thin wrapper around `Runner` that runs the entire benchmark on
|
|
58
|
+
multiple GPUs using `multiprocessing`.
|
|
59
|
+
|
|
60
|
+
Note that this runner will load multiple models into memory at the same
|
|
61
|
+
time if your running configuration specifies that, similar to the `Runner`
|
|
62
|
+
class. `SlurmRunner` on the other hand will load at most one model on a
|
|
63
|
+
GPU"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
execution_spec: ExecutionSpec,
|
|
68
|
+
output_path: str,
|
|
69
|
+
suite: str,
|
|
70
|
+
skip_instances: bool,
|
|
71
|
+
cache_instances: bool,
|
|
72
|
+
cache_instances_only: bool,
|
|
73
|
+
skip_completed_runs: bool,
|
|
74
|
+
exit_on_error: bool,
|
|
75
|
+
):
|
|
76
|
+
super().__init__(
|
|
77
|
+
execution_spec=execution_spec,
|
|
78
|
+
output_path=output_path,
|
|
79
|
+
suite=suite,
|
|
80
|
+
skip_instances=skip_instances,
|
|
81
|
+
cache_instances=cache_instances,
|
|
82
|
+
cache_instances_only=cache_instances_only,
|
|
83
|
+
skip_completed_runs=skip_completed_runs,
|
|
84
|
+
exit_on_error=exit_on_error,
|
|
85
|
+
)
|
|
86
|
+
# Configure max concurrent worker jobs from the environment variable.
|
|
87
|
+
env_max_concurrent_workers = os.getenv(_MAX_CONCURRENT_WORKERS_ENV_NAME)
|
|
88
|
+
self.max_concurrent_workers = (
|
|
89
|
+
int(env_max_concurrent_workers)
|
|
90
|
+
if env_max_concurrent_workers
|
|
91
|
+
else (
|
|
92
|
+
RUNNER_CONFIG.helm_max_concurrent_workers
|
|
93
|
+
if RUNNER_CONFIG.helm_max_concurrent_workers > 0
|
|
94
|
+
else torch.cuda.device_count()
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def safe_run_one(self, run_spec: RunSpec):
|
|
99
|
+
register_builtin_configs_from_helm_package()
|
|
100
|
+
if self.executor.execution_spec.local_path is not None:
|
|
101
|
+
register_configs_from_directory(self.executor.execution_spec.local_path)
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
with htrack_block(f"Running {run_spec.name}"):
|
|
105
|
+
self.run_one(run_spec)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
hlog(f"Error when running {run_spec.name}:\n{traceback.format_exc()}")
|
|
108
|
+
return e
|
|
109
|
+
|
|
110
|
+
def run_all(self, run_specs: List[RunSpec]):
|
|
111
|
+
"""Run the entire benchmark on multiple GPU"""
|
|
112
|
+
|
|
113
|
+
# Set the start method to forkserver to avoid issues with CUDA.
|
|
114
|
+
multiprocessing.set_start_method("forkserver")
|
|
115
|
+
|
|
116
|
+
with Pool(
|
|
117
|
+
max_workers=self.max_concurrent_workers,
|
|
118
|
+
initializer=start_thread_to_terminate_when_parent_process_dies,
|
|
119
|
+
initargs=(os.getpid(),),
|
|
120
|
+
) as pool:
|
|
121
|
+
# Pin GPUs to each worker process
|
|
122
|
+
pool.map(initialize_worker, [i for i in range(self.max_concurrent_workers)])
|
|
123
|
+
|
|
124
|
+
# Run all queued tasks
|
|
125
|
+
error_msgs = list(tqdm(pool.map(self.safe_run_one, run_specs), total=len(run_specs), disable=None))
|
|
126
|
+
|
|
127
|
+
# Raise exception for failed runs, if any.
|
|
128
|
+
failed_run_names = [
|
|
129
|
+
run_spec.name for error_msg, run_spec in zip(error_msgs, run_specs) if error_msg is not None
|
|
130
|
+
]
|
|
131
|
+
if failed_run_names:
|
|
132
|
+
failed_runs_str = ", ".join([f'"{run_name}"' for run_name in failed_run_names])
|
|
133
|
+
raise RunnerError(f"Failed runs: [{failed_runs_str}]")
|
|
@@ -2,10 +2,10 @@ from dataclasses import dataclass
|
|
|
2
2
|
from typing import List, Optional
|
|
3
3
|
import dacite
|
|
4
4
|
import importlib_resources as resources
|
|
5
|
-
import yaml
|
|
5
|
+
import yaml
|
|
6
6
|
|
|
7
7
|
from helm.common.hierarchical_logger import htrack, hlog
|
|
8
|
-
from helm.
|
|
8
|
+
from helm.benchmark.model_metadata_registry import MODEL_NAME_TO_MODEL_METADATA
|
|
9
9
|
from helm.benchmark.presentation.schema import Schema
|
|
10
10
|
|
|
11
11
|
|
|
@@ -70,7 +70,7 @@ def validate_contamination(contamination: Contamination, schema: Schema):
|
|
|
70
70
|
"""Make sure models and groups in contamination are defined according to `schema`."""
|
|
71
71
|
for point in contamination.points:
|
|
72
72
|
for model in point.models:
|
|
73
|
-
if model not in
|
|
73
|
+
if model not in MODEL_NAME_TO_MODEL_METADATA:
|
|
74
74
|
hlog(f"WARNING: model {model} not defined in schema")
|
|
75
75
|
for group in point.groups:
|
|
76
76
|
if group not in schema.name_to_run_group:
|