crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from helm.benchmark.runner import get_cached_models_path
|
|
7
|
+
from helm.common.general import ensure_file_downloaded, hlog
|
|
8
|
+
from helm.common.images_utils import open_image
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device
|
|
11
|
+
from .base_detector import BaseDetector
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
MODEL_CONFIG_DOWNLOAD_URL: str = "https://drive.google.com/uc?id=1MLuwQ0ZN0gJQ42oVCc0aFz6Rneb1g3Rt"
|
|
15
|
+
MODEL_CHECKPOINT_DOWNLOAD_URL: str = (
|
|
16
|
+
"https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/mask_rcnn_vitdet_b/f325346929/model_final_61ccd1.pkl"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ViTDetDetector(BaseDetector):
|
|
21
|
+
def __init__(self):
|
|
22
|
+
try:
|
|
23
|
+
from detectron2.checkpoint import DetectionCheckpointer
|
|
24
|
+
from detectron2.config import LazyConfig
|
|
25
|
+
from detectron2.config import instantiate
|
|
26
|
+
from detectron2.data.catalog import MetadataCatalog
|
|
27
|
+
except ModuleNotFoundError as e:
|
|
28
|
+
handle_module_not_found_error(e, ["heim"])
|
|
29
|
+
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
cache_path: str = get_cached_models_path()
|
|
33
|
+
cfg_path: str = os.path.join(cache_path, "vitdet_model.yaml")
|
|
34
|
+
ensure_file_downloaded(source_url=MODEL_CONFIG_DOWNLOAD_URL, target_path=cfg_path)
|
|
35
|
+
cfg = LazyConfig.load(cfg_path)
|
|
36
|
+
|
|
37
|
+
model_path: str = os.path.join(cache_path, "vitdet_model.pkl")
|
|
38
|
+
ensure_file_downloaded(source_url=MODEL_CHECKPOINT_DOWNLOAD_URL, target_path=model_path)
|
|
39
|
+
cfg.train.init_checkpoint = model_path
|
|
40
|
+
|
|
41
|
+
model = instantiate(cfg.model).cuda()
|
|
42
|
+
model = model.eval()
|
|
43
|
+
for p in model.parameters():
|
|
44
|
+
p.requires_grad = False
|
|
45
|
+
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
|
|
46
|
+
|
|
47
|
+
self._cfg = cfg
|
|
48
|
+
self._model = model
|
|
49
|
+
self._device: torch.device = get_torch_device()
|
|
50
|
+
hlog("Initialized the ViTDet model.")
|
|
51
|
+
|
|
52
|
+
# COCO classes
|
|
53
|
+
self._coco_classes = MetadataCatalog.get("coco_2017_val").thing_classes
|
|
54
|
+
|
|
55
|
+
def forward_model(self, image_location: str) -> float:
|
|
56
|
+
try:
|
|
57
|
+
from detectron2.data.common import DatasetFromList, MapDataset
|
|
58
|
+
from detectron2.config import instantiate
|
|
59
|
+
except ModuleNotFoundError as e:
|
|
60
|
+
handle_module_not_found_error(e, ["heim"])
|
|
61
|
+
|
|
62
|
+
image = open_image(image_location)
|
|
63
|
+
dataset_dicts = [
|
|
64
|
+
{
|
|
65
|
+
"file_name": image_location,
|
|
66
|
+
"width": image.width,
|
|
67
|
+
"height": image.height,
|
|
68
|
+
}
|
|
69
|
+
]
|
|
70
|
+
dataset = DatasetFromList(dataset_dicts, copy=False)
|
|
71
|
+
mapper = instantiate(self._cfg.dataloader.test.mapper)
|
|
72
|
+
dataset = MapDataset(dataset, mapper)
|
|
73
|
+
inputs = [dataset[0]]
|
|
74
|
+
outputs = self._model(inputs)
|
|
75
|
+
return outputs[0]["instances"]
|
|
76
|
+
|
|
77
|
+
def compute_score(self, caption: str, image_location: str, references: Dict[str, Any]) -> float:
|
|
78
|
+
# hlog(f'compute score for prompt: {caption}, file: {image_location}, skill: {references["skill"]}')
|
|
79
|
+
instances = self.forward_model(image_location)
|
|
80
|
+
if references["skill"] == "object":
|
|
81
|
+
return self.compute_score_object(instances, references)
|
|
82
|
+
if references["skill"] == "count":
|
|
83
|
+
return self.compute_score_count(instances, references)
|
|
84
|
+
if references["skill"] == "spatial":
|
|
85
|
+
return self.compute_score_spatial(instances, references)
|
|
86
|
+
raise NotImplementedError(references["skill"])
|
|
87
|
+
|
|
88
|
+
def compute_score_object(self, instances, references):
|
|
89
|
+
gt_class_name = references["object"]
|
|
90
|
+
gt_class = self._coco_classes.index(gt_class_name)
|
|
91
|
+
if len(instances.scores) == 0:
|
|
92
|
+
pred_id = None
|
|
93
|
+
pred_score = torch.zeros(())
|
|
94
|
+
pred_class = None
|
|
95
|
+
pred_class_name = None
|
|
96
|
+
correct = 0.0
|
|
97
|
+
else:
|
|
98
|
+
pred_id = instances.scores.max(-1).indices
|
|
99
|
+
pred_score = instances.scores[pred_id] # (num_instances,) -> () # noqa
|
|
100
|
+
pred_class = instances.pred_classes[pred_id] # (num_instances,) -> ()
|
|
101
|
+
pred_class_name = self._coco_classes[pred_class.item()] # noqa
|
|
102
|
+
|
|
103
|
+
correct = float(pred_class == gt_class)
|
|
104
|
+
|
|
105
|
+
# hlog(f"pred_class: {pred_class_name}, gt_class: {gt_class_name}, correct: {correct}")
|
|
106
|
+
return correct
|
|
107
|
+
|
|
108
|
+
def compute_score_count(self, instances, references):
|
|
109
|
+
# assume that there is only one type of object
|
|
110
|
+
gt_class_name = references["object"]
|
|
111
|
+
gt_class_idx = self._coco_classes.index(gt_class_name)
|
|
112
|
+
gt_count = references["count"]
|
|
113
|
+
if len(instances.scores) == 0:
|
|
114
|
+
pred_count = 0
|
|
115
|
+
correct = 0.0
|
|
116
|
+
else:
|
|
117
|
+
pred_count = (instances.pred_classes == gt_class_idx).sum().item()
|
|
118
|
+
correct = float(pred_count == gt_count)
|
|
119
|
+
return correct
|
|
120
|
+
|
|
121
|
+
def compute_score_spatial(self, instances, references):
|
|
122
|
+
gt_class_name_1, gt_class_name_2 = references["objects"]
|
|
123
|
+
gt_class_idx_1 = self._coco_classes.index(gt_class_name_1)
|
|
124
|
+
gt_class_idx_2 = self._coco_classes.index(gt_class_name_2)
|
|
125
|
+
relation = references["relation"].split("_")[0]
|
|
126
|
+
|
|
127
|
+
if len(instances.scores) == 0:
|
|
128
|
+
correct = 0
|
|
129
|
+
pred_rel = "no_pred"
|
|
130
|
+
else:
|
|
131
|
+
pred_count_1 = (instances.pred_classes == gt_class_idx_1).sum().item()
|
|
132
|
+
pred_count_2 = (instances.pred_classes == gt_class_idx_2).sum().item()
|
|
133
|
+
if pred_count_1 != 1 or pred_count_2 != 1:
|
|
134
|
+
correct = 0
|
|
135
|
+
pred_rel = "obj_count_mismatch"
|
|
136
|
+
else:
|
|
137
|
+
x11, y11 = instances.pred_boxes[instances.pred_classes == gt_class_idx_1].tensor[0, :2]
|
|
138
|
+
x21, y21 = instances.pred_boxes[instances.pred_classes == gt_class_idx_2].tensor[0, :2]
|
|
139
|
+
|
|
140
|
+
x_diff = x11 - x21
|
|
141
|
+
y_diff = y11 - y21
|
|
142
|
+
|
|
143
|
+
# FIXME: The code below mimics dall-eval logic. I don't think
|
|
144
|
+
# we need to follow it. Does the case of two objects of same
|
|
145
|
+
# category make sense? Also, I don't know why we need to
|
|
146
|
+
# to ensure something is more "right" than it is "above".
|
|
147
|
+
if gt_class_name_1 == gt_class_name_2:
|
|
148
|
+
if abs(x_diff) > abs(y_diff):
|
|
149
|
+
if relation in ["left", "right"]:
|
|
150
|
+
correct = 1
|
|
151
|
+
pred_rel = "relation_correct"
|
|
152
|
+
else:
|
|
153
|
+
pred_rel = "relation_incorrect"
|
|
154
|
+
correct = 0
|
|
155
|
+
else:
|
|
156
|
+
if relation in ["above", "below"]:
|
|
157
|
+
pred_rel = "relation_correct"
|
|
158
|
+
correct = 1
|
|
159
|
+
else:
|
|
160
|
+
pred_rel = "relation_incorrect"
|
|
161
|
+
correct = 0
|
|
162
|
+
else:
|
|
163
|
+
if abs(x_diff) > abs(y_diff):
|
|
164
|
+
if x11 < x21:
|
|
165
|
+
pred_rel = "right"
|
|
166
|
+
else:
|
|
167
|
+
pred_rel = "left"
|
|
168
|
+
else:
|
|
169
|
+
if y11 > y21:
|
|
170
|
+
pred_rel = "above"
|
|
171
|
+
else:
|
|
172
|
+
pred_rel = "below"
|
|
173
|
+
|
|
174
|
+
if relation == pred_rel:
|
|
175
|
+
correct = 1
|
|
176
|
+
else:
|
|
177
|
+
correct = 0
|
|
178
|
+
return correct
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from helm.common.request import RequestResult
|
|
4
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
5
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
6
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
7
|
+
from helm.benchmark.metrics.metric import Metric
|
|
8
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
9
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
10
|
+
from helm.common.multimodal_request_utils import gather_generated_image_locations
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EfficiencyMetric(Metric):
|
|
14
|
+
"""
|
|
15
|
+
Defines the efficiency metrics for text-to-image models.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __repr__(self):
|
|
19
|
+
return "EfficiencyMetric()"
|
|
20
|
+
|
|
21
|
+
def evaluate_generation(
|
|
22
|
+
self,
|
|
23
|
+
adapter_spec: AdapterSpec,
|
|
24
|
+
request_state: RequestState,
|
|
25
|
+
metric_service: MetricService,
|
|
26
|
+
eval_cache_path: str,
|
|
27
|
+
) -> List[Stat]:
|
|
28
|
+
prompt: str = request_state.request.prompt
|
|
29
|
+
|
|
30
|
+
assert request_state.result is not None
|
|
31
|
+
request_result: RequestResult = request_state.result
|
|
32
|
+
image_locations: List[str] = gather_generated_image_locations(request_result)
|
|
33
|
+
if len(image_locations) == 0:
|
|
34
|
+
return []
|
|
35
|
+
|
|
36
|
+
# inference_runtime is computed in BasicMetric
|
|
37
|
+
stats: List[Stat] = [
|
|
38
|
+
Stat(MetricName("prompt_length")).add(len(prompt)),
|
|
39
|
+
Stat(MetricName("num_generated_images")).add(len(request_result.completions)),
|
|
40
|
+
]
|
|
41
|
+
return stats
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from tqdm import tqdm
|
|
2
|
+
from typing import Dict, List, Set, Optional
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
import shutil
|
|
6
|
+
|
|
7
|
+
from helm.common.general import ensure_directory_exists, generate_unique_id, get_file_name, hlog
|
|
8
|
+
from helm.common.gpu_utils import is_cuda_available, get_torch_device
|
|
9
|
+
from helm.common.request import RequestResult
|
|
10
|
+
from helm.benchmark.augmentations.perturbation_description import PerturbationDescription
|
|
11
|
+
from helm.benchmark.scenarios.scenario import Instance
|
|
12
|
+
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
13
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
14
|
+
from helm.benchmark.metrics.metric import MetricInterface, MetricResult
|
|
15
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
16
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
17
|
+
from helm.common.images_utils import is_blacked_out_image, copy_image
|
|
18
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FidelityMetric(MetricInterface):
|
|
22
|
+
"""
|
|
23
|
+
Frechet Inception Distance (FID) is a measure of similarity between two sets of images.
|
|
24
|
+
Inception Score (IS) measures quality and diversity of images.
|
|
25
|
+
Both metrics require a large number of samples to compute.
|
|
26
|
+
|
|
27
|
+
@misc{Seitzer2020FID,
|
|
28
|
+
author={Maximilian Seitzer},
|
|
29
|
+
title={{pytorch-fid: FID Score for PyTorch}},
|
|
30
|
+
month={August},
|
|
31
|
+
year={2020},
|
|
32
|
+
note={Version 0.3.0},
|
|
33
|
+
howpublished={https://github.com/mseitzer/pytorch-fid},
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
@misc{obukhov2020torchfidelity,
|
|
37
|
+
author={Anton Obukhov and Maximilian Seitzer and Po-Wei Wu and Semen Zhydenko and Jonathan Kyl
|
|
38
|
+
and Elvis Yu-Jing Lin},
|
|
39
|
+
year=2020,
|
|
40
|
+
title={High-fidelity performance metrics for generative models in PyTorch},
|
|
41
|
+
url={https://github.com/toshas/torch-fidelity},
|
|
42
|
+
publisher={Zenodo},
|
|
43
|
+
version={v0.3.0},
|
|
44
|
+
doi={10.5281/zenodo.4957738},
|
|
45
|
+
note={Version: 0.3.0, DOI: 10.5281/zenodo.4957738}
|
|
46
|
+
}
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
IMAGE_WIDTH: int = 512
|
|
50
|
+
IMAGE_HEIGHT: int = 512
|
|
51
|
+
|
|
52
|
+
def __repr__(self):
|
|
53
|
+
return "FidelityMetric()"
|
|
54
|
+
|
|
55
|
+
def evaluate(
|
|
56
|
+
self,
|
|
57
|
+
scenario_state: ScenarioState,
|
|
58
|
+
metric_service: MetricService,
|
|
59
|
+
eval_cache_path: str,
|
|
60
|
+
parallelism: int,
|
|
61
|
+
) -> MetricResult:
|
|
62
|
+
try:
|
|
63
|
+
import torch_fidelity
|
|
64
|
+
from pytorch_fid.fid_score import calculate_fid_given_paths
|
|
65
|
+
except ModuleNotFoundError as e:
|
|
66
|
+
handle_module_not_found_error(e, ["heim"])
|
|
67
|
+
|
|
68
|
+
dest_path: str
|
|
69
|
+
unique_perturbations: Set[Optional[PerturbationDescription]] = set()
|
|
70
|
+
|
|
71
|
+
gold_images_path: str = os.path.join(eval_cache_path, generate_unique_id())
|
|
72
|
+
ensure_directory_exists(gold_images_path)
|
|
73
|
+
|
|
74
|
+
# The library requires the gold and generated images to be in two separate directories.
|
|
75
|
+
# Gather the gold images and the unique perturbations
|
|
76
|
+
num_gold_images: int = 0
|
|
77
|
+
for request_state in tqdm(scenario_state.request_states):
|
|
78
|
+
instance: Instance = request_state.instance
|
|
79
|
+
unique_perturbations.add(instance.perturbation)
|
|
80
|
+
|
|
81
|
+
for reference in instance.references:
|
|
82
|
+
if not reference.is_correct:
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
assert (
|
|
86
|
+
reference.output.multimedia_content is not None
|
|
87
|
+
and reference.output.multimedia_content.media_objects[0].location is not None
|
|
88
|
+
)
|
|
89
|
+
file_path: str = reference.output.multimedia_content.media_objects[0].location
|
|
90
|
+
dest_path = os.path.join(gold_images_path, get_file_name(file_path))
|
|
91
|
+
copy_image(file_path, dest_path, width=self.IMAGE_WIDTH, height=self.IMAGE_HEIGHT)
|
|
92
|
+
num_gold_images += 1
|
|
93
|
+
hlog(f"Resized {num_gold_images} gold images to {self.IMAGE_WIDTH}x{self.IMAGE_HEIGHT}.")
|
|
94
|
+
|
|
95
|
+
# Compute the FID for each perturbation group
|
|
96
|
+
stats: List[Stat] = []
|
|
97
|
+
for perturbation in unique_perturbations:
|
|
98
|
+
perturbation_name: str = "" if perturbation is None else str(perturbation)
|
|
99
|
+
generated_images_path: str = os.path.join(eval_cache_path, generate_unique_id())
|
|
100
|
+
ensure_directory_exists(generated_images_path)
|
|
101
|
+
|
|
102
|
+
num_generated_images: int = 0
|
|
103
|
+
for request_state in tqdm(scenario_state.request_states):
|
|
104
|
+
if request_state.instance.perturbation != perturbation:
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
assert request_state.result is not None
|
|
108
|
+
request_result: RequestResult = request_state.result
|
|
109
|
+
|
|
110
|
+
# Gather the model-generated images
|
|
111
|
+
for image in request_result.completions:
|
|
112
|
+
assert image.multimodal_content is not None
|
|
113
|
+
location = image.multimodal_content.media_objects[0].location
|
|
114
|
+
if location is not None and not is_blacked_out_image(location):
|
|
115
|
+
dest_path = os.path.join(generated_images_path, get_file_name(location))
|
|
116
|
+
copy_image(location, dest_path, width=self.IMAGE_WIDTH, height=self.IMAGE_HEIGHT)
|
|
117
|
+
num_generated_images += 1
|
|
118
|
+
|
|
119
|
+
compute_kid: bool = num_generated_images >= 1000
|
|
120
|
+
hlog(f"Resized {num_generated_images} images to {self.IMAGE_WIDTH}x{self.IMAGE_HEIGHT}.")
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
hlog(f"Computing FID between {generated_images_path} and {gold_images_path}...")
|
|
124
|
+
fid: float = calculate_fid_given_paths(
|
|
125
|
+
paths=[generated_images_path, gold_images_path],
|
|
126
|
+
device=get_torch_device(),
|
|
127
|
+
# Following defaults set in
|
|
128
|
+
# https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py#L54
|
|
129
|
+
batch_size=50,
|
|
130
|
+
dims=2048,
|
|
131
|
+
num_workers=8,
|
|
132
|
+
)
|
|
133
|
+
hlog(f"Done. FID score: {fid}")
|
|
134
|
+
|
|
135
|
+
# The torch_fidelity library fails when there are too few images (i.e., `max_eval_instances` is small).
|
|
136
|
+
hlog("Computing the other fidelity metrics...")
|
|
137
|
+
metrics_dict: Dict[str, float] = torch_fidelity.calculate_metrics(
|
|
138
|
+
input1=generated_images_path,
|
|
139
|
+
input2=gold_images_path,
|
|
140
|
+
isc=True,
|
|
141
|
+
fid=False,
|
|
142
|
+
kid=compute_kid,
|
|
143
|
+
ppl=False, # Requires `GenerativeModel`
|
|
144
|
+
cuda=is_cuda_available(),
|
|
145
|
+
save_cpu_ram=not is_cuda_available(),
|
|
146
|
+
)
|
|
147
|
+
inception_score: float = metrics_dict["inception_score_mean"]
|
|
148
|
+
if math.isnan(inception_score):
|
|
149
|
+
inception_score = 0
|
|
150
|
+
|
|
151
|
+
stats.extend(
|
|
152
|
+
[
|
|
153
|
+
Stat(MetricName("fid", perturbation=perturbation)).add(fid),
|
|
154
|
+
Stat(MetricName("inception_score", perturbation=perturbation)).add(inception_score),
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
if compute_kid:
|
|
158
|
+
kid: float = metrics_dict["kernel_inception_distance_mean"]
|
|
159
|
+
stats.append(Stat(MetricName("kernel_inception_distance", perturbation=perturbation)).add(kid))
|
|
160
|
+
except AssertionError as e:
|
|
161
|
+
hlog(f"Error occurred when computing fidelity metrics for perturbation: {perturbation_name} Error: {e}")
|
|
162
|
+
|
|
163
|
+
shutil.rmtree(generated_images_path)
|
|
164
|
+
|
|
165
|
+
# Delete the gold images directory
|
|
166
|
+
shutil.rmtree(gold_images_path)
|
|
167
|
+
|
|
168
|
+
return MetricResult(aggregated_stats=stats, per_instance_stats=[])
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_fractal_dimension(image_path: str) -> float:
|
|
7
|
+
"""
|
|
8
|
+
Compute the fractal coefficient of an image.
|
|
9
|
+
From https://en.wikipedia.org/wiki/Minkowski–Bouligand_dimension, in fractal
|
|
10
|
+
geometry, the Minkowski–Bouligand dimension, also known as Minkowski dimension
|
|
11
|
+
or box-counting dimension, is a way of determining the fractal dimension of a
|
|
12
|
+
set S in a Euclidean space Rn, or more generally in a metric space (X, d).
|
|
13
|
+
|
|
14
|
+
Adapted from https://gist.github.com/viveksck/1110dfca01e4ec2c608515f0d5a5b1d1.
|
|
15
|
+
|
|
16
|
+
:param image_path: Path to the image.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def fractal_dimension(Z, threshold=0.2):
|
|
20
|
+
# Only for 2d image
|
|
21
|
+
assert len(Z.shape) == 2
|
|
22
|
+
|
|
23
|
+
# From https://github.com/rougier/numpy-100 (#87)
|
|
24
|
+
def boxcount(Z, k):
|
|
25
|
+
S = np.add.reduceat(
|
|
26
|
+
np.add.reduceat(Z, np.arange(0, Z.shape[0], k), axis=0), np.arange(0, Z.shape[1], k), axis=1
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# We count non-empty (0) and non-full boxes (k*k)
|
|
30
|
+
return len(np.where((S > 0) & (S < k * k))[0])
|
|
31
|
+
|
|
32
|
+
# Transform Z into a binary array
|
|
33
|
+
Z = Z < threshold
|
|
34
|
+
|
|
35
|
+
# Minimal dimension of image
|
|
36
|
+
p = min(Z.shape)
|
|
37
|
+
|
|
38
|
+
# Greatest power of 2 less than or equal to p
|
|
39
|
+
n = 2 ** np.floor(np.log(p) / np.log(2))
|
|
40
|
+
|
|
41
|
+
# Extract the exponent
|
|
42
|
+
n = int(np.log(n) / np.log(2))
|
|
43
|
+
|
|
44
|
+
# Build successive box sizes (from 2**n down to 2**1)
|
|
45
|
+
sizes = 2 ** np.arange(n, 1, -1)
|
|
46
|
+
|
|
47
|
+
# Actual box counting with decreasing size
|
|
48
|
+
counts = []
|
|
49
|
+
for size in sizes:
|
|
50
|
+
counts.append(boxcount(Z, size))
|
|
51
|
+
|
|
52
|
+
# Fit the successive log(sizes) with log (counts)
|
|
53
|
+
coeffs = np.polyfit(np.log(sizes), np.log(counts), 1)
|
|
54
|
+
return -coeffs[0]
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
import cv2
|
|
58
|
+
except ModuleNotFoundError as e:
|
|
59
|
+
handle_module_not_found_error(e, ["heim"])
|
|
60
|
+
|
|
61
|
+
image = cv2.imread(image_path, 0) / 255.0 # type: ignore
|
|
62
|
+
assert image.min() >= 0 and image.max() <= 1
|
|
63
|
+
return fractal_dimension(image)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from .fractal_dimension_util import compute_fractal_dimension
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def fractal_dimension_test(image_filename: str, expected_fractal_dimension: float):
|
|
7
|
+
image_path: str = os.path.join(os.path.dirname(__file__), "test_images", image_filename)
|
|
8
|
+
dim: float = compute_fractal_dimension(image_path)
|
|
9
|
+
assert round(dim, 2) == expected_fractal_dimension
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Test case are inspired by https://www.sciencedirect.com/science/article/pii/S0097849303001547
|
|
13
|
+
def test_compute_fractal_dimension_cloud():
|
|
14
|
+
# Clouds have a fractal dimension (D) of 1.30-1.33.
|
|
15
|
+
fractal_dimension_test("cloud.png", 1.34)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_compute_fractal_dimension_sea_anemone():
|
|
19
|
+
# Sea anemones have a D of 1.6.
|
|
20
|
+
fractal_dimension_test("sea_anemone.png", 1.54)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_compute_fractal_dimension_snowflake():
|
|
24
|
+
# Snowflakes have a D of 1.7.
|
|
25
|
+
fractal_dimension_test("snowflakes.png", 1.69)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_compute_fractal_dimension_convergence():
|
|
29
|
+
# "Pollock continued to drip paint for a period lasting up to six months, depositing layer upon layer,
|
|
30
|
+
# and gradually creating a highly dense fractal pattern. As a result, the D value of his paintings rose
|
|
31
|
+
# gradually as they neared completion, starting in the range of 1.3–1.5 for the initial springboard layer
|
|
32
|
+
# and reaching a final value as high as 1.9". Convergence was produced in 1952 by Jackson Pollock.
|
|
33
|
+
fractal_dimension_test("convergence.png", 1.83)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from statistics import mean
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from helm.common.request import RequestResult
|
|
6
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
7
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
8
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
9
|
+
from helm.benchmark.metrics.metric import Metric
|
|
10
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
11
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
12
|
+
from helm.common.multimodal_request_utils import gather_generated_image_locations
|
|
13
|
+
from .fractal_dimension.fractal_dimension_util import compute_fractal_dimension
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FractalDimensionMetric(Metric):
|
|
17
|
+
|
|
18
|
+
# From https://www.nature.com/articles/35065154, "participants in the perception study consistently
|
|
19
|
+
# preferred fractals with D values in the range of 1.3 to 1.5, irrespective of the pattern's origin.
|
|
20
|
+
# Significantly, many of the fractal patterns surrounding us in nature have D values in this range.
|
|
21
|
+
# Clouds have a value of 1.3."
|
|
22
|
+
IDEAL_FRACTAL_DIMENSION: float = 1.4
|
|
23
|
+
|
|
24
|
+
def __repr__(self):
|
|
25
|
+
return "FractalDimensionMetric()"
|
|
26
|
+
|
|
27
|
+
def evaluate_generation(
|
|
28
|
+
self,
|
|
29
|
+
adapter_spec: AdapterSpec,
|
|
30
|
+
request_state: RequestState,
|
|
31
|
+
metric_service: MetricService,
|
|
32
|
+
eval_cache_path: str,
|
|
33
|
+
) -> List[Stat]:
|
|
34
|
+
assert request_state.result is not None
|
|
35
|
+
request_result: RequestResult = request_state.result
|
|
36
|
+
image_locations: List[str] = gather_generated_image_locations(request_result)
|
|
37
|
+
if len(image_locations) == 0:
|
|
38
|
+
return []
|
|
39
|
+
|
|
40
|
+
fractal_dimensions: List[float] = [
|
|
41
|
+
compute_fractal_dimension(image_location) for image_location in image_locations
|
|
42
|
+
]
|
|
43
|
+
fractal_dimension_losses: List[float] = [
|
|
44
|
+
abs(dim - self.IDEAL_FRACTAL_DIMENSION) for dim in fractal_dimensions if not math.isnan(dim)
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
stats: List[Stat] = []
|
|
48
|
+
if len(fractal_dimension_losses) > 0:
|
|
49
|
+
stats.append(Stat(MetricName("fractal_dimension_loss")).add(mean(fractal_dimension_losses)))
|
|
50
|
+
return stats
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from statistics import mean
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
5
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
6
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
7
|
+
from helm.benchmark.metrics.metric import Metric
|
|
8
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
9
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
10
|
+
from helm.common.request import RequestResult
|
|
11
|
+
from helm.common.clip_score_request import CLIPScoreRequest, CLIPScoreResult
|
|
12
|
+
from helm.common.images_utils import is_blacked_out_image
|
|
13
|
+
from helm.common.multimodal_request_utils import gather_generated_image_locations
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GenderMetric(Metric):
|
|
17
|
+
"""
|
|
18
|
+
Following https://arxiv.org/abs/2202.04053, use the CLIP score to determine if the photo
|
|
19
|
+
is of a female or male.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
MALE_PROMPT: str = "a photo of a male"
|
|
23
|
+
FEMALE_PROMPT: str = "a photo of a female"
|
|
24
|
+
|
|
25
|
+
IDEAL_FRAC: float = 0.5
|
|
26
|
+
|
|
27
|
+
def __repr__(self):
|
|
28
|
+
return "GenderMetric()"
|
|
29
|
+
|
|
30
|
+
def evaluate_generation(
|
|
31
|
+
self,
|
|
32
|
+
adapter_spec: AdapterSpec,
|
|
33
|
+
request_state: RequestState,
|
|
34
|
+
metric_service: MetricService,
|
|
35
|
+
eval_cache_path: str,
|
|
36
|
+
) -> List[Stat]:
|
|
37
|
+
assert request_state.result is not None
|
|
38
|
+
request_result: RequestResult = request_state.result
|
|
39
|
+
|
|
40
|
+
image_locations: List[str] = gather_generated_image_locations(request_result)
|
|
41
|
+
is_female_results: List[bool] = [
|
|
42
|
+
self._is_photo_of_female(metric_service, location)
|
|
43
|
+
for location in image_locations
|
|
44
|
+
if location is not None and not is_blacked_out_image(location)
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
female_frac: float = mean(is_female_results) if len(is_female_results) > 0 else 0
|
|
48
|
+
stats: List[Stat] = [Stat(MetricName("gender_imbalance")).add(abs(female_frac - self.IDEAL_FRAC))]
|
|
49
|
+
return stats
|
|
50
|
+
|
|
51
|
+
def _is_photo_of_female(self, metric_service: MetricService, image_location: str) -> bool:
|
|
52
|
+
def make_clip_score_request(prompt: str) -> float:
|
|
53
|
+
result: CLIPScoreResult = metric_service.compute_clip_score(CLIPScoreRequest(prompt, image_location))
|
|
54
|
+
return result.score
|
|
55
|
+
|
|
56
|
+
female_clip_score: float = make_clip_score_request(self.FEMALE_PROMPT)
|
|
57
|
+
male_clip_score: float = make_clip_score_request(self.MALE_PROMPT)
|
|
58
|
+
return female_clip_score > male_clip_score
|