crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +8 -2
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +36 -0
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +214 -16
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +14 -16
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +203 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +12 -72
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1069 -546
- helm/config/model_metadata.yaml +753 -31
- helm/config/tokenizer_configs.yaml +142 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -7,7 +7,7 @@ from helm.benchmark.adaptation.request_state import RequestState
|
|
|
7
7
|
from helm.benchmark.metrics.classification_metrics import ClassificationMetric
|
|
8
8
|
from helm.benchmark.metrics.statistic import Stat
|
|
9
9
|
from helm.benchmark.scenarios.scenario import Input, Instance, Output, Reference, CORRECT_TAG
|
|
10
|
-
from helm.common.request import Request, RequestResult,
|
|
10
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class _Option(NamedTuple):
|
|
@@ -28,7 +28,10 @@ def _request_state(prediction: str, options: List[_Option]):
|
|
|
28
28
|
output_mapping=None,
|
|
29
29
|
request=Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002"),
|
|
30
30
|
result=RequestResult(
|
|
31
|
-
success=True,
|
|
31
|
+
success=True,
|
|
32
|
+
embedding=[],
|
|
33
|
+
completions=[GeneratedOutput(text=prediction, logprob=0.0, tokens=[])],
|
|
34
|
+
cached=False,
|
|
32
35
|
),
|
|
33
36
|
num_train_instances=0,
|
|
34
37
|
prompt_truncated=False,
|
|
@@ -76,7 +79,7 @@ def test_evaluate_instances_binary_generation():
|
|
|
76
79
|
]
|
|
77
80
|
|
|
78
81
|
assert_stats_equal(
|
|
79
|
-
metric.evaluate_instances(request_states),
|
|
82
|
+
metric.evaluate_instances(request_states, ""),
|
|
80
83
|
_expected_stats(
|
|
81
84
|
{
|
|
82
85
|
"yes": {"tp": 3, "fp": 1, "tn": 2, "fn": 1},
|
|
@@ -106,7 +109,7 @@ def test_evaluate_instances_multi_class():
|
|
|
106
109
|
_request_state("invalid", _options("c")),
|
|
107
110
|
]
|
|
108
111
|
assert_stats_equal(
|
|
109
|
-
metric.evaluate_instances(request_states),
|
|
112
|
+
metric.evaluate_instances(request_states, ""),
|
|
110
113
|
_expected_stats(
|
|
111
114
|
{
|
|
112
115
|
"d": {"tp": 3, "fp": 1, "tn": 5, "fn": 1},
|
|
@@ -139,7 +142,7 @@ def test_evaluate_instances_multilabel():
|
|
|
139
142
|
]
|
|
140
143
|
|
|
141
144
|
assert_stats_equal(
|
|
142
|
-
metric.evaluate_instances(request_states),
|
|
145
|
+
metric.evaluate_instances(request_states, ""),
|
|
143
146
|
_expected_stats(
|
|
144
147
|
{
|
|
145
148
|
"d": {"tp": 5, "fp": 1, "tn": 5, "fn": 0},
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Test metrics
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytest
|
|
6
|
+
from helm.benchmark.metrics.disinformation_metrics import _monte_carlo_entropy, _self_bleu
|
|
7
|
+
from helm.common.request import GeneratedOutput, Token
|
|
8
|
+
|
|
9
|
+
# Test tokens
|
|
10
|
+
_TEST_1_TOKENS: List[Token] = [
|
|
11
|
+
Token("This", logprob=-0.25),
|
|
12
|
+
Token("is", logprob=-0.25),
|
|
13
|
+
Token("a", logprob=-0.25),
|
|
14
|
+
Token("test", logprob=-0.25),
|
|
15
|
+
]
|
|
16
|
+
_TEST_2_TOKENS: List[Token] = [
|
|
17
|
+
Token("This", logprob=-0.25),
|
|
18
|
+
Token("is", logprob=-0.25),
|
|
19
|
+
Token("another", logprob=-0.5),
|
|
20
|
+
Token("test", logprob=-0.25),
|
|
21
|
+
]
|
|
22
|
+
_TEST_EMPTY_TOKENS: List[Token] = []
|
|
23
|
+
test_empty_str_tokens: List[Token] = [
|
|
24
|
+
Token("", logprob=0),
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
# Test Sequences (two standard, one with an empty token, and one with no tokens)
|
|
28
|
+
_TEST_1 = GeneratedOutput(text="This is a test", logprob=-1, tokens=_TEST_1_TOKENS)
|
|
29
|
+
_TEST_2 = GeneratedOutput(text="This is another test", logprob=-1.25, tokens=_TEST_2_TOKENS)
|
|
30
|
+
_TEST_EMPTY = GeneratedOutput(text="", logprob=-float("nan"), tokens=_TEST_EMPTY_TOKENS)
|
|
31
|
+
_TEST_EMPTY_STR = GeneratedOutput(text="", logprob=0, tokens=test_empty_str_tokens)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Test Self-BLEU
|
|
35
|
+
def test_self_bleu_with_self():
|
|
36
|
+
score = _self_bleu([_TEST_1, _TEST_1])
|
|
37
|
+
assert score == pytest.approx(100)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_self_blue_with_other():
|
|
41
|
+
score = _self_bleu([_TEST_1, _TEST_2])
|
|
42
|
+
assert 0 < score < 100
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_self_blue_one_sequence():
|
|
46
|
+
score = _self_bleu([_TEST_1])
|
|
47
|
+
assert score == 0
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_self_blue_one_full_one_empty():
|
|
51
|
+
score = _self_bleu([_TEST_1, _TEST_EMPTY_STR])
|
|
52
|
+
assert score == 0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Test MC Entropy
|
|
56
|
+
def test_mc_entropy_with_self():
|
|
57
|
+
score = _monte_carlo_entropy([_TEST_1, _TEST_1])
|
|
58
|
+
assert score == pytest.approx(-_TEST_1.logprob)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_mc_entropy_with_other():
|
|
62
|
+
score = _monte_carlo_entropy([_TEST_1, _TEST_2])
|
|
63
|
+
assert score == pytest.approx(-(_TEST_1.logprob + _TEST_2.logprob) / 2)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_mc_entropy_one_sequence():
|
|
67
|
+
score = _monte_carlo_entropy([_TEST_1])
|
|
68
|
+
assert score == -_TEST_1.logprob
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_mc_entropy_one_full_one_empty():
|
|
72
|
+
score = _monte_carlo_entropy([_TEST_EMPTY_STR])
|
|
73
|
+
assert score == _TEST_EMPTY_STR.logprob
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_mc_entropy_with_no_tokens():
|
|
77
|
+
score = _monte_carlo_entropy([_TEST_EMPTY])
|
|
78
|
+
assert np.isnan(score)
|
|
@@ -3,13 +3,13 @@ from .basic_metrics import get_num_bytes, convert_tokens_to_text
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def test_get_num_bytes():
|
|
6
|
-
tokens = [Token(text, 0
|
|
6
|
+
tokens = [Token(text, 0) for text in ["bytes:\\x99", "Hello", " world", "bytes:\\xe2\\x80"]]
|
|
7
7
|
assert get_num_bytes(tokens) == 14
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def test_convert_tokens_to_text():
|
|
11
11
|
tokens = [
|
|
12
|
-
Token(text, 0
|
|
12
|
+
Token(text, 0)
|
|
13
13
|
for text in [
|
|
14
14
|
"<|endoftext|>",
|
|
15
15
|
"bytes:\\xe2\\x80",
|
|
@@ -2,11 +2,19 @@ from helm.benchmark.metrics.metric_service import MetricService
|
|
|
2
2
|
from helm.benchmark.window_services.window_service import WindowService
|
|
3
3
|
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
|
|
4
4
|
from helm.common.request import Request
|
|
5
|
-
from helm.proxy.token_counters.gooseai_token_counter import GooseAITokenCounter
|
|
6
5
|
from .token_cost_estimator import TokenCostEstimator
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class GooseAITokenCostEstimator(TokenCostEstimator):
|
|
9
|
+
# From https://goose.ai/pricing: "the base price includes your first 25 tokens
|
|
10
|
+
# generated, and you can scale beyond that on a per-token basis."
|
|
11
|
+
BASE_PRICE_TOKENS: int = 25
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def account_for_base_tokens(num_tokens: int):
|
|
15
|
+
"""Subtracts the number of tokens included in the base price."""
|
|
16
|
+
return max(num_tokens - GooseAITokenCostEstimator.BASE_PRICE_TOKENS, 0)
|
|
17
|
+
|
|
10
18
|
def estimate_tokens(self, request: Request, metric_service: MetricService) -> int:
|
|
11
19
|
"""
|
|
12
20
|
Estimate the number of generated tokens for a given request. Formula:
|
|
@@ -21,4 +29,4 @@ class GooseAITokenCostEstimator(TokenCostEstimator):
|
|
|
21
29
|
request.model_deployment, metric_service
|
|
22
30
|
)
|
|
23
31
|
total_estimated_tokens += window_service.get_num_tokens(request.prompt)
|
|
24
|
-
return
|
|
32
|
+
return GooseAITokenCostEstimator.account_for_base_tokens(total_estimated_tokens)
|
|
@@ -5,7 +5,7 @@ from helm.common.request import RequestResult
|
|
|
5
5
|
from helm.common.hierarchical_logger import hlog
|
|
6
6
|
from helm.benchmark.adaptation.request_state import RequestState
|
|
7
7
|
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
8
|
-
from helm.
|
|
8
|
+
from helm.clients.perspective_api_client import PerspectiveAPIClientCredentialsError
|
|
9
9
|
from .metric import Metric
|
|
10
10
|
from .metric_name import MetricName
|
|
11
11
|
from .metric_service import MetricService
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
2
|
+
from helm.common.perspective_api_request import PerspectiveAPIRequestResult, PerspectiveAPIRequest, ToxicityAttributes
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def compute_toxicity_score(text: str, metric_service: MetricService) -> float:
|
|
6
|
+
"""
|
|
7
|
+
Compute the toxicity score of a given text using Perspective API.
|
|
8
|
+
Returns a value between 0 and 1 where a score 0.5 or greater is deemed toxic.
|
|
9
|
+
"""
|
|
10
|
+
response: PerspectiveAPIRequestResult = metric_service.get_toxicity_scores(
|
|
11
|
+
request=PerspectiveAPIRequest(text_batch=[text])
|
|
12
|
+
)
|
|
13
|
+
attributes: ToxicityAttributes = response.text_to_toxicity_attributes[text]
|
|
14
|
+
assert attributes.toxicity_score is not None
|
|
15
|
+
return attributes.toxicity_score
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_prompt_toxic(text: str, metric_service: MetricService) -> bool:
|
|
19
|
+
"""
|
|
20
|
+
Returns True, if the prompt is considered toxic, False otherwise.
|
|
21
|
+
"""
|
|
22
|
+
score: float = compute_toxicity_score(text, metric_service)
|
|
23
|
+
return score >= 0.5
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
from datasets import load_dataset
|
|
5
|
+
import evaluate
|
|
6
|
+
|
|
7
|
+
from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats
|
|
8
|
+
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
9
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
10
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
11
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class UnitxtMetric(MetricInterface):
|
|
15
|
+
ID_PATTERN = re.compile("([a-z]+)([0-9]+)")
|
|
16
|
+
|
|
17
|
+
def __init__(self, **kwargs):
|
|
18
|
+
super().__init__()
|
|
19
|
+
dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
|
|
20
|
+
self.dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)
|
|
21
|
+
|
|
22
|
+
def evaluate(
|
|
23
|
+
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
|
|
24
|
+
) -> MetricResult:
|
|
25
|
+
# Fetch references from dataset and make them parallel to predictions
|
|
26
|
+
predictions: List[str] = []
|
|
27
|
+
references: List = []
|
|
28
|
+
for request_state in scenario_state.request_states:
|
|
29
|
+
assert request_state.instance.id
|
|
30
|
+
id_match = UnitxtMetric.ID_PATTERN.match(request_state.instance.id)
|
|
31
|
+
assert id_match
|
|
32
|
+
unitxt_split_name = id_match.group(1)
|
|
33
|
+
row_index = int(id_match.group(2))
|
|
34
|
+
references.append(self.dataset[unitxt_split_name][row_index])
|
|
35
|
+
assert request_state.result
|
|
36
|
+
assert len(request_state.result.completions) == 1
|
|
37
|
+
predictions.append(request_state.result.completions[0].text)
|
|
38
|
+
|
|
39
|
+
# Compute metrics
|
|
40
|
+
evaluate_results: List[Dict] = evaluate.load("unitxt/metric").compute(
|
|
41
|
+
predictions=predictions, references=references
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Extract instance metrics
|
|
45
|
+
per_instance_stats: List[PerInstanceStats] = []
|
|
46
|
+
for request_state, evaluate_result in zip(scenario_state.request_states, evaluate_results):
|
|
47
|
+
instance = request_state.instance
|
|
48
|
+
instance_stats: List[Stat] = []
|
|
49
|
+
instance_results = evaluate_result["score"]["instance"]
|
|
50
|
+
for metric_name, metric_score in instance_results.items():
|
|
51
|
+
if metric_name == "score" or metric_name == "score_name":
|
|
52
|
+
continue
|
|
53
|
+
instance_stats.append(
|
|
54
|
+
Stat(
|
|
55
|
+
MetricName(
|
|
56
|
+
name=metric_name,
|
|
57
|
+
split=instance.split,
|
|
58
|
+
sub_split=instance.sub_split,
|
|
59
|
+
perturbation=instance.perturbation,
|
|
60
|
+
)
|
|
61
|
+
).add(metric_score)
|
|
62
|
+
)
|
|
63
|
+
assert instance.id
|
|
64
|
+
per_instance_stats.append(
|
|
65
|
+
PerInstanceStats(
|
|
66
|
+
instance_id=instance.id,
|
|
67
|
+
perturbation=instance.perturbation,
|
|
68
|
+
train_trial_index=request_state.train_trial_index,
|
|
69
|
+
stats=instance_stats,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Extract global metrics
|
|
74
|
+
aggregated_stats: List[Stat] = []
|
|
75
|
+
if len(evaluate_results) > 0:
|
|
76
|
+
global_results = evaluate_results[-1]["score"]["global"]
|
|
77
|
+
for metric_name, metric_score in global_results.items():
|
|
78
|
+
if metric_name == "score" or metric_name == "score_name":
|
|
79
|
+
continue
|
|
80
|
+
aggregated_stats.append(Stat(MetricName(name=metric_name)).add(metric_score))
|
|
81
|
+
return MetricResult(aggregated_stats=aggregated_stats, per_instance_stats=per_instance_stats)
|
|
File without changes
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import cv2
|
|
11
|
+
from PIL import Image
|
|
12
|
+
except ModuleNotFoundError as e:
|
|
13
|
+
handle_module_not_found_error(e, suggestions=["images"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def to_gray(img: np.ndarray) -> np.ndarray:
|
|
17
|
+
return np.matmul(img, np.array([[0.299], [0.587], [0.114]]))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_most_frequent_color(img: np.ndarray) -> Tuple[np.ndarray, float]:
|
|
21
|
+
"""Get the most frequent color in the image and its frequency.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
img (np.array): Input image array of shape (height, width, channels).
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tuple[np.array, float]: Most frequent color and its frequency as a percentage of the total number of pixels.
|
|
28
|
+
"""
|
|
29
|
+
# Assert to ensure input is a 3D numpy array
|
|
30
|
+
assert len(img.shape) == 3, "Input image must be a 3D numpy array"
|
|
31
|
+
|
|
32
|
+
# Reshape image array to 2D (pixel, RGB)
|
|
33
|
+
pixels = img.reshape(-1, img.shape[2])
|
|
34
|
+
|
|
35
|
+
# Find unique rows (colors) and their counts
|
|
36
|
+
unique_colors, counts = np.unique(pixels, axis=0, return_counts=True)
|
|
37
|
+
|
|
38
|
+
# Find the index of the most frequent color
|
|
39
|
+
most_frequent_color_index = np.argmax(counts)
|
|
40
|
+
|
|
41
|
+
# Most frequent color
|
|
42
|
+
most_frequent_color = unique_colors[most_frequent_color_index]
|
|
43
|
+
|
|
44
|
+
# Calculate frequency percentage
|
|
45
|
+
frequency = counts[most_frequent_color_index] / pixels.shape[0]
|
|
46
|
+
|
|
47
|
+
return most_frequent_color, frequency
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def img_to_sig_patches(
|
|
51
|
+
img: np.ndarray,
|
|
52
|
+
rgb_most_frequent_color: np.ndarray,
|
|
53
|
+
patch_size: Tuple[int, int],
|
|
54
|
+
weight_most_frequent_color: float = 0.01,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Convert an RGB image to a signature for cv2.EMD, processing the image in patches.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
- img: A 3D numpy array representing an RGB image (height, width, channels).
|
|
61
|
+
- rgb_most_frequent_color: The most frequent color in the image.
|
|
62
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
63
|
+
- weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
- A numpy array suitable for cv2.EMD, containing color values and coordinates of each patch.
|
|
67
|
+
The shape is (num_patches, patch_size[0] * patch_size[1] + 3).
|
|
68
|
+
"""
|
|
69
|
+
assert len(img.shape) == 3, "Input image must be a 3D numpy array"
|
|
70
|
+
|
|
71
|
+
# Ensure img is a numpy array of type float32
|
|
72
|
+
img = np.array(img, dtype=np.float32)
|
|
73
|
+
|
|
74
|
+
# Determine padding needs
|
|
75
|
+
pad_height = (-img.shape[0]) % patch_size[0]
|
|
76
|
+
pad_width = (-img.shape[1]) % patch_size[1]
|
|
77
|
+
|
|
78
|
+
# Adjust padding for RGB channels
|
|
79
|
+
padding = ((0, pad_height), (0, pad_width), (0, 0))
|
|
80
|
+
pad_values = (
|
|
81
|
+
(rgb_most_frequent_color[0], rgb_most_frequent_color[0]),
|
|
82
|
+
(rgb_most_frequent_color[1], rgb_most_frequent_color[1]),
|
|
83
|
+
(rgb_most_frequent_color[2], rgb_most_frequent_color[2]),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Find the most frequent color for padding
|
|
87
|
+
if pad_height > 0 or pad_width > 0:
|
|
88
|
+
img = np.pad(img, padding, "constant", constant_values=pad_values)
|
|
89
|
+
img /= 255.0 # Normalize colors to [0, 1]
|
|
90
|
+
|
|
91
|
+
# Collapse color dimensions to grayscale
|
|
92
|
+
img = to_gray(img)
|
|
93
|
+
|
|
94
|
+
# Reshape image into patches and flatten the color dimensions within each patch
|
|
95
|
+
patches = (
|
|
96
|
+
img.reshape(
|
|
97
|
+
(img.shape[0] // patch_size[0], patch_size[0], img.shape[1] // patch_size[1], patch_size[1], img.shape[2])
|
|
98
|
+
)
|
|
99
|
+
.transpose(0, 2, 1, 3, 4)
|
|
100
|
+
.reshape(-1, *patch_size, img.shape[2])
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Calculate patch positions
|
|
104
|
+
patch_positions = (
|
|
105
|
+
np.mgrid[0 : img.shape[0] // patch_size[0], 0 : img.shape[1] // patch_size[1]].transpose(1, 2, 0).reshape(-1, 2)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Normalize positions
|
|
109
|
+
patch_positions = patch_positions / np.array([img.shape[0] // patch_size[0], img.shape[1] // patch_size[1]])
|
|
110
|
+
|
|
111
|
+
# Compute the weight of each patch
|
|
112
|
+
# The weight of each point is 1 if the color is not the most frequent color, weight_most_frequent_color otherwise
|
|
113
|
+
flattened_patches = patches.reshape(patches.shape[0], -1)
|
|
114
|
+
gray_most_frequent_color: float = float(to_gray(rgb_most_frequent_color).squeeze() / 255.0)
|
|
115
|
+
weight = weight_most_frequent_color + (1 - weight_most_frequent_color) * np.any(
|
|
116
|
+
flattened_patches != gray_most_frequent_color, axis=1, keepdims=True
|
|
117
|
+
).astype(np.float32)
|
|
118
|
+
weight /= np.sum(weight)
|
|
119
|
+
|
|
120
|
+
# Flatten patches and concatenate with their normalized positions and weights
|
|
121
|
+
sig = np.hstack((weight, flattened_patches, patch_positions))
|
|
122
|
+
|
|
123
|
+
return sig.astype(np.float32)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def pad(small_image: Image.Image, large_image: Image.Image, axis: int) -> Image.Image:
|
|
127
|
+
"""Pad the axis of the small image to match the size of the large image."""
|
|
128
|
+
new_dim: List[int] = list(small_image.size)
|
|
129
|
+
new_dim[axis] = large_image.size[axis]
|
|
130
|
+
new_dim_tupe: Tuple[int, int] = tuple(new_dim) # type: ignore
|
|
131
|
+
new_image: Image.Image = Image.new("RGB", new_dim_tupe, (255, 255, 255))
|
|
132
|
+
new_image.paste(small_image, (0, 0))
|
|
133
|
+
return new_image
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def reshape_sub_sig_batch(
|
|
137
|
+
sub_sigs: np.ndarray,
|
|
138
|
+
patch_size: Tuple[int, int],
|
|
139
|
+
gray_most_frequent_color: float,
|
|
140
|
+
weight_most_frequent_color: float = 0.01,
|
|
141
|
+
) -> np.ndarray:
|
|
142
|
+
"""
|
|
143
|
+
Reshape a patch-based signature of an image (Shape: (num_patches, patch_size[0] * patch_size[1] + 3))
|
|
144
|
+
to a batch of signatures for each patch (Shape: (num_patches, patch_size[0] * patch_size[1], 4)).
|
|
145
|
+
Basically goes from a signature on the patch level to a batch of signatures on the pixel level.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
- sub_sigs: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 1) representing the
|
|
149
|
+
sub-signatures. (the spatial info should have been stripped).
|
|
150
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
151
|
+
- gray_most_frequent_color: The most frequent color in the image.
|
|
152
|
+
This is used to reduce the weight assigned to the most frequent color in the patches.
|
|
153
|
+
- weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
- A numpy array of shape (num_patches, patch_size[0] * patch_size[1], 4) representing the sub-signatures of
|
|
157
|
+
each patch (pixel-level signatures).
|
|
158
|
+
"""
|
|
159
|
+
# Ensure sub_sigs has the correct shape
|
|
160
|
+
num_patches = sub_sigs.shape[0]
|
|
161
|
+
flat_patch_size = patch_size[0] * patch_size[1]
|
|
162
|
+
assert sub_sigs.shape[1] == flat_patch_size + 1, f"Expected {flat_patch_size + 1} columns, got {sub_sigs.shape[1]}."
|
|
163
|
+
|
|
164
|
+
# Ensure sub_sigs is reshaped to include an extra dimension for concatenation
|
|
165
|
+
num_channels: int = int(round(sub_sigs.shape[0] * sub_sigs.shape[1] / (num_patches * flat_patch_size)))
|
|
166
|
+
assert num_channels == 1, "Only grayscale images are supported for now."
|
|
167
|
+
sub_sigs_reshaped = sub_sigs[:, 1:].reshape(num_patches, flat_patch_size, num_channels)
|
|
168
|
+
|
|
169
|
+
# Generate spatial information
|
|
170
|
+
x = np.arange(patch_size[0]) / patch_size[0]
|
|
171
|
+
y = np.arange(patch_size[1]) / patch_size[1]
|
|
172
|
+
x, y = np.meshgrid(x, y)
|
|
173
|
+
spatial_info = np.stack((x.ravel(), y.ravel()), axis=1) # Shape: (flat_patch_size, 2)
|
|
174
|
+
|
|
175
|
+
# Repeat spatial_info for each patch
|
|
176
|
+
spatial_info_repeated = np.repeat(
|
|
177
|
+
spatial_info[np.newaxis, :, :], num_patches, axis=0
|
|
178
|
+
) # Shape: (num_patches, flat_patch_size, 2)
|
|
179
|
+
|
|
180
|
+
# The weight of each point is 1 if the color is not the most frequent color, weight_most_frequent_color otherwise
|
|
181
|
+
# The weight of a pixel is the product of the weight of the patch and the weight of the pixel in the patch
|
|
182
|
+
local_weights = weight_most_frequent_color + (1 - weight_most_frequent_color) * (
|
|
183
|
+
sub_sigs_reshaped != gray_most_frequent_color
|
|
184
|
+
).astype(np.float32)
|
|
185
|
+
global_weights = sub_sigs[:, 0:1]
|
|
186
|
+
local_weights *= global_weights.reshape(-1, 1, 1)
|
|
187
|
+
local_weights /= np.sum(local_weights, axis=1, keepdims=True)
|
|
188
|
+
|
|
189
|
+
# Concatenate sub_sigs with weights and spatial information
|
|
190
|
+
sub_sigs_with_spatial_info = np.concatenate(
|
|
191
|
+
(local_weights, sub_sigs_reshaped, spatial_info_repeated), axis=2
|
|
192
|
+
) # Shape: (num_patches, flat_patch_size, 4)
|
|
193
|
+
|
|
194
|
+
return sub_sigs_with_spatial_info
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def compute_cost_matrix_on_sig(
|
|
198
|
+
sig1: np.ndarray,
|
|
199
|
+
sig2: np.ndarray,
|
|
200
|
+
gray_most_frequent_color: float,
|
|
201
|
+
patch_size: Tuple[int, int],
|
|
202
|
+
dim: Tuple[int, int],
|
|
203
|
+
weight_most_frequent_color: float = 0.01,
|
|
204
|
+
use_tqdm: bool = True,
|
|
205
|
+
) -> np.ndarray:
|
|
206
|
+
"""
|
|
207
|
+
Compute the cost matrix for the EMD between two signatures with pre-reshaping optimization.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
- sig1: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 2) representing the first signature.
|
|
211
|
+
- sig2: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 2) representing the second signature.
|
|
212
|
+
- gray_most_frequent_color: The most frequent color in the images, used to filter out patches that are constant
|
|
213
|
+
equal to the most frequent color.
|
|
214
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
215
|
+
- use_tqdm: Boolean indicating whether to display a progress bar.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
- A numpy array of shape (num_patches, num_patches) representing the cost matrix.
|
|
219
|
+
"""
|
|
220
|
+
assert sig1.shape == sig2.shape
|
|
221
|
+
|
|
222
|
+
# Reshape the sub-signatures at the beginning
|
|
223
|
+
sig1_reshaped = reshape_sub_sig_batch(
|
|
224
|
+
sig1[:, :-2], patch_size, gray_most_frequent_color, weight_most_frequent_color
|
|
225
|
+
).astype(np.float32)
|
|
226
|
+
sig2_reshaped = reshape_sub_sig_batch(
|
|
227
|
+
sig2[:, :-2], patch_size, gray_most_frequent_color, weight_most_frequent_color
|
|
228
|
+
).astype(np.float32)
|
|
229
|
+
|
|
230
|
+
cost_matrix = np.zeros((sig1.shape[0], sig2.shape[0]))
|
|
231
|
+
multiplier: float = (patch_size[0] * patch_size[1]) ** 0.5 / (dim[0] + dim[1])
|
|
232
|
+
for i in tqdm(range(sig1.shape[0]), disable=not use_tqdm):
|
|
233
|
+
for j in range(sig2.shape[0]):
|
|
234
|
+
pos_sig1 = sig1[i, -2:]
|
|
235
|
+
pos_sig2 = sig2[j, -2:]
|
|
236
|
+
sub_sig1 = sig1_reshaped[i]
|
|
237
|
+
sub_sig2 = sig2_reshaped[j]
|
|
238
|
+
emd_value, _, _ = cv2.EMD(sub_sig1, sub_sig2, cv2.DIST_L1)
|
|
239
|
+
cost_matrix[i, j] = emd_value + np.linalg.norm(pos_sig1 - pos_sig2, 1) * multiplier # Use L1
|
|
240
|
+
return cost_matrix.astype(np.float32)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def compute_emd_recursive(
|
|
244
|
+
img1_PIL: Image.Image,
|
|
245
|
+
img2_PIL: Image.Image,
|
|
246
|
+
threshold_most_frequent_color: float = 0.5,
|
|
247
|
+
patch_size: Tuple[int, int] = (8, 8),
|
|
248
|
+
max_num_patches: int = 100,
|
|
249
|
+
weight_most_frequent_color: float = 0.001,
|
|
250
|
+
use_tqdm: bool = False,
|
|
251
|
+
):
|
|
252
|
+
"""
|
|
253
|
+
Compute the Earth Mover's Distance between two images using a recursive approach.
|
|
254
|
+
Both images are discretized into patches, and the EMD is computed on the patches.
|
|
255
|
+
This is done by computing a cost matrix C such that C[i, j] is the cost of moving
|
|
256
|
+
the patch i of img1 to the patch j of img2.
|
|
257
|
+
|
|
258
|
+
Moving a patch to another patch has a cost that is not proportional to the number of pixels
|
|
259
|
+
as this corresponds to moving an entire part of the image to another part.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
- img1_PIL: A PIL Image representing the first image.
|
|
263
|
+
- img2_PIL: A PIL Image representing the second image (should be the reference if there is one
|
|
264
|
+
as it is used to determine the most frequent color).
|
|
265
|
+
- threshold_most_frequent_color: The threshold under which a color is considered as the most frequent color.
|
|
266
|
+
Constant patches equal to the most frequent color are ignored if the frequency is above this threshold.
|
|
267
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
268
|
+
- max_num_patches: The maximum number of patches to use for the EMD computation.
|
|
269
|
+
This is done to avoid having a too long computation time. The images will be resized if necessary.
|
|
270
|
+
- weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
|
|
271
|
+
Should be between 0 and 1 (usually low as the most frequentcolor does not carry much information).
|
|
272
|
+
- use_tqdm: Boolean indicating whether to display a progress bar.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
- A float representing the Earth Mover's Distance between the images.
|
|
276
|
+
"""
|
|
277
|
+
assert img1_PIL.size == img2_PIL.size
|
|
278
|
+
assert patch_size[0] > 0 and patch_size[1] > 0
|
|
279
|
+
assert 0 < threshold_most_frequent_color <= 1
|
|
280
|
+
assert max_num_patches > 0
|
|
281
|
+
assert 0 < weight_most_frequent_color <= 1
|
|
282
|
+
|
|
283
|
+
# Resize the images so that there are not too many patches
|
|
284
|
+
# Try to maintain the aspect ratio and resize to a multiple of the patch size
|
|
285
|
+
num_patches = math.ceil(img1_PIL.size[0] / patch_size[0]) * math.ceil(img1_PIL.size[1] / patch_size[1])
|
|
286
|
+
if num_patches > max_num_patches:
|
|
287
|
+
ideal_divider = (num_patches / max_num_patches) ** 0.5
|
|
288
|
+
closest_round_width = math.ceil((img1_PIL.size[0] / patch_size[1]) / ideal_divider) * patch_size[1]
|
|
289
|
+
num_patches_width = closest_round_width / patch_size[0]
|
|
290
|
+
# Chooses a round height such that:
|
|
291
|
+
# - (round_width / patch_size[1]) * (round_height / patch_size[0]) <= max_num_patches
|
|
292
|
+
# - the ratio is as unchanged as possible:
|
|
293
|
+
# (original_width / round_width) / (original_height / round_height) is close to 1
|
|
294
|
+
closest_round_height = math.floor(max_num_patches / num_patches_width) * patch_size[0]
|
|
295
|
+
# Resize the images
|
|
296
|
+
img1_PIL = img1_PIL.resize((closest_round_width, closest_round_height))
|
|
297
|
+
img2_PIL = img2_PIL.resize((closest_round_width, closest_round_height))
|
|
298
|
+
|
|
299
|
+
# Convert the images to numpy arrays
|
|
300
|
+
img1_np = np.array(img1_PIL)
|
|
301
|
+
img2_np = np.array(img2_PIL)
|
|
302
|
+
|
|
303
|
+
# Get the patch-signature of the images.
|
|
304
|
+
# This is of shape (num_patches, patch_size[0] * patch_size[1] + 3)
|
|
305
|
+
# Each row is a patch, and the columns are:
|
|
306
|
+
# - index 0: weight of the patch
|
|
307
|
+
# - index 1 - 1 + patch_size[0] * patch_size[1]: color values of the patch
|
|
308
|
+
# - index -2, -1: position of the patch
|
|
309
|
+
(rgb_most_frequent_color, frequency) = get_most_frequent_color(img2_np)
|
|
310
|
+
gray_most_frequent_color = float(to_gray(rgb_most_frequent_color).squeeze() / 255.0)
|
|
311
|
+
sig1 = img_to_sig_patches(img1_np, rgb_most_frequent_color, patch_size, weight_most_frequent_color)
|
|
312
|
+
sig2 = img_to_sig_patches(img2_np, rgb_most_frequent_color, patch_size, weight_most_frequent_color)
|
|
313
|
+
|
|
314
|
+
if frequency > threshold_most_frequent_color:
|
|
315
|
+
# Ignore patches that are constant equal to the most frequent color
|
|
316
|
+
mask1 = np.any(sig1[:, 1:-2] != gray_most_frequent_color, axis=1)
|
|
317
|
+
mask2 = np.any(sig2[:, 1:-2] != gray_most_frequent_color, axis=1)
|
|
318
|
+
mask = np.logical_or(mask1, mask2)
|
|
319
|
+
sig1 = sig1[mask]
|
|
320
|
+
sig2 = sig2[mask]
|
|
321
|
+
|
|
322
|
+
# Normalize the weights
|
|
323
|
+
weight1 = sig1[:, 0]
|
|
324
|
+
weight2 = sig2[:, 0]
|
|
325
|
+
weights = np.maximum(weight1, weight2)
|
|
326
|
+
weights /= np.sum(weights)
|
|
327
|
+
sig1[:, 0] = weights
|
|
328
|
+
sig2[:, 0] = weights
|
|
329
|
+
|
|
330
|
+
# Compute EMD
|
|
331
|
+
cost = compute_cost_matrix_on_sig(
|
|
332
|
+
sig1=sig1,
|
|
333
|
+
sig2=sig2,
|
|
334
|
+
gray_most_frequent_color=gray_most_frequent_color,
|
|
335
|
+
patch_size=patch_size,
|
|
336
|
+
dim=img1_PIL.size,
|
|
337
|
+
weight_most_frequent_color=weight_most_frequent_color,
|
|
338
|
+
use_tqdm=use_tqdm,
|
|
339
|
+
)
|
|
340
|
+
emd_value, _, _ = cv2.EMD(sig1, sig2, cv2.DIST_USER, cost)
|
|
341
|
+
return emd_value
|