crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
- crfm_helm-0.5.1.dist-info/RECORD +654 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +25 -3
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +41 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +213 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +41 -1
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +205 -35
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +163 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +757 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +823 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +233 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +301 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +104 -73
- helm/clients/vertexai_client.py +400 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +111 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +33 -3
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1159 -538
- helm/config/model_metadata.yaml +868 -41
- helm/config/tokenizer_configs.yaml +149 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_deployment_definition.py +0 -92
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
from typing import List, Dict, Optional, Callable, Tuple, Any, Set
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from torchvision import transforms, models
|
|
4
|
+
from skimage.metrics import structural_similarity as ssim
|
|
5
|
+
from nltk.tokenize.treebank import TreebankWordTokenizer
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import warnings
|
|
9
|
+
import numpy as np
|
|
10
|
+
import os
|
|
11
|
+
import tempfile
|
|
12
|
+
|
|
13
|
+
from helm.benchmark.metrics.copyright_metrics import _edit_similarity
|
|
14
|
+
from helm.benchmark.metrics.metric import Metric
|
|
15
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
16
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
17
|
+
from helm.common.images_utils import open_image
|
|
18
|
+
from helm.common.gpu_utils import get_torch_device
|
|
19
|
+
from helm.common.cache import Cache
|
|
20
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
21
|
+
from helm.common.media_object import MediaObject
|
|
22
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
23
|
+
from helm.common.hierarchical_logger import hlog
|
|
24
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
25
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
26
|
+
from helm.benchmark.metrics.vision_language.image_utils import (
|
|
27
|
+
preprocess_image,
|
|
28
|
+
pixel_similarity,
|
|
29
|
+
sift_similarity,
|
|
30
|
+
)
|
|
31
|
+
from helm.benchmark.metrics.vision_language.emd_utils import compute_emd_recursive, get_most_frequent_color
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
35
|
+
from PIL import Image
|
|
36
|
+
import imagehash
|
|
37
|
+
except ModuleNotFoundError as e:
|
|
38
|
+
handle_module_not_found_error(e, suggestions=["image2structure"])
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pad(small_image: Image.Image, large_image: Image.Image, axis: int) -> Image.Image:
|
|
42
|
+
"""Pad the axis of the small image to match the size of the large image."""
|
|
43
|
+
new_dim: List[int] = list(small_image.size)
|
|
44
|
+
new_dim[axis] = large_image.size[axis]
|
|
45
|
+
new_dim_tupe: Tuple[int, int] = tuple(new_dim) # type: ignore
|
|
46
|
+
new_image: Image.Image = Image.new("RGB", new_dim_tupe, (255, 255, 255))
|
|
47
|
+
new_image.paste(small_image, (0, 0))
|
|
48
|
+
return new_image
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CompilationError(Exception):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class AnnotatedMetric:
|
|
57
|
+
name: str
|
|
58
|
+
function: Callable
|
|
59
|
+
input_type: str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class AnnotatedImageMetrics(Metric):
|
|
63
|
+
"""Abstract class for image metrics.
|
|
64
|
+
|
|
65
|
+
This class is designed to evaluate metrics on images that should be generated using the text
|
|
66
|
+
ouput of the model, such as LaTeX, HTML, etc.
|
|
67
|
+
|
|
68
|
+
The class provides a method to compile the completion into an image and then evaluate the
|
|
69
|
+
similarity between the generated image and the reference image using different metrics.
|
|
70
|
+
|
|
71
|
+
In addition to the metrics, the class also provides a metric to evaluate the compilation success.
|
|
72
|
+
If the compilation fails, the similarity metrics are not evaluated and are all set to the most
|
|
73
|
+
dissimilar value.
|
|
74
|
+
|
|
75
|
+
Since compilation can be expensive, the class provides a cache to store the compiled images.
|
|
76
|
+
In addition metrics can also be cached to avoid recomputation.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# Metric names
|
|
80
|
+
COMPILE_METRIC: str = "compilation_success"
|
|
81
|
+
BLOCK_EARTH_MOVER_SIMILARITY_NORM1: str = "block_emd_similarity_white"
|
|
82
|
+
BLOCK_EARTH_MOVER_SIMILARITY_NORM2: str = "block_emd_similarity_median_color"
|
|
83
|
+
BLOCK_EARTH_MOVER_SIMILARITY: str = "block_emd_similarity"
|
|
84
|
+
PIXEL_SIMILARITY: str = "pixel_similarity"
|
|
85
|
+
SIFT_SIMILARITY: str = "sift_similarity"
|
|
86
|
+
LPIPS_SIMILARITY: str = "lpips_similarity"
|
|
87
|
+
SSIM_SIMILARITY: str = "ssim_similarity"
|
|
88
|
+
FID_SIMILARITY: str = "fid_similarity"
|
|
89
|
+
EDIT_SIMILARITY: str = "edit_similarity"
|
|
90
|
+
NORMALIZE_FID_FACTOR: float = 0.0025
|
|
91
|
+
|
|
92
|
+
SIZE_HANDLING_METHODS: List[str] = ["resize", "padding", "none"]
|
|
93
|
+
|
|
94
|
+
# Hashing (for caching)
|
|
95
|
+
HASH_LENGTH: int = 16
|
|
96
|
+
HASH_FUNC: Callable = imagehash.average_hash
|
|
97
|
+
|
|
98
|
+
def __init__(self, generation_type: str, metric_names: List[str], size_handling_method: str = "resize"):
|
|
99
|
+
self.generation_type = generation_type
|
|
100
|
+
self._metric_names: List[str] = metric_names
|
|
101
|
+
self._lpips_metric: Optional[LearnedPerceptualImagePatchSimilarity] = None
|
|
102
|
+
self._inception_model: Optional[models.Inception3] = None
|
|
103
|
+
self._device = get_torch_device()
|
|
104
|
+
self._cache: Optional[Cache] = None
|
|
105
|
+
self._size_handling_method: str = size_handling_method
|
|
106
|
+
self._tokenizer = TreebankWordTokenizer()
|
|
107
|
+
|
|
108
|
+
metrics: List[AnnotatedMetric] = [
|
|
109
|
+
AnnotatedMetric(self.PIXEL_SIMILARITY, pixel_similarity, "image_np_gray"),
|
|
110
|
+
AnnotatedMetric(self.SIFT_SIMILARITY, sift_similarity, "image_np"),
|
|
111
|
+
# Raw block EMD
|
|
112
|
+
AnnotatedMetric(self.BLOCK_EARTH_MOVER_SIMILARITY, self.compute_block_emd_raw, "image_PIL"),
|
|
113
|
+
# Normalized block EMD against white
|
|
114
|
+
AnnotatedMetric(self.BLOCK_EARTH_MOVER_SIMILARITY_NORM1, self.compute_block_emd_white, "image_PIL"),
|
|
115
|
+
# Normalized block EMD against median
|
|
116
|
+
AnnotatedMetric(self.BLOCK_EARTH_MOVER_SIMILARITY_NORM2, self.compute_block_emd_median, "image_PIL"),
|
|
117
|
+
AnnotatedMetric(self.LPIPS_SIMILARITY, self.lpips_similarity, "image_PIL"),
|
|
118
|
+
AnnotatedMetric(self.FID_SIMILARITY, self.fid_similarity, "image_PIL"),
|
|
119
|
+
AnnotatedMetric(self.SSIM_SIMILARITY, self.compute_ssim, "image_np_gray"),
|
|
120
|
+
AnnotatedMetric(self.EDIT_SIMILARITY, self.compute_edit_sim, "text_str"),
|
|
121
|
+
]
|
|
122
|
+
self.metrics: Dict[str, AnnotatedMetric] = {metric.name: metric for metric in metrics}
|
|
123
|
+
|
|
124
|
+
def _get_compilation_cache_key(self, completion: str) -> Dict[str, str]:
|
|
125
|
+
return {
|
|
126
|
+
"generation_type": self.generation_type,
|
|
127
|
+
"completion": completion,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def _prepare_inputs(
|
|
131
|
+
self,
|
|
132
|
+
inputs_required: Set[str],
|
|
133
|
+
request_state: RequestState,
|
|
134
|
+
annotation: Dict[str, Any],
|
|
135
|
+
ref_image: Optional[Image.Image],
|
|
136
|
+
) -> Dict[str, Tuple[Any, Any]]:
|
|
137
|
+
inputs: Dict[str, Tuple[Any, Any]] = {}
|
|
138
|
+
|
|
139
|
+
# Image
|
|
140
|
+
if any([input_type.startswith("image") for input_type in inputs_required]):
|
|
141
|
+
# Get the image and make sure we have a reference image
|
|
142
|
+
assert ref_image is not None
|
|
143
|
+
assert "media_object" in annotation
|
|
144
|
+
assert isinstance(annotation["media_object"], MediaObject)
|
|
145
|
+
media_object: MediaObject = annotation["media_object"]
|
|
146
|
+
assert media_object.type == "image"
|
|
147
|
+
assert media_object.is_local_file and media_object.location is not None
|
|
148
|
+
image: Image.Image = Image.open(media_object.location).convert("RGB")
|
|
149
|
+
|
|
150
|
+
# Handle difference in size
|
|
151
|
+
if image.size != ref_image.size:
|
|
152
|
+
if self._size_handling_method == "none":
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"Compiled image and reference image should have the same size"
|
|
155
|
+
" when the size handling method is none."
|
|
156
|
+
)
|
|
157
|
+
elif self._size_handling_method == "resize":
|
|
158
|
+
image = image.resize(ref_image.size)
|
|
159
|
+
elif self._size_handling_method == "padding":
|
|
160
|
+
for axis in range(2):
|
|
161
|
+
if image.size[axis] < ref_image.size[axis]:
|
|
162
|
+
image = pad(image, ref_image, axis)
|
|
163
|
+
elif image.size[axis] > ref_image.size[axis]:
|
|
164
|
+
ref_image = pad(ref_image, image, axis)
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"size handling method {self._size_handling_method} not recognized.")
|
|
167
|
+
assert image.size == ref_image.size
|
|
168
|
+
|
|
169
|
+
# Save the inputs
|
|
170
|
+
inputs["image_PIL"] = (image, ref_image)
|
|
171
|
+
|
|
172
|
+
# Convert to numpy array
|
|
173
|
+
if "image_np" in inputs_required:
|
|
174
|
+
rgb_ref_image: np.ndarray = np.array(ref_image)
|
|
175
|
+
rgb_image: np.ndarray = np.array(image)
|
|
176
|
+
inputs["image_np"] = (rgb_image, rgb_ref_image)
|
|
177
|
+
if "image_np_gray" in inputs_required:
|
|
178
|
+
gray_ref_image: np.ndarray = preprocess_image(ref_image)
|
|
179
|
+
gray_image: np.ndarray = preprocess_image(image)
|
|
180
|
+
inputs["image_np_gray"] = (gray_image, gray_ref_image)
|
|
181
|
+
|
|
182
|
+
# Text
|
|
183
|
+
if any([input_type.startswith("text") for input_type in inputs_required]):
|
|
184
|
+
assert "text" in annotation
|
|
185
|
+
text: str = annotation["text"]
|
|
186
|
+
reference = request_state.instance.references[0]
|
|
187
|
+
inputs["text_str"] = (text, reference.output.text)
|
|
188
|
+
|
|
189
|
+
# Check that all inputs are present
|
|
190
|
+
SUPPORTED_INPUTS: List[str] = ["image_PIL", "image_np", "image_np_gray", "text_str"]
|
|
191
|
+
for input_type in inputs_required:
|
|
192
|
+
if input_type not in SUPPORTED_INPUTS:
|
|
193
|
+
raise AssertionError(f"Input type {input_type} is not supported.")
|
|
194
|
+
if input_type not in inputs:
|
|
195
|
+
raise ValueError(f"Input type {input_type} is required for the metrics but not present.")
|
|
196
|
+
|
|
197
|
+
return inputs
|
|
198
|
+
|
|
199
|
+
def evaluate_generation(
|
|
200
|
+
self,
|
|
201
|
+
adapter_spec: AdapterSpec,
|
|
202
|
+
request_state: RequestState,
|
|
203
|
+
metric_service: MetricService,
|
|
204
|
+
eval_cache_path: str,
|
|
205
|
+
) -> List[Stat]:
|
|
206
|
+
compiler_name: str = f"{self.generation_type}_compiler"
|
|
207
|
+
if self._cache is None:
|
|
208
|
+
self._cache = metric_service.get_cache(f"image_metrics_{self.generation_type}")
|
|
209
|
+
|
|
210
|
+
stats_dict: Dict[str, Stat] = {
|
|
211
|
+
name: Stat(MetricName(name)) for name in (self._metric_names + [self.COMPILE_METRIC])
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
if request_state.annotations is None or request_state.result is None:
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"Annotations and results should be present.",
|
|
217
|
+
" Please make sure to add a compiler annotator to the run spec.",
|
|
218
|
+
)
|
|
219
|
+
if compiler_name not in request_state.annotations:
|
|
220
|
+
raise ValueError(f"Compiler {compiler_name} should be present in the annotations.")
|
|
221
|
+
|
|
222
|
+
inputs_required: Set[str] = set()
|
|
223
|
+
for metric_name in self._metric_names:
|
|
224
|
+
inputs_required.add(self.metrics[metric_name].input_type)
|
|
225
|
+
|
|
226
|
+
# Get the image reference (only once as opening an image is slow)
|
|
227
|
+
# The text annotation can be loaded several times without performance issues
|
|
228
|
+
reference = request_state.instance.references[0]
|
|
229
|
+
ref_image: Optional[Image.Image] = None
|
|
230
|
+
if any([input_type.startswith("image") for input_type in inputs_required]):
|
|
231
|
+
assert reference.output.multimedia_content is not None
|
|
232
|
+
assert len(reference.output.multimedia_content.media_objects) > 0
|
|
233
|
+
ref_media_object: MediaObject = reference.output.multimedia_content.media_objects[0]
|
|
234
|
+
assert ref_media_object.type == "image"
|
|
235
|
+
if ref_media_object.is_local_file and ref_media_object.location is not None:
|
|
236
|
+
ref_image = open_image(ref_media_object.location)
|
|
237
|
+
else:
|
|
238
|
+
raise Exception(
|
|
239
|
+
"Remote images are not supported in metrics. "
|
|
240
|
+
"Images should be downloaded when constructing the instance."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# For each completion, evaluate the metrics
|
|
244
|
+
assert request_state.result is not None
|
|
245
|
+
for completion_index in range(len(request_state.result.completions)):
|
|
246
|
+
annotation: Dict[str, Any] = request_state.annotations[compiler_name][completion_index]
|
|
247
|
+
|
|
248
|
+
# Handle errors in annotation
|
|
249
|
+
if "unknown_error" in annotation:
|
|
250
|
+
hlog(
|
|
251
|
+
f"Unknown error in annotation: {annotation['unknown_error']}\n"
|
|
252
|
+
f"Scores of zero will be returned for all metrics."
|
|
253
|
+
)
|
|
254
|
+
if "error" in annotation or "unknown_error" in annotation:
|
|
255
|
+
stats_dict[self.COMPILE_METRIC].add(0) # Did not compile
|
|
256
|
+
# For all other metrics, we set the value to zero
|
|
257
|
+
for metric_name in self._metric_names:
|
|
258
|
+
stats_dict[metric_name].add(0)
|
|
259
|
+
continue
|
|
260
|
+
|
|
261
|
+
# Get te inputs
|
|
262
|
+
inputs = self._prepare_inputs(inputs_required, request_state, annotation, ref_image)
|
|
263
|
+
|
|
264
|
+
# Hash the images for the cache key
|
|
265
|
+
hash_dict: Optional[Dict[str, str]] = None
|
|
266
|
+
if "image_PIL" in inputs:
|
|
267
|
+
(image, _) = inputs["image_PIL"]
|
|
268
|
+
hash_dict = {
|
|
269
|
+
"reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
|
|
270
|
+
"generated_image": str(AnnotatedImageMetrics.HASH_FUNC(image, hash_size=self.HASH_LENGTH)),
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
# Evaluate the metrics
|
|
274
|
+
for metric_name in self._metric_names:
|
|
275
|
+
metric: AnnotatedMetric = self.metrics[metric_name]
|
|
276
|
+
(pred, gt) = inputs[metric.input_type]
|
|
277
|
+
|
|
278
|
+
value: float
|
|
279
|
+
try:
|
|
280
|
+
|
|
281
|
+
def do_it():
|
|
282
|
+
value = metric.function(pred, gt)
|
|
283
|
+
return {"value": value}
|
|
284
|
+
|
|
285
|
+
cache_key = {"metric_name": metric_name, "pred": pred, "gt": gt}
|
|
286
|
+
if not isinstance(pred, str):
|
|
287
|
+
assert hash_dict is not None
|
|
288
|
+
cache_key = {"metric_name": metric_name, **hash_dict}
|
|
289
|
+
response_metric, _ = self._cache.get(cache_key, do_it)
|
|
290
|
+
value = response_metric["value"]
|
|
291
|
+
except Exception as e:
|
|
292
|
+
hlog(f"Error in metric {metric_name}: {str(e)}")
|
|
293
|
+
value = 0
|
|
294
|
+
stats_dict[metric_name].add(value)
|
|
295
|
+
|
|
296
|
+
stats_dict[self.COMPILE_METRIC].add(1) # Compiled
|
|
297
|
+
|
|
298
|
+
return list(stats_dict.values())
|
|
299
|
+
|
|
300
|
+
def lpips_similarity(self, generated_image: Image.Image, reference_image: Image.Image) -> float:
|
|
301
|
+
"""Compute the LPIPS similarity between the generated and reference images.
|
|
302
|
+
|
|
303
|
+
This metric is defined here as it requires loading the LPIPS model.
|
|
304
|
+
Storing the model in this class is easier than passing it as an argument.
|
|
305
|
+
"""
|
|
306
|
+
if self._lpips_metric is None:
|
|
307
|
+
with warnings.catch_warnings():
|
|
308
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
309
|
+
self._lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type="vgg").to(self._device)
|
|
310
|
+
|
|
311
|
+
preprocessing = transforms.Compose(
|
|
312
|
+
[
|
|
313
|
+
transforms.Resize((256, 256)),
|
|
314
|
+
transforms.ToTensor(),
|
|
315
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
316
|
+
]
|
|
317
|
+
)
|
|
318
|
+
generated_image_tensor = preprocessing(generated_image)
|
|
319
|
+
reference_image_tensor = preprocessing(reference_image)
|
|
320
|
+
|
|
321
|
+
# Add batch dimension (B, C, H, W) since torchmetrics expects batches
|
|
322
|
+
img1 = generated_image_tensor.unsqueeze(0).to(self._device)
|
|
323
|
+
img2 = reference_image_tensor.unsqueeze(0).to(self._device)
|
|
324
|
+
|
|
325
|
+
# Compute the LPIPS score
|
|
326
|
+
assert self._lpips_metric is not None
|
|
327
|
+
score: float = self._lpips_metric(img1, img2).detach().item()
|
|
328
|
+
return score
|
|
329
|
+
|
|
330
|
+
def _calculate_fid(self, act1, act2):
|
|
331
|
+
# Directly use the provided activations, assuming they are already means
|
|
332
|
+
mu1, mu2 = act1[0], act2[0] # Assuming act1 and act2 are of shape (1, 1000)
|
|
333
|
+
|
|
334
|
+
# Since we cannot compute a meaningful covariance matrix for single observations,
|
|
335
|
+
# and the provided sigma is scalar (not meaningful in this context),
|
|
336
|
+
# we'll skip the covariance part of the standard FID calculation.
|
|
337
|
+
# This is a significant deviation from the FID's intended use.
|
|
338
|
+
|
|
339
|
+
# Compute the square difference between the means
|
|
340
|
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
|
341
|
+
|
|
342
|
+
# Placeholder for FID score since we're not using covariance matrices
|
|
343
|
+
fid = ssdiff # This is not a standard FID calculation.
|
|
344
|
+
|
|
345
|
+
return fid
|
|
346
|
+
|
|
347
|
+
def _get_inception_features(self, img_tensor):
|
|
348
|
+
if self._inception_model is None:
|
|
349
|
+
|
|
350
|
+
def load_inception_model():
|
|
351
|
+
return models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, transform_input=False).to(
|
|
352
|
+
self._device
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
self._inception_model = load_inception_model()
|
|
357
|
+
except PermissionError:
|
|
358
|
+
# If access denied, use a temporary directory
|
|
359
|
+
hlog("Access denied to torch cache directory. Using a temporary directory.")
|
|
360
|
+
temp_cache_dir = tempfile.mkdtemp()
|
|
361
|
+
os.environ["TORCH_HOME"] = temp_cache_dir
|
|
362
|
+
self._inception_model = load_inception_model()
|
|
363
|
+
self._inception_model.eval()
|
|
364
|
+
with torch.no_grad():
|
|
365
|
+
if self._inception_model.training:
|
|
366
|
+
self._inception_model.eval()
|
|
367
|
+
pred = self._inception_model(img_tensor)
|
|
368
|
+
return pred.cpu().detach().numpy()
|
|
369
|
+
|
|
370
|
+
def _preprocess_image(self, image):
|
|
371
|
+
# Source: https://pytorch.org/hub/pytorch_vision_inception_v3/
|
|
372
|
+
preprocess = transforms.Compose(
|
|
373
|
+
[
|
|
374
|
+
transforms.Resize(299),
|
|
375
|
+
transforms.CenterCrop(299),
|
|
376
|
+
transforms.ToTensor(),
|
|
377
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
378
|
+
]
|
|
379
|
+
)
|
|
380
|
+
return preprocess(image)
|
|
381
|
+
|
|
382
|
+
def fid_similarity(self, generated_image: Image.Image, reference_image: Image.Image) -> float:
|
|
383
|
+
"""Compute the Frechet Inception Distance (FID) between the generated and reference images.
|
|
384
|
+
|
|
385
|
+
This metric is defined here as it requires loading the Inception model.
|
|
386
|
+
Storing the model in this class is easier than passing it as an argument.
|
|
387
|
+
"""
|
|
388
|
+
img1_tensor = self._preprocess_image(generated_image).unsqueeze(0).to(self._device)
|
|
389
|
+
img2_tensor = self._preprocess_image(reference_image).unsqueeze(0).to(self._device)
|
|
390
|
+
|
|
391
|
+
features1 = self._get_inception_features(img1_tensor)
|
|
392
|
+
features2 = self._get_inception_features(img2_tensor)
|
|
393
|
+
|
|
394
|
+
fid_score = self._calculate_fid(features1, features2)
|
|
395
|
+
normalize_fid: float = np.exp(-fid_score * self.NORMALIZE_FID_FACTOR)
|
|
396
|
+
return normalize_fid
|
|
397
|
+
|
|
398
|
+
def compute_ssim(self, generated_image: np.ndarray, reference_image: np.ndarray) -> float:
|
|
399
|
+
"""Compute the Structural Similarity Index (SSIM) between the generated and reference images."""
|
|
400
|
+
return ssim(generated_image, reference_image)
|
|
401
|
+
|
|
402
|
+
def compute_edit_sim(self, completion: str, reference: str) -> float:
|
|
403
|
+
# `reference` is the entire remaining book for each instance.
|
|
404
|
+
# Truncate it here to be of the same length as the completion to ensure edit-distance is meaningful.
|
|
405
|
+
truncated_reference: str = reference[: len(completion)]
|
|
406
|
+
|
|
407
|
+
completion_tokens = self._tokenizer.tokenize(completion)
|
|
408
|
+
truncated_reference_tokens = self._tokenizer.tokenize(truncated_reference)
|
|
409
|
+
|
|
410
|
+
# Exploit numpy SIMD for efficiency on CPUs.
|
|
411
|
+
completion_tokens = np.array(completion_tokens)
|
|
412
|
+
truncated_reference_tokens = np.array(truncated_reference_tokens)
|
|
413
|
+
|
|
414
|
+
result = _edit_similarity(completion_tokens, truncated_reference_tokens)
|
|
415
|
+
return result
|
|
416
|
+
|
|
417
|
+
def compute_block_emd_white(
|
|
418
|
+
self,
|
|
419
|
+
pred_image: Image.Image,
|
|
420
|
+
ref_image: Image.Image,
|
|
421
|
+
threshold_most_frequent_color: float = 0.5,
|
|
422
|
+
patch_size: Tuple[int, int] = (8, 8),
|
|
423
|
+
max_num_patches: int = 100,
|
|
424
|
+
weight_most_frequent_color: float = 0.001,
|
|
425
|
+
use_tqdm: bool = False,
|
|
426
|
+
):
|
|
427
|
+
"""Computes the block Earth Moving Distance (EMD). This attempts to
|
|
428
|
+
speed up EMD for images with huge areas by considering movement/transformatio
|
|
429
|
+
of blocks of pixels. The score is normalized against EMD against white images
|
|
430
|
+
"""
|
|
431
|
+
|
|
432
|
+
def compute_numerator():
|
|
433
|
+
return self.compute_block_emd_raw_wrapper(
|
|
434
|
+
pred_image,
|
|
435
|
+
ref_image,
|
|
436
|
+
threshold_most_frequent_color,
|
|
437
|
+
patch_size,
|
|
438
|
+
max_num_patches,
|
|
439
|
+
weight_most_frequent_color,
|
|
440
|
+
use_tqdm,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
def compute_denominator():
|
|
444
|
+
constant_image = Image.new("RGB", ref_image.size, (255, 255, 255)) # default color is white
|
|
445
|
+
value = compute_emd_recursive(
|
|
446
|
+
constant_image,
|
|
447
|
+
ref_image,
|
|
448
|
+
threshold_most_frequent_color,
|
|
449
|
+
patch_size,
|
|
450
|
+
max_num_patches,
|
|
451
|
+
weight_most_frequent_color,
|
|
452
|
+
use_tqdm,
|
|
453
|
+
)
|
|
454
|
+
return {"value": value}
|
|
455
|
+
|
|
456
|
+
hash_dict = {
|
|
457
|
+
"reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
|
|
458
|
+
}
|
|
459
|
+
cache_key_numerator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY}", **hash_dict}
|
|
460
|
+
cache_key_denominator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY_NORM1}", **hash_dict}
|
|
461
|
+
|
|
462
|
+
assert self._cache is not None
|
|
463
|
+
emd_raw, _ = self._cache.get(cache_key_numerator, compute_numerator)
|
|
464
|
+
emd_base, _ = self._cache.get(cache_key_denominator, compute_denominator)
|
|
465
|
+
|
|
466
|
+
return 1.0 - emd_raw["value"] / emd_base["value"]
|
|
467
|
+
|
|
468
|
+
def compute_block_emd_median(
|
|
469
|
+
self,
|
|
470
|
+
pred_image: Image.Image,
|
|
471
|
+
ref_image: Image.Image,
|
|
472
|
+
threshold_most_frequent_color: float = 0.5,
|
|
473
|
+
patch_size: Tuple[int, int] = (8, 8),
|
|
474
|
+
max_num_patches: int = 100,
|
|
475
|
+
weight_most_frequent_color: float = 0.001,
|
|
476
|
+
use_tqdm: bool = False,
|
|
477
|
+
):
|
|
478
|
+
"""Same as compute_emd_similarity_recursive EXCEPT that
|
|
479
|
+
the normalization is against an image of the median color.
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
def compute_numerator():
|
|
483
|
+
return self.compute_block_emd_raw_wrapper(
|
|
484
|
+
pred_image,
|
|
485
|
+
ref_image,
|
|
486
|
+
threshold_most_frequent_color,
|
|
487
|
+
patch_size,
|
|
488
|
+
max_num_patches,
|
|
489
|
+
weight_most_frequent_color,
|
|
490
|
+
use_tqdm,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
def compute_denominator():
|
|
494
|
+
ref_img_np = np.array(ref_image)
|
|
495
|
+
(rgb_most_frequent_color, _) = get_most_frequent_color(ref_img_np)
|
|
496
|
+
|
|
497
|
+
# Most frequent color as base
|
|
498
|
+
constant_image = Image.new("RGB", ref_image.size, tuple(rgb_most_frequent_color)) # type: ignore
|
|
499
|
+
value = compute_emd_recursive(
|
|
500
|
+
constant_image,
|
|
501
|
+
ref_image,
|
|
502
|
+
threshold_most_frequent_color,
|
|
503
|
+
patch_size,
|
|
504
|
+
max_num_patches,
|
|
505
|
+
weight_most_frequent_color,
|
|
506
|
+
use_tqdm,
|
|
507
|
+
)
|
|
508
|
+
return {"value": value}
|
|
509
|
+
|
|
510
|
+
hash_dict = {
|
|
511
|
+
"reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
|
|
512
|
+
}
|
|
513
|
+
cache_key_numerator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY}", **hash_dict}
|
|
514
|
+
cache_key_denominator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY_NORM2}", **hash_dict}
|
|
515
|
+
|
|
516
|
+
assert self._cache is not None
|
|
517
|
+
emd_raw, _ = self._cache.get(cache_key_numerator, compute_numerator)
|
|
518
|
+
emd_base, _ = self._cache.get(cache_key_denominator, compute_denominator)
|
|
519
|
+
|
|
520
|
+
return 1.0 - emd_raw["value"] / emd_base["value"]
|
|
521
|
+
|
|
522
|
+
def compute_block_emd_raw(
|
|
523
|
+
self,
|
|
524
|
+
pred_image: Image.Image,
|
|
525
|
+
ref_image: Image.Image,
|
|
526
|
+
threshold_most_frequent_color: float = 0.5,
|
|
527
|
+
patch_size: Tuple[int, int] = (8, 8),
|
|
528
|
+
max_num_patches: int = 100,
|
|
529
|
+
weight_most_frequent_color: float = 0.001,
|
|
530
|
+
use_tqdm: bool = False,
|
|
531
|
+
):
|
|
532
|
+
def compute():
|
|
533
|
+
return self.compute_block_emd_raw_wrapper(
|
|
534
|
+
pred_image,
|
|
535
|
+
ref_image,
|
|
536
|
+
threshold_most_frequent_color,
|
|
537
|
+
patch_size,
|
|
538
|
+
max_num_patches,
|
|
539
|
+
weight_most_frequent_color,
|
|
540
|
+
use_tqdm,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
hash_dict = {
|
|
544
|
+
"reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
|
|
545
|
+
}
|
|
546
|
+
cache_key = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY}", **hash_dict}
|
|
547
|
+
assert self._cache is not None
|
|
548
|
+
emd_raw, _ = self._cache.get(cache_key, compute)
|
|
549
|
+
|
|
550
|
+
return emd_raw["value"]
|
|
551
|
+
|
|
552
|
+
def compute_block_emd_raw_wrapper(
|
|
553
|
+
self,
|
|
554
|
+
pred_image: Image.Image,
|
|
555
|
+
ref_image: Image.Image,
|
|
556
|
+
threshold_most_frequent_color: float = 0.5,
|
|
557
|
+
patch_size: Tuple[int, int] = (8, 8),
|
|
558
|
+
max_num_patches: int = 100,
|
|
559
|
+
weight_most_frequent_color: float = 0.001,
|
|
560
|
+
use_tqdm: bool = False,
|
|
561
|
+
):
|
|
562
|
+
"""Computes the block Earth Moving Distance (EMD). This attempts to
|
|
563
|
+
speed up EMD for images with huge areas by considering movement/transformatio
|
|
564
|
+
of blocks of pixels. The score is normalized against EMD against white images
|
|
565
|
+
"""
|
|
566
|
+
emd_value = compute_emd_recursive(
|
|
567
|
+
pred_image,
|
|
568
|
+
ref_image,
|
|
569
|
+
threshold_most_frequent_color,
|
|
570
|
+
patch_size,
|
|
571
|
+
max_num_patches,
|
|
572
|
+
weight_most_frequent_color,
|
|
573
|
+
use_tqdm,
|
|
574
|
+
)
|
|
575
|
+
return {"value": emd_value}
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import cv2
|
|
7
|
+
from PIL.Image import Image
|
|
8
|
+
except ModuleNotFoundError as e:
|
|
9
|
+
handle_module_not_found_error(e, suggestions=["image2structure"])
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def preprocess_image(image: Image) -> np.ndarray:
|
|
13
|
+
"""Preprocesses an image for use in metrics.
|
|
14
|
+
Returns a grayscale image stored using int in a numpy array.
|
|
15
|
+
Also normalizes the exposure of the image.
|
|
16
|
+
"""
|
|
17
|
+
image = image.convert("L")
|
|
18
|
+
np_image = np.array(image)
|
|
19
|
+
assert np_image.dtype == np.uint8
|
|
20
|
+
return np_image
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def pixel_similarity(img_a: np.ndarray, img_b: np.ndarray, threshold: float = 0.5, tolerance: float = 0.02) -> float:
|
|
24
|
+
"""
|
|
25
|
+
Measure the pixel-level similarity between two images
|
|
26
|
+
If the image has a color that occurs more than 100 * threshold percent of the time,
|
|
27
|
+
Then the associated pixels are ignored and the match is computed only on the other pixels.
|
|
28
|
+
A tolerance is used to compare each pixels to allow some small variations in color.
|
|
29
|
+
The tolerance is between 0 (exact match) and 1 (every color is ok)
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
img_a (np.ndarray): the first image
|
|
33
|
+
img_b (np.ndarray): the second image
|
|
34
|
+
threshold (float): Threshold to ignore dominant colors.
|
|
35
|
+
tolerance (float): Tolerance for color variation.
|
|
36
|
+
Returns:
|
|
37
|
+
float: the pixel-level similarity between the images (between 0 and 1)
|
|
38
|
+
"""
|
|
39
|
+
if img_a.shape != img_b.shape:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Images must have the same dimensions. img_a.shape = {img_a.shape}, img_b.shape = {img_b.shape}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Flatten the images
|
|
45
|
+
img_a_flat = img_a.reshape(-1, img_a.shape[-1])
|
|
46
|
+
img_b_flat = img_b.reshape(-1, img_b.shape[-1])
|
|
47
|
+
|
|
48
|
+
# Calculate color differences with tolerance
|
|
49
|
+
color_diff = np.linalg.norm(img_a_flat - img_b_flat, axis=1) / 255
|
|
50
|
+
within_tolerance = color_diff <= tolerance
|
|
51
|
+
|
|
52
|
+
# Calculate frequencies of all colors
|
|
53
|
+
unique_colors, indices = np.unique(np.concatenate((img_a_flat, img_b_flat), axis=0), axis=0, return_inverse=True)
|
|
54
|
+
color_counts = np.bincount(indices)
|
|
55
|
+
|
|
56
|
+
# Identify colors to ignore based on frequency threshold
|
|
57
|
+
ignore_colors_mask = color_counts > (len(img_a_flat) + len(img_b_flat)) * threshold / 2
|
|
58
|
+
ignore_in_a = ignore_colors_mask[indices[: len(img_a_flat)]]
|
|
59
|
+
ignore_in_b = ignore_colors_mask[indices[len(img_a_flat) :]]
|
|
60
|
+
|
|
61
|
+
# Apply ignore mask
|
|
62
|
+
valid_pixels = np.logical_not(np.logical_or(ignore_in_a, ignore_in_b)) & within_tolerance
|
|
63
|
+
|
|
64
|
+
# Calculate similarity
|
|
65
|
+
similarity = np.mean(valid_pixels) if len(valid_pixels) > 0 else 0
|
|
66
|
+
|
|
67
|
+
return similarity
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def sift_similarity(img_a: np.ndarray, img_b: np.ndarray) -> float:
|
|
71
|
+
"""
|
|
72
|
+
Use ORB features to measure image similarity between two numpy arrays representing images.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
img_a (np.ndarray): the first image
|
|
76
|
+
img_b (np.ndarray): the second image
|
|
77
|
+
Returns:
|
|
78
|
+
float: the ORB similarity between the images
|
|
79
|
+
"""
|
|
80
|
+
if len(img_a.shape) < 3 or len(img_b.shape) < 3:
|
|
81
|
+
raise ValueError("Both images must have 3 channels")
|
|
82
|
+
|
|
83
|
+
# Initialize the ORB feature detector
|
|
84
|
+
orb = cv2.ORB_create() if hasattr(cv2, "ORB_create") else cv2.ORB()
|
|
85
|
+
|
|
86
|
+
# Find the keypoints and descriptors with ORB
|
|
87
|
+
_, desc_a = orb.detectAndCompute(img_a, None)
|
|
88
|
+
_, desc_b = orb.detectAndCompute(img_b, None)
|
|
89
|
+
|
|
90
|
+
# Initialize the brute force matcher
|
|
91
|
+
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
|
|
92
|
+
|
|
93
|
+
# Match descriptors.
|
|
94
|
+
matches = bf.match(desc_a, desc_b)
|
|
95
|
+
|
|
96
|
+
# Calculate similarity based on the distance of the matches
|
|
97
|
+
similar_regions = [i for i in matches if i.distance < 70]
|
|
98
|
+
if len(matches) == 0:
|
|
99
|
+
return 0
|
|
100
|
+
return len(similar_regions) / len(matches)
|