crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
from typing import Callable, Dict, List, Optional, Set, Tuple, cast
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import partial
|
|
5
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
6
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
7
|
+
from helm.benchmark.metrics.cleva_metrics_helper import ChineseTokenizer
|
|
8
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
9
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
10
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
11
|
+
from helm.benchmark.scenarios.code_scenario import CodeReference
|
|
12
|
+
from helm.benchmark.scenarios.scenario import Reference
|
|
13
|
+
from helm.common.request import GeneratedOutput
|
|
14
|
+
from helm.benchmark.scenarios.math_scenario import is_equiv, is_equiv_chain_of_thought
|
|
15
|
+
from nltk.metrics.scores import f_measure
|
|
16
|
+
from nltk.translate.bleu_score import sentence_bleu
|
|
17
|
+
from nltk.tokenize import word_tokenize
|
|
18
|
+
from rouge_score import rouge_scorer
|
|
19
|
+
import re
|
|
20
|
+
import string
|
|
21
|
+
from . import code_metrics_helper
|
|
22
|
+
import nltk
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
nltk.data.find("tokenizers/punkt")
|
|
26
|
+
except LookupError:
|
|
27
|
+
nltk.download("punkt") # Required for rouge
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def pass_at_k_estimator(n: int, c: int, k: int) -> float:
|
|
31
|
+
"""Calculates 1 - comb(n - c, k) / comb(n, k).
|
|
32
|
+
|
|
33
|
+
Numerically stable version defined in
|
|
34
|
+
https://arxiv.org/pdf/2107.03374.pdf
|
|
35
|
+
"""
|
|
36
|
+
if n - c < k:
|
|
37
|
+
return 1.0
|
|
38
|
+
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def normalize_text(text: str) -> str:
|
|
42
|
+
"""Lower text and remove punctuation, articles and extra whitespace.
|
|
43
|
+
Copied from the [QuAC](http://quac.ai/) evaluation script found at
|
|
44
|
+
https://s3.amazonaws.com/my89public/quac/scorer.py"""
|
|
45
|
+
|
|
46
|
+
def remove_articles(text: str) -> str:
|
|
47
|
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
48
|
+
|
|
49
|
+
def white_space_fix(text: str) -> str:
|
|
50
|
+
return " ".join(text.split())
|
|
51
|
+
|
|
52
|
+
def remove_punc(text: str) -> str:
|
|
53
|
+
exclude = set(string.punctuation)
|
|
54
|
+
return "".join(ch for ch in text if ch not in exclude)
|
|
55
|
+
|
|
56
|
+
def lower(text: str) -> str:
|
|
57
|
+
return text.lower()
|
|
58
|
+
|
|
59
|
+
return white_space_fix(remove_articles(remove_punc(lower(text))))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def exact_match(gold: str, pred: str) -> float:
|
|
63
|
+
if not pred:
|
|
64
|
+
return 0
|
|
65
|
+
|
|
66
|
+
return 1 if gold.strip() == pred.strip() else 0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def quasi_exact_match(gold: str, pred: str) -> float:
|
|
70
|
+
if not pred:
|
|
71
|
+
return 0
|
|
72
|
+
|
|
73
|
+
return 1 if normalize_text(gold) == normalize_text(pred) else 0
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def prefix_exact_match(gold: str, pred: str) -> float:
|
|
77
|
+
"""
|
|
78
|
+
The `prefix_exact_match` metric is particularly useful in the zero-shot setting, where the model is
|
|
79
|
+
not given examples of the expected outputs and tends to output more tokens than it should.
|
|
80
|
+
|
|
81
|
+
For example, for this zero-shot prompt from BoolQ,
|
|
82
|
+
|
|
83
|
+
Passage: Elmendorf Air Force Base (IATA: EDF, ICAO: PAED, FAA LID: EDF) is a United States military facility
|
|
84
|
+
in Anchorage, the largest city in Alaska. Originally known as Elmendorf Field, it became Elmendorf Air Force
|
|
85
|
+
Base after World War II, and in 2010 it merged with nearby Fort Richardson to form Joint Base Elmendorf-Richardson.
|
|
86
|
+
Question: Is there an air force base in anchorage alaska?
|
|
87
|
+
Answer:
|
|
88
|
+
|
|
89
|
+
the model could output up to `max_tokens` number of tokens "Yes, Elmendorf" instead of just "Yes".
|
|
90
|
+
"""
|
|
91
|
+
if not pred:
|
|
92
|
+
return 0
|
|
93
|
+
|
|
94
|
+
return 1 if pred.strip().startswith(gold.strip()) else 0
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def quasi_prefix_exact_match(gold: str, pred: str) -> float:
|
|
98
|
+
"""
|
|
99
|
+
Same thing as `prefix_exact_match` but we normalize the text before checking if the prefix match.
|
|
100
|
+
"""
|
|
101
|
+
if not pred:
|
|
102
|
+
return 0
|
|
103
|
+
|
|
104
|
+
return 1 if normalize_text(pred).startswith(normalize_text(gold)) else 0
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def f1_score(gold: str, pred: str) -> float:
|
|
108
|
+
ret = f_measure(set(normalize_text(gold).split()), set(normalize_text(pred).split()))
|
|
109
|
+
if ret is None: # answer is the empty string after normalizing
|
|
110
|
+
return 0.0
|
|
111
|
+
|
|
112
|
+
return ret
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def exact_match_indicator(gold: str, pred: str, indicator: str = " ") -> float:
|
|
116
|
+
"""
|
|
117
|
+
Exact match, allowing for some preceding context.
|
|
118
|
+
For example, the following two answers are considered matching:
|
|
119
|
+
- Because of x and y, the answer is ## <answer>
|
|
120
|
+
- Given reasons y and z, the answer is ## <answer>
|
|
121
|
+
While the following is considered different from the earlier two
|
|
122
|
+
- Given reasons x and a, the answer is ## <other answer>
|
|
123
|
+
"""
|
|
124
|
+
pred = pred.split(indicator)[-1].strip()
|
|
125
|
+
gold = gold.split(indicator)[-1].strip()
|
|
126
|
+
return exact_match(gold, pred)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def final_number_exact_match(gold: str, pred: str) -> float:
|
|
130
|
+
"""
|
|
131
|
+
Returns 1 iff the final number in gold and pred match.
|
|
132
|
+
Similar to exact_match_indicator.
|
|
133
|
+
Example:
|
|
134
|
+
- gold = "The answer is 15."
|
|
135
|
+
- pred = "The answer is 15 eggs."
|
|
136
|
+
- Returns 1
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def get_final_number(x: str) -> str:
|
|
140
|
+
matches = re.findall(r"-?[\d,]+(?:.\d+)?", x)
|
|
141
|
+
if not matches:
|
|
142
|
+
return ""
|
|
143
|
+
return matches[-1].replace(",", "")
|
|
144
|
+
|
|
145
|
+
return exact_match(get_final_number(gold), get_final_number(pred))
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def rouge_score(gold: str, pred: str, rouge_type: str, scorer: rouge_scorer.RougeScorer) -> float:
|
|
149
|
+
scores = scorer.score(gold, pred)
|
|
150
|
+
return scores[rouge_type].fmeasure
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_rouge_function(rouge_type: str) -> Callable[[str, str], float]:
|
|
154
|
+
scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True)
|
|
155
|
+
return partial(rouge_score, scorer=scorer, rouge_type=rouge_type)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def bleu_1(gold: str, pred: str) -> float:
|
|
159
|
+
return sentence_bleu([word_tokenize(gold)], word_tokenize(pred), weights=(1, 0, 0, 0))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def chinese_bleu_1(gold: str, pred: str) -> float:
|
|
163
|
+
char_tokenizer = ChineseTokenizer()
|
|
164
|
+
return sentence_bleu([char_tokenizer.tokenize(gold)], char_tokenizer.tokenize(pred), weights=(1, 0, 0, 0))
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get_chinese_rouge_function(rouge_type: str) -> Callable[[str, str], float]:
|
|
168
|
+
char_tokenizer = ChineseTokenizer()
|
|
169
|
+
scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True, tokenizer=char_tokenizer)
|
|
170
|
+
return partial(rouge_score, scorer=scorer, rouge_type=rouge_type)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def cleva_math_result_match(gold: str, pred: str) -> float:
|
|
174
|
+
"""
|
|
175
|
+
Exact match that only cares the last math expression.
|
|
176
|
+
Common math expressions are numbers and fractions.
|
|
177
|
+
"""
|
|
178
|
+
pattern = r"[-+*/%\.\(\)\d]+"
|
|
179
|
+
matches = re.findall(pattern, pred)
|
|
180
|
+
if matches:
|
|
181
|
+
pred = matches[-1].lstrip(")")
|
|
182
|
+
# remove space in front or at the end
|
|
183
|
+
pred = pred.strip()
|
|
184
|
+
return exact_match(gold, pred)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def bleu_4(gold: str, pred: str) -> float:
|
|
188
|
+
return sentence_bleu([word_tokenize(gold)], word_tokenize(pred), weights=(0, 0, 0, 1))
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def extract_set_from_text(
|
|
192
|
+
set_str: str,
|
|
193
|
+
set_start_str: str = " is ",
|
|
194
|
+
set_separator: str = " and ",
|
|
195
|
+
empty_set_str: str = "Nothing.",
|
|
196
|
+
) -> Set[str]:
|
|
197
|
+
"""
|
|
198
|
+
Given a string, extract the set of strings implied by that string.
|
|
199
|
+
set_start_str denotes the start of the set
|
|
200
|
+
set_separator denotes the string separating set elements
|
|
201
|
+
empty_set_str is the string which denotes the empty set
|
|
202
|
+
"""
|
|
203
|
+
if set_str == empty_set_str:
|
|
204
|
+
return set()
|
|
205
|
+
set_str = set_str.replace(".", "")
|
|
206
|
+
extracted_set = set(set_str.split(set_start_str)[-1].split(set_separator))
|
|
207
|
+
return extracted_set
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def extract_gold_pred_sets(gold: str, pred: str) -> Tuple[Set[str], Set[str]]:
|
|
211
|
+
"""Extract the set of strings implied by the gold and pred strings"""
|
|
212
|
+
gold_set = extract_set_from_text(gold)
|
|
213
|
+
pred_set = extract_set_from_text(pred.split("\n")[0])
|
|
214
|
+
return gold_set, pred_set
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def iou_set_match(gold: str, pred: str) -> float:
|
|
218
|
+
"""Compute the intersection over union of the gold and pred sets"""
|
|
219
|
+
gold_set, pred_set = extract_gold_pred_sets(gold, pred)
|
|
220
|
+
if len(gold_set) == 0: # If gold is empty, just check if the pred set is also empty
|
|
221
|
+
return float(gold_set == pred_set)
|
|
222
|
+
return len(gold_set.intersection(pred_set)) / len(gold_set.union(pred_set))
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def f1_set_match(gold: str, pred: str) -> float:
|
|
226
|
+
"""Compute the F1 score of the gold and pred sets"""
|
|
227
|
+
gold_set, pred_set = extract_gold_pred_sets(gold, pred)
|
|
228
|
+
if len(gold_set) == 0: # If gold is empty, just check if the pred set is also empty
|
|
229
|
+
return float(gold_set == pred_set)
|
|
230
|
+
true_positives = gold_set.intersection(pred_set)
|
|
231
|
+
return 2 * len(true_positives) / (len(gold_set) + len(pred_set))
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def exact_set_match(gold: str, pred: str) -> float:
|
|
235
|
+
"""Compute whether the sets generated exactly match"""
|
|
236
|
+
gold_set, pred_set = extract_gold_pred_sets(gold, pred)
|
|
237
|
+
return float(gold_set == pred_set)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def absolute_value_difference(gold: str, pred: str) -> float:
|
|
241
|
+
"""Compute the absolute value of the difference between two numbers (provided as strings),
|
|
242
|
+
or 0.0 if invalid input.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def maybe_int(text: str):
|
|
246
|
+
"""Parse int, ignoring commas in numbers."""
|
|
247
|
+
try:
|
|
248
|
+
val = int(text.replace(",", ""))
|
|
249
|
+
except ValueError:
|
|
250
|
+
return 0.0
|
|
251
|
+
return val
|
|
252
|
+
|
|
253
|
+
gold_val = maybe_int(gold)
|
|
254
|
+
pred_val = maybe_int(pred)
|
|
255
|
+
return abs(gold_val - pred_val)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def code_eval(gold: Tuple[str, Optional[Dict]], pred: str) -> float:
|
|
259
|
+
"""Evaluate Code Correctness on test examples."""
|
|
260
|
+
assert gold[1] is not None # gold[1]["canonical_solution"]
|
|
261
|
+
# Warning: will execute machine generated code; need to sandbox before executing
|
|
262
|
+
return float(code_metrics_helper.check_correctness(gold[1], pred, 3.0)["passed"]) # type: ignore
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
# TODO This should probably be made into an implementation of MetricInterface. For now it lives here
|
|
266
|
+
# just to separate it from basic_metrics.py.
|
|
267
|
+
def compute_reference_metrics(
|
|
268
|
+
names: List[str], adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
|
|
269
|
+
) -> List[Stat]:
|
|
270
|
+
"""
|
|
271
|
+
Setup:
|
|
272
|
+
|
|
273
|
+
- Gold (correct references): G1 ... Gm
|
|
274
|
+
- Predictions (completions): P1 ... Pk
|
|
275
|
+
|
|
276
|
+
For each pair (G, P), we can define a ${score} (e.g., exact match, F1, BLEU).
|
|
277
|
+
|
|
278
|
+
We define the following stats:
|
|
279
|
+
|
|
280
|
+
- ${score}: max_i score(Gi, P1)
|
|
281
|
+
- ${score}@k: max_{i,j} score(Gi, Pj)
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def compute_metrics_helper(
|
|
285
|
+
name: MetricName,
|
|
286
|
+
score_func: Callable,
|
|
287
|
+
group: Optional[str] = None,
|
|
288
|
+
) -> List[Stat]:
|
|
289
|
+
if name.name == "pass": # Calculate pass@k for HumanEval from CodeScenario.
|
|
290
|
+
score_func = cast(Callable[[Tuple[str, Optional[Dict]], str], float], score_func) # Make mypy happy.
|
|
291
|
+
code_golds = cast(List[CodeReference], golds)
|
|
292
|
+
results = [score_func((gold.output.text, gold.test_cases), pred) for gold in code_golds for pred in preds]
|
|
293
|
+
_len, _sum = len(results), int(sum(results)) # Cast to int to make type match.
|
|
294
|
+
score_1 = pass_at_k_estimator(_len, _sum, 1)
|
|
295
|
+
score_k = pass_at_k_estimator(_len, _sum, adapter_spec.num_outputs)
|
|
296
|
+
elif name.name == "code_eval_acc":
|
|
297
|
+
score_func = cast(Callable[[Tuple[str, Optional[Dict]], str], float], score_func) # Make mypy happy.
|
|
298
|
+
code_golds = cast(List[CodeReference], golds)
|
|
299
|
+
score_1 = max(score_func((gold.output.text, gold.test_cases), preds[0]) for gold in code_golds)
|
|
300
|
+
score_k = max(
|
|
301
|
+
score_func((gold.output.text, gold.test_cases), pred) for gold in code_golds for pred in preds
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
score_func = cast(Callable[[str, str], float], score_func) # Make mypy happy.
|
|
305
|
+
score_1 = max(score_func(gold.output.text, preds[0]) for gold in golds)
|
|
306
|
+
score_k = max(score_func(gold.output.text, pred) for gold in golds for pred in preds)
|
|
307
|
+
|
|
308
|
+
metrics = [Stat(name).add(score_1)] # score_1 corresponds using one prediction
|
|
309
|
+
if adapter_spec.num_outputs != 1:
|
|
310
|
+
metrics.append(Stat(replace(name, name=f"{name.name}@{adapter_spec.num_outputs}")).add(score_k))
|
|
311
|
+
return metrics
|
|
312
|
+
|
|
313
|
+
# maps each string metric name to its associated function
|
|
314
|
+
metric_fn_mapping: Dict[str, Callable] = {
|
|
315
|
+
"exact_match": exact_match,
|
|
316
|
+
"quasi_exact_match": quasi_exact_match,
|
|
317
|
+
"prefix_exact_match": prefix_exact_match,
|
|
318
|
+
"quasi_prefix_exact_match": quasi_prefix_exact_match,
|
|
319
|
+
"exact_match_indicator": exact_match_indicator,
|
|
320
|
+
"final_number_exact_match": final_number_exact_match,
|
|
321
|
+
"exact_set_match": exact_set_match,
|
|
322
|
+
"iou_set_match": iou_set_match,
|
|
323
|
+
"f1_set_match": f1_set_match,
|
|
324
|
+
"math_equiv": is_equiv,
|
|
325
|
+
"math_equiv_chain_of_thought": is_equiv_chain_of_thought,
|
|
326
|
+
"code_eval_acc": code_eval,
|
|
327
|
+
"pass": code_eval,
|
|
328
|
+
"f1_score": f1_score,
|
|
329
|
+
"rouge_1": get_rouge_function("rouge1"),
|
|
330
|
+
"rouge_2": get_rouge_function("rouge2"),
|
|
331
|
+
"rouge_l": get_rouge_function("rougeL"),
|
|
332
|
+
"bleu_1": bleu_1,
|
|
333
|
+
"bleu_4": bleu_4,
|
|
334
|
+
"chinese_bleu_1": chinese_bleu_1,
|
|
335
|
+
"chinese_rouge_1": get_chinese_rouge_function("rouge1"),
|
|
336
|
+
"chinese_rouge_2": get_chinese_rouge_function("rouge2"),
|
|
337
|
+
"cleva_math_result_match": cleva_math_result_match,
|
|
338
|
+
"absolute_value_difference": absolute_value_difference,
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
stats: List[Stat] = []
|
|
342
|
+
|
|
343
|
+
# Gold outputs
|
|
344
|
+
golds: List[Reference] = [reference for reference in request_state.instance.references if reference.is_correct]
|
|
345
|
+
assert len(golds) > 0
|
|
346
|
+
|
|
347
|
+
# Predicted outputs
|
|
348
|
+
assert request_state.result is not None
|
|
349
|
+
sorted_completions: List[GeneratedOutput] = sorted(request_state.result.completions, key=lambda x: -x.logprob)
|
|
350
|
+
preds: List[str] = [completion.text.strip() for completion in sorted_completions]
|
|
351
|
+
|
|
352
|
+
# Apply mapping if exists (e.g., for multiple-choice questions A -> Boston, B -> New York)
|
|
353
|
+
# Note: If 'A' and 'B' were the only possible choices, smaller language models like GPT-2 would
|
|
354
|
+
# sometimes predict a random letter like 'M'.
|
|
355
|
+
if request_state.output_mapping is not None:
|
|
356
|
+
preds = [request_state.output_mapping.get(pred) for pred in preds] # type: ignore
|
|
357
|
+
|
|
358
|
+
# Compute max_prob, the probability that the model assigns to its generated text.
|
|
359
|
+
# Use the log prob of sorted_completions[0], which is the completion with the highest
|
|
360
|
+
# log_prob. We use this since that's what's used for computing metrics like exact_match.
|
|
361
|
+
# One subtlety is that when computing exact_match, we strip whitespace, so the actual
|
|
362
|
+
# max_prob is the sum of all the probabilities in the set {x : strip(x) = prediction}.
|
|
363
|
+
# In practice, we think this may not make much of a difference because models may not place
|
|
364
|
+
# high probabilities on having additional spaces (should check this). Also, the sum
|
|
365
|
+
# involves computing the log_prob for many completions which could be intractable.
|
|
366
|
+
max_prob = np.exp(sorted_completions[0].logprob)
|
|
367
|
+
stats.append(Stat(MetricName("max_prob")).add(max_prob))
|
|
368
|
+
|
|
369
|
+
# Add other metrics
|
|
370
|
+
for metric_name in names:
|
|
371
|
+
if metric_name in metric_fn_mapping:
|
|
372
|
+
stats.extend(compute_metrics_helper(MetricName(metric_name), metric_fn_mapping[metric_name]))
|
|
373
|
+
else:
|
|
374
|
+
raise NameError(f"{metric_name} is not in the list of metric functions.")
|
|
375
|
+
|
|
376
|
+
return stats
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from statistics import mean
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from helm.common.images_utils import is_blacked_out_image
|
|
5
|
+
from helm.common.request import RequestResult
|
|
6
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
7
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
8
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
9
|
+
from helm.benchmark.metrics.metric import Metric
|
|
10
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
11
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
12
|
+
from .aesthetics_scorer import AestheticsScorer
|
|
13
|
+
from helm.common.multimodal_request_utils import gather_generated_image_locations
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AestheticsMetric(Metric):
|
|
17
|
+
"""
|
|
18
|
+
Defines metrics for LAION's CLIP-based aesthetics predictor for images
|
|
19
|
+
(https://github.com/LAION-AI/aesthetic-predictor).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self._aesthetics_scorer: Optional[AestheticsScorer] = None
|
|
24
|
+
|
|
25
|
+
def __repr__(self):
|
|
26
|
+
return "AestheticsMetric()"
|
|
27
|
+
|
|
28
|
+
def evaluate_generation(
|
|
29
|
+
self,
|
|
30
|
+
adapter_spec: AdapterSpec,
|
|
31
|
+
request_state: RequestState,
|
|
32
|
+
metric_service: MetricService,
|
|
33
|
+
eval_cache_path: str,
|
|
34
|
+
) -> List[Stat]:
|
|
35
|
+
assert request_state.result is not None
|
|
36
|
+
request_result: RequestResult = request_state.result
|
|
37
|
+
image_locations: List[str] = gather_generated_image_locations(request_result)
|
|
38
|
+
if len(image_locations) == 0:
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
if self._aesthetics_scorer is None:
|
|
42
|
+
self._aesthetics_scorer = AestheticsScorer()
|
|
43
|
+
|
|
44
|
+
# Compute the aesthetics score for each generated image. Skip blacked out images.
|
|
45
|
+
scores: List[float] = [
|
|
46
|
+
self._aesthetics_scorer.compute_aesthetics_score(location)
|
|
47
|
+
for location in image_locations
|
|
48
|
+
if not is_blacked_out_image(location)
|
|
49
|
+
]
|
|
50
|
+
stats: List[Stat] = [
|
|
51
|
+
Stat(MetricName("expected_aesthetics_score")).add(mean(scores) if len(scores) > 0 else 0),
|
|
52
|
+
Stat(MetricName("max_aesthetics_score")).add(max(scores) if len(scores) > 0 else 0),
|
|
53
|
+
]
|
|
54
|
+
return stats
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from urllib.request import urlretrieve
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from helm.common.general import ensure_directory_exists
|
|
7
|
+
from helm.common.gpu_utils import get_torch_device
|
|
8
|
+
from helm.common.images_utils import open_image
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
|
+
from helm.benchmark.runner import get_cached_models_path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AestheticsScorer:
|
|
14
|
+
"""
|
|
15
|
+
LAION's CLIP-based aesthetics predictor for images (https://github.com/LAION-AI/aesthetic-predictor).
|
|
16
|
+
Adapted from
|
|
17
|
+
https://colab.research.google.com/github/LAION-AI/aesthetic-predictor/blob/main/asthetics_predictor.ipynb.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
MODEL_URL_TEMPLATE: str = (
|
|
21
|
+
"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_{clip_model}_linear.pth?raw=true"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def load_model(clip_model="vit_l_14"):
|
|
26
|
+
"""Load the aesthetics model."""
|
|
27
|
+
cache_folder: str = os.path.join(get_cached_models_path(), "emb_reader")
|
|
28
|
+
ensure_directory_exists(cache_folder)
|
|
29
|
+
model_path: str = os.path.join(cache_folder, f"sa_0_4_{clip_model}_linear.pth")
|
|
30
|
+
|
|
31
|
+
if not os.path.exists(model_path):
|
|
32
|
+
model_url: str = os.path.join(AestheticsScorer.MODEL_URL_TEMPLATE.format(clip_model=clip_model))
|
|
33
|
+
urlretrieve(model_url, model_path)
|
|
34
|
+
|
|
35
|
+
if clip_model == "vit_l_14":
|
|
36
|
+
m = torch.nn.Linear(768, 1)
|
|
37
|
+
elif clip_model == "vit_b_32":
|
|
38
|
+
m = torch.nn.Linear(512, 1)
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Invalid model: {clip_model}")
|
|
41
|
+
|
|
42
|
+
s = torch.load(model_path)
|
|
43
|
+
m.load_state_dict(s)
|
|
44
|
+
m.eval()
|
|
45
|
+
return m
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
try:
|
|
49
|
+
import clip
|
|
50
|
+
except ModuleNotFoundError as e:
|
|
51
|
+
handle_module_not_found_error(e, ["heim"])
|
|
52
|
+
|
|
53
|
+
# Load the CLIP and aesthetics model
|
|
54
|
+
self._device: torch.device = get_torch_device()
|
|
55
|
+
self._model, self._preprocess = clip.load("ViT-L/14", device=self._device)
|
|
56
|
+
self._aesthetics_model = self.load_model().to(self._device)
|
|
57
|
+
|
|
58
|
+
def compute_aesthetics_score(self, image_location: str) -> float:
|
|
59
|
+
"""
|
|
60
|
+
Compute the aesthetics score. Returns a value between 1 and 10.
|
|
61
|
+
"""
|
|
62
|
+
image = self._preprocess(open_image(image_location)).unsqueeze(0).to(self._device)
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
image_features = self._model.encode_image(image)
|
|
65
|
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
66
|
+
return self._aesthetics_model(image_features.float()).detach().item()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from statistics import mean
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from helm.common.general import singleton
|
|
5
|
+
from helm.common.request import RequestResult
|
|
6
|
+
from helm.common.clip_score_request import DEFAULT_CLIP_SCORE_MODEL, CLIPScoreResult, CLIPScoreRequest
|
|
7
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
8
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
9
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
10
|
+
from helm.benchmark.metrics.metric import Metric
|
|
11
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
12
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
13
|
+
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
|
|
14
|
+
from helm.common.images_utils import is_blacked_out_image
|
|
15
|
+
from helm.common.multimodal_request_utils import gather_generated_image_locations
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CLIPScoreMetric(Metric):
|
|
19
|
+
"""
|
|
20
|
+
Defines CLIPScore-based metrics (https://arxiv.org/abs/2104.08718).
|
|
21
|
+
CLIPScore is a reference free metric that can be used to evaluate the correlation between an image
|
|
22
|
+
caption and the content of the image. It has been found to be highly correlated with human judgement.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, multilingual: bool = False):
|
|
26
|
+
self._multilingual: bool = multilingual
|
|
27
|
+
|
|
28
|
+
def __repr__(self):
|
|
29
|
+
return f"CLIPScoreMetric(multilingual={self._multilingual})"
|
|
30
|
+
|
|
31
|
+
def evaluate_generation(
|
|
32
|
+
self,
|
|
33
|
+
adapter_spec: AdapterSpec,
|
|
34
|
+
request_state: RequestState,
|
|
35
|
+
metric_service: MetricService,
|
|
36
|
+
eval_cache_path: str,
|
|
37
|
+
) -> List[Stat]:
|
|
38
|
+
def get_metric_name(base_name: str) -> str:
|
|
39
|
+
if self._multilingual:
|
|
40
|
+
base_name = f"{base_name}_multilingual"
|
|
41
|
+
return base_name
|
|
42
|
+
|
|
43
|
+
assert request_state.result is not None
|
|
44
|
+
request_result: RequestResult = request_state.result
|
|
45
|
+
|
|
46
|
+
prompt: str = request_state.request.prompt
|
|
47
|
+
perturbation_name: str = request_state.instance.perturbation.name if request_state.instance.perturbation else ""
|
|
48
|
+
if (
|
|
49
|
+
request_state.instance.contrast_inputs is not None
|
|
50
|
+
and len(request_state.instance.contrast_inputs) > 0
|
|
51
|
+
and perturbation_name in ["translate", "dialect", "mild_mix"]
|
|
52
|
+
):
|
|
53
|
+
prompt = singleton(request_state.instance.contrast_inputs).text
|
|
54
|
+
|
|
55
|
+
# Truncate the prompt using the CLIP tokenizer before feeding into the CLIP model.
|
|
56
|
+
# Otherwise, the library will throw an error.
|
|
57
|
+
model = DEFAULT_CLIP_SCORE_MODEL
|
|
58
|
+
prompt = WindowServiceFactory.get_window_service(model, metric_service).truncate_from_right(prompt)
|
|
59
|
+
|
|
60
|
+
scores: List[float] = []
|
|
61
|
+
image_locations: List[str] = gather_generated_image_locations(request_result)
|
|
62
|
+
for location in image_locations:
|
|
63
|
+
if not is_blacked_out_image(location):
|
|
64
|
+
result: CLIPScoreResult = metric_service.compute_clip_score(
|
|
65
|
+
CLIPScoreRequest(prompt, location, model=model, multilingual=self._multilingual)
|
|
66
|
+
)
|
|
67
|
+
scores.append(result.score)
|
|
68
|
+
|
|
69
|
+
stats: List[Stat] = [
|
|
70
|
+
Stat(MetricName(get_metric_name("expected_clip_score"))).add(mean(scores) if len(scores) > 0 else 0),
|
|
71
|
+
Stat(MetricName(get_metric_name("max_clip_score"))).add(max(scores) if len(scores) > 0 else 0),
|
|
72
|
+
]
|
|
73
|
+
return stats
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
from typing import Dict
|
|
4
|
+
import math
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from helm.common.request import RequestResult
|
|
8
|
+
from helm.benchmark.scenarios.scenario import Instance
|
|
9
|
+
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
10
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
11
|
+
from helm.benchmark.metrics.metric import MetricInterface, MetricResult
|
|
12
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
13
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DenoisedRuntimeMetric(MetricInterface):
|
|
17
|
+
def __repr__(self):
|
|
18
|
+
return "DenoisedRuntimeMetric()"
|
|
19
|
+
|
|
20
|
+
def evaluate(
|
|
21
|
+
self,
|
|
22
|
+
scenario_state: ScenarioState,
|
|
23
|
+
metric_service: MetricService,
|
|
24
|
+
eval_cache_path: str,
|
|
25
|
+
parallelism: int,
|
|
26
|
+
) -> MetricResult:
|
|
27
|
+
|
|
28
|
+
instance_to_min_request_times: Dict[Instance, float] = defaultdict(lambda: math.inf)
|
|
29
|
+
for request_state in tqdm(scenario_state.request_states):
|
|
30
|
+
assert request_state.result is not None
|
|
31
|
+
request_result: RequestResult = request_state.result
|
|
32
|
+
|
|
33
|
+
assert request_result.request_time is not None
|
|
34
|
+
request_time: float = request_result.request_time
|
|
35
|
+
|
|
36
|
+
instance: Instance = request_state.instance
|
|
37
|
+
instance_to_min_request_times[instance] = min(instance_to_min_request_times[instance], request_time)
|
|
38
|
+
|
|
39
|
+
denoised_runtime: float = float(np.mean(list(instance_to_min_request_times.values())))
|
|
40
|
+
return MetricResult(
|
|
41
|
+
aggregated_stats=[Stat(MetricName("denoised_runtime")).add(denoised_runtime)], per_instance_stats=[]
|
|
42
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
import json
|
|
3
|
+
from statistics import mean
|
|
4
|
+
|
|
5
|
+
from helm.common.request import RequestResult
|
|
6
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
7
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
8
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
9
|
+
from helm.benchmark.metrics.metric import Metric
|
|
10
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
11
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
12
|
+
from helm.common.multimodal_request_utils import gather_generated_image_locations
|
|
13
|
+
from .detectors.vitdet import ViTDetDetector
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DetectionMetric(Metric):
|
|
17
|
+
"""
|
|
18
|
+
Define metrics following DALL-EVAL (https://arxiv.org/abs/2202.04053),
|
|
19
|
+
which measure whether generated images contain the correct objects, counts, and relations
|
|
20
|
+
as specified in input text prompts.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
self._detection_model = None
|
|
25
|
+
|
|
26
|
+
def __repr__(self):
|
|
27
|
+
return "DetectionMetric()"
|
|
28
|
+
|
|
29
|
+
def evaluate_generation(
|
|
30
|
+
self,
|
|
31
|
+
adapter_spec: AdapterSpec,
|
|
32
|
+
request_state: RequestState,
|
|
33
|
+
metric_service: MetricService,
|
|
34
|
+
eval_cache_path: str,
|
|
35
|
+
) -> List[Stat]:
|
|
36
|
+
assert request_state.result is not None
|
|
37
|
+
request_result: RequestResult = request_state.result
|
|
38
|
+
image_locations: List[str] = gather_generated_image_locations(request_result)
|
|
39
|
+
if len(image_locations) == 0:
|
|
40
|
+
return []
|
|
41
|
+
|
|
42
|
+
if self._detection_model is None:
|
|
43
|
+
self._detection_model = ViTDetDetector()
|
|
44
|
+
|
|
45
|
+
instance = request_state.instance
|
|
46
|
+
references: Dict[str, Any] = {**json.loads(instance.references[0].output.text), "skill": instance.sub_split}
|
|
47
|
+
|
|
48
|
+
prompt: str = request_state.request.prompt
|
|
49
|
+
scores: List[float] = []
|
|
50
|
+
for image_location in image_locations:
|
|
51
|
+
score: float = self._detection_model.compute_score(prompt, image_location, references)
|
|
52
|
+
scores.append(score)
|
|
53
|
+
|
|
54
|
+
stats: List[Stat] = [
|
|
55
|
+
Stat(MetricName("detection_correct_frac")).add(mean(scores) if len(scores) > 0 else 0),
|
|
56
|
+
]
|
|
57
|
+
return stats
|