crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
from datasets import load_dataset
|
|
5
|
+
import evaluate
|
|
6
|
+
|
|
7
|
+
from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats
|
|
8
|
+
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
9
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
10
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
11
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class UnitxtMetric(MetricInterface):
|
|
15
|
+
ID_PATTERN = re.compile("([a-z]+)([0-9]+)")
|
|
16
|
+
|
|
17
|
+
def __init__(self, **kwargs):
|
|
18
|
+
super().__init__()
|
|
19
|
+
dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
|
|
20
|
+
self.dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)
|
|
21
|
+
|
|
22
|
+
def evaluate(
|
|
23
|
+
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
|
|
24
|
+
) -> MetricResult:
|
|
25
|
+
# Fetch references from dataset and make them parallel to predictions
|
|
26
|
+
predictions: List[str] = []
|
|
27
|
+
references: List = []
|
|
28
|
+
for request_state in scenario_state.request_states:
|
|
29
|
+
assert request_state.instance.id
|
|
30
|
+
id_match = UnitxtMetric.ID_PATTERN.match(request_state.instance.id)
|
|
31
|
+
assert id_match
|
|
32
|
+
unitxt_split_name = id_match.group(1)
|
|
33
|
+
row_index = int(id_match.group(2))
|
|
34
|
+
references.append(self.dataset[unitxt_split_name][row_index])
|
|
35
|
+
assert request_state.result
|
|
36
|
+
assert len(request_state.result.completions) == 1
|
|
37
|
+
predictions.append(request_state.result.completions[0].text)
|
|
38
|
+
|
|
39
|
+
# Compute metrics
|
|
40
|
+
evaluate_results: List[Dict] = evaluate.load("unitxt/metric").compute(
|
|
41
|
+
predictions=predictions, references=references
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Extract instance metrics
|
|
45
|
+
per_instance_stats: List[PerInstanceStats] = []
|
|
46
|
+
for request_state, evaluate_result in zip(scenario_state.request_states, evaluate_results):
|
|
47
|
+
instance = request_state.instance
|
|
48
|
+
instance_stats: List[Stat] = []
|
|
49
|
+
instance_results = evaluate_result["score"]["instance"]
|
|
50
|
+
for metric_name, metric_score in instance_results.items():
|
|
51
|
+
if metric_name == "score" or metric_name == "score_name":
|
|
52
|
+
continue
|
|
53
|
+
instance_stats.append(
|
|
54
|
+
Stat(
|
|
55
|
+
MetricName(
|
|
56
|
+
name=metric_name,
|
|
57
|
+
split=instance.split,
|
|
58
|
+
sub_split=instance.sub_split,
|
|
59
|
+
perturbation=instance.perturbation,
|
|
60
|
+
)
|
|
61
|
+
).add(metric_score)
|
|
62
|
+
)
|
|
63
|
+
assert instance.id
|
|
64
|
+
per_instance_stats.append(
|
|
65
|
+
PerInstanceStats(
|
|
66
|
+
instance_id=instance.id,
|
|
67
|
+
perturbation=instance.perturbation,
|
|
68
|
+
train_trial_index=request_state.train_trial_index,
|
|
69
|
+
stats=instance_stats,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Extract global metrics
|
|
74
|
+
aggregated_stats: List[Stat] = []
|
|
75
|
+
if len(evaluate_results) > 0:
|
|
76
|
+
global_results = evaluate_results[-1]["score"]["global"]
|
|
77
|
+
for metric_name, metric_score in global_results.items():
|
|
78
|
+
if metric_name == "score" or metric_name == "score_name":
|
|
79
|
+
continue
|
|
80
|
+
aggregated_stats.append(Stat(MetricName(name=metric_name)).add(metric_score))
|
|
81
|
+
return MetricResult(aggregated_stats=aggregated_stats, per_instance_stats=per_instance_stats)
|
|
File without changes
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import cv2
|
|
11
|
+
from PIL import Image
|
|
12
|
+
except ModuleNotFoundError as e:
|
|
13
|
+
handle_module_not_found_error(e, suggestions=["images"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def to_gray(img: np.ndarray) -> np.ndarray:
|
|
17
|
+
return np.matmul(img, np.array([[0.299], [0.587], [0.114]]))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_most_frequent_color(img: np.ndarray) -> Tuple[np.ndarray, float]:
|
|
21
|
+
"""Get the most frequent color in the image and its frequency.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
img (np.array): Input image array of shape (height, width, channels).
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tuple[np.array, float]: Most frequent color and its frequency as a percentage of the total number of pixels.
|
|
28
|
+
"""
|
|
29
|
+
# Assert to ensure input is a 3D numpy array
|
|
30
|
+
assert len(img.shape) == 3, "Input image must be a 3D numpy array"
|
|
31
|
+
|
|
32
|
+
# Reshape image array to 2D (pixel, RGB)
|
|
33
|
+
pixels = img.reshape(-1, img.shape[2])
|
|
34
|
+
|
|
35
|
+
# Find unique rows (colors) and their counts
|
|
36
|
+
unique_colors, counts = np.unique(pixels, axis=0, return_counts=True)
|
|
37
|
+
|
|
38
|
+
# Find the index of the most frequent color
|
|
39
|
+
most_frequent_color_index = np.argmax(counts)
|
|
40
|
+
|
|
41
|
+
# Most frequent color
|
|
42
|
+
most_frequent_color = unique_colors[most_frequent_color_index]
|
|
43
|
+
|
|
44
|
+
# Calculate frequency percentage
|
|
45
|
+
frequency = counts[most_frequent_color_index] / pixels.shape[0]
|
|
46
|
+
|
|
47
|
+
return most_frequent_color, frequency
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def img_to_sig_patches(
|
|
51
|
+
img: np.ndarray,
|
|
52
|
+
rgb_most_frequent_color: np.ndarray,
|
|
53
|
+
patch_size: Tuple[int, int],
|
|
54
|
+
weight_most_frequent_color: float = 0.01,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Convert an RGB image to a signature for cv2.EMD, processing the image in patches.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
- img: A 3D numpy array representing an RGB image (height, width, channels).
|
|
61
|
+
- rgb_most_frequent_color: The most frequent color in the image.
|
|
62
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
63
|
+
- weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
- A numpy array suitable for cv2.EMD, containing color values and coordinates of each patch.
|
|
67
|
+
The shape is (num_patches, patch_size[0] * patch_size[1] + 3).
|
|
68
|
+
"""
|
|
69
|
+
assert len(img.shape) == 3, "Input image must be a 3D numpy array"
|
|
70
|
+
|
|
71
|
+
# Ensure img is a numpy array of type float32
|
|
72
|
+
img = np.array(img, dtype=np.float32)
|
|
73
|
+
|
|
74
|
+
# Determine padding needs
|
|
75
|
+
pad_height = (-img.shape[0]) % patch_size[0]
|
|
76
|
+
pad_width = (-img.shape[1]) % patch_size[1]
|
|
77
|
+
|
|
78
|
+
# Adjust padding for RGB channels
|
|
79
|
+
padding = ((0, pad_height), (0, pad_width), (0, 0))
|
|
80
|
+
pad_values = (
|
|
81
|
+
(rgb_most_frequent_color[0], rgb_most_frequent_color[0]),
|
|
82
|
+
(rgb_most_frequent_color[1], rgb_most_frequent_color[1]),
|
|
83
|
+
(rgb_most_frequent_color[2], rgb_most_frequent_color[2]),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Find the most frequent color for padding
|
|
87
|
+
if pad_height > 0 or pad_width > 0:
|
|
88
|
+
img = np.pad(img, padding, "constant", constant_values=pad_values)
|
|
89
|
+
img /= 255.0 # Normalize colors to [0, 1]
|
|
90
|
+
|
|
91
|
+
# Collapse color dimensions to grayscale
|
|
92
|
+
img = to_gray(img)
|
|
93
|
+
|
|
94
|
+
# Reshape image into patches and flatten the color dimensions within each patch
|
|
95
|
+
patches = (
|
|
96
|
+
img.reshape(
|
|
97
|
+
(img.shape[0] // patch_size[0], patch_size[0], img.shape[1] // patch_size[1], patch_size[1], img.shape[2])
|
|
98
|
+
)
|
|
99
|
+
.transpose(0, 2, 1, 3, 4)
|
|
100
|
+
.reshape(-1, *patch_size, img.shape[2])
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Calculate patch positions
|
|
104
|
+
patch_positions = (
|
|
105
|
+
np.mgrid[0 : img.shape[0] // patch_size[0], 0 : img.shape[1] // patch_size[1]].transpose(1, 2, 0).reshape(-1, 2)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Normalize positions
|
|
109
|
+
patch_positions = patch_positions / np.array([img.shape[0] // patch_size[0], img.shape[1] // patch_size[1]])
|
|
110
|
+
|
|
111
|
+
# Compute the weight of each patch
|
|
112
|
+
# The weight of each point is 1 if the color is not the most frequent color, weight_most_frequent_color otherwise
|
|
113
|
+
flattened_patches = patches.reshape(patches.shape[0], -1)
|
|
114
|
+
gray_most_frequent_color: float = float(to_gray(rgb_most_frequent_color).squeeze() / 255.0)
|
|
115
|
+
weight = weight_most_frequent_color + (1 - weight_most_frequent_color) * np.any(
|
|
116
|
+
flattened_patches != gray_most_frequent_color, axis=1, keepdims=True
|
|
117
|
+
).astype(np.float32)
|
|
118
|
+
weight /= np.sum(weight)
|
|
119
|
+
|
|
120
|
+
# Flatten patches and concatenate with their normalized positions and weights
|
|
121
|
+
sig = np.hstack((weight, flattened_patches, patch_positions))
|
|
122
|
+
|
|
123
|
+
return sig.astype(np.float32)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def pad(small_image: Image.Image, large_image: Image.Image, axis: int) -> Image.Image:
|
|
127
|
+
"""Pad the axis of the small image to match the size of the large image."""
|
|
128
|
+
new_dim: List[int] = list(small_image.size)
|
|
129
|
+
new_dim[axis] = large_image.size[axis]
|
|
130
|
+
new_dim_tupe: Tuple[int, int] = tuple(new_dim) # type: ignore
|
|
131
|
+
new_image: Image.Image = Image.new("RGB", new_dim_tupe, (255, 255, 255))
|
|
132
|
+
new_image.paste(small_image, (0, 0))
|
|
133
|
+
return new_image
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def reshape_sub_sig_batch(
|
|
137
|
+
sub_sigs: np.ndarray,
|
|
138
|
+
patch_size: Tuple[int, int],
|
|
139
|
+
gray_most_frequent_color: float,
|
|
140
|
+
weight_most_frequent_color: float = 0.01,
|
|
141
|
+
) -> np.ndarray:
|
|
142
|
+
"""
|
|
143
|
+
Reshape a patch-based signature of an image (Shape: (num_patches, patch_size[0] * patch_size[1] + 3))
|
|
144
|
+
to a batch of signatures for each patch (Shape: (num_patches, patch_size[0] * patch_size[1], 4)).
|
|
145
|
+
Basically goes from a signature on the patch level to a batch of signatures on the pixel level.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
- sub_sigs: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 1) representing the
|
|
149
|
+
sub-signatures. (the spatial info should have been stripped).
|
|
150
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
151
|
+
- gray_most_frequent_color: The most frequent color in the image.
|
|
152
|
+
This is used to reduce the weight assigned to the most frequent color in the patches.
|
|
153
|
+
- weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
- A numpy array of shape (num_patches, patch_size[0] * patch_size[1], 4) representing the sub-signatures of
|
|
157
|
+
each patch (pixel-level signatures).
|
|
158
|
+
"""
|
|
159
|
+
# Ensure sub_sigs has the correct shape
|
|
160
|
+
num_patches = sub_sigs.shape[0]
|
|
161
|
+
flat_patch_size = patch_size[0] * patch_size[1]
|
|
162
|
+
assert sub_sigs.shape[1] == flat_patch_size + 1, f"Expected {flat_patch_size + 1} columns, got {sub_sigs.shape[1]}."
|
|
163
|
+
|
|
164
|
+
# Ensure sub_sigs is reshaped to include an extra dimension for concatenation
|
|
165
|
+
num_channels: int = int(round(sub_sigs.shape[0] * sub_sigs.shape[1] / (num_patches * flat_patch_size)))
|
|
166
|
+
assert num_channels == 1, "Only grayscale images are supported for now."
|
|
167
|
+
sub_sigs_reshaped = sub_sigs[:, 1:].reshape(num_patches, flat_patch_size, num_channels)
|
|
168
|
+
|
|
169
|
+
# Generate spatial information
|
|
170
|
+
x = np.arange(patch_size[0]) / patch_size[0]
|
|
171
|
+
y = np.arange(patch_size[1]) / patch_size[1]
|
|
172
|
+
x, y = np.meshgrid(x, y)
|
|
173
|
+
spatial_info = np.stack((x.ravel(), y.ravel()), axis=1) # Shape: (flat_patch_size, 2)
|
|
174
|
+
|
|
175
|
+
# Repeat spatial_info for each patch
|
|
176
|
+
spatial_info_repeated = np.repeat(
|
|
177
|
+
spatial_info[np.newaxis, :, :], num_patches, axis=0
|
|
178
|
+
) # Shape: (num_patches, flat_patch_size, 2)
|
|
179
|
+
|
|
180
|
+
# The weight of each point is 1 if the color is not the most frequent color, weight_most_frequent_color otherwise
|
|
181
|
+
# The weight of a pixel is the product of the weight of the patch and the weight of the pixel in the patch
|
|
182
|
+
local_weights = weight_most_frequent_color + (1 - weight_most_frequent_color) * (
|
|
183
|
+
sub_sigs_reshaped != gray_most_frequent_color
|
|
184
|
+
).astype(np.float32)
|
|
185
|
+
global_weights = sub_sigs[:, 0:1]
|
|
186
|
+
local_weights *= global_weights.reshape(-1, 1, 1)
|
|
187
|
+
local_weights /= np.sum(local_weights, axis=1, keepdims=True)
|
|
188
|
+
|
|
189
|
+
# Concatenate sub_sigs with weights and spatial information
|
|
190
|
+
sub_sigs_with_spatial_info = np.concatenate(
|
|
191
|
+
(local_weights, sub_sigs_reshaped, spatial_info_repeated), axis=2
|
|
192
|
+
) # Shape: (num_patches, flat_patch_size, 4)
|
|
193
|
+
|
|
194
|
+
return sub_sigs_with_spatial_info
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def compute_cost_matrix_on_sig(
|
|
198
|
+
sig1: np.ndarray,
|
|
199
|
+
sig2: np.ndarray,
|
|
200
|
+
gray_most_frequent_color: float,
|
|
201
|
+
patch_size: Tuple[int, int],
|
|
202
|
+
dim: Tuple[int, int],
|
|
203
|
+
weight_most_frequent_color: float = 0.01,
|
|
204
|
+
use_tqdm: bool = True,
|
|
205
|
+
) -> np.ndarray:
|
|
206
|
+
"""
|
|
207
|
+
Compute the cost matrix for the EMD between two signatures with pre-reshaping optimization.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
- sig1: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 2) representing the first signature.
|
|
211
|
+
- sig2: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 2) representing the second signature.
|
|
212
|
+
- gray_most_frequent_color: The most frequent color in the images, used to filter out patches that are constant
|
|
213
|
+
equal to the most frequent color.
|
|
214
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
215
|
+
- use_tqdm: Boolean indicating whether to display a progress bar.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
- A numpy array of shape (num_patches, num_patches) representing the cost matrix.
|
|
219
|
+
"""
|
|
220
|
+
assert sig1.shape == sig2.shape
|
|
221
|
+
|
|
222
|
+
# Reshape the sub-signatures at the beginning
|
|
223
|
+
sig1_reshaped = reshape_sub_sig_batch(
|
|
224
|
+
sig1[:, :-2], patch_size, gray_most_frequent_color, weight_most_frequent_color
|
|
225
|
+
).astype(np.float32)
|
|
226
|
+
sig2_reshaped = reshape_sub_sig_batch(
|
|
227
|
+
sig2[:, :-2], patch_size, gray_most_frequent_color, weight_most_frequent_color
|
|
228
|
+
).astype(np.float32)
|
|
229
|
+
|
|
230
|
+
cost_matrix = np.zeros((sig1.shape[0], sig2.shape[0]))
|
|
231
|
+
multiplier: float = (patch_size[0] * patch_size[1]) ** 0.5 / (dim[0] + dim[1])
|
|
232
|
+
for i in tqdm(range(sig1.shape[0]), disable=not use_tqdm):
|
|
233
|
+
for j in range(sig2.shape[0]):
|
|
234
|
+
pos_sig1 = sig1[i, -2:]
|
|
235
|
+
pos_sig2 = sig2[j, -2:]
|
|
236
|
+
sub_sig1 = sig1_reshaped[i]
|
|
237
|
+
sub_sig2 = sig2_reshaped[j]
|
|
238
|
+
emd_value, _, _ = cv2.EMD(sub_sig1, sub_sig2, cv2.DIST_L1)
|
|
239
|
+
cost_matrix[i, j] = emd_value + np.linalg.norm(pos_sig1 - pos_sig2, 1) * multiplier # Use L1
|
|
240
|
+
return cost_matrix.astype(np.float32)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def compute_emd_recursive(
|
|
244
|
+
img1_PIL: Image.Image,
|
|
245
|
+
img2_PIL: Image.Image,
|
|
246
|
+
threshold_most_frequent_color: float = 0.5,
|
|
247
|
+
patch_size: Tuple[int, int] = (8, 8),
|
|
248
|
+
max_num_patches: int = 100,
|
|
249
|
+
weight_most_frequent_color: float = 0.001,
|
|
250
|
+
use_tqdm: bool = False,
|
|
251
|
+
):
|
|
252
|
+
"""
|
|
253
|
+
Compute the Earth Mover's Distance between two images using a recursive approach.
|
|
254
|
+
Both images are discretized into patches, and the EMD is computed on the patches.
|
|
255
|
+
This is done by computing a cost matrix C such that C[i, j] is the cost of moving
|
|
256
|
+
the patch i of img1 to the patch j of img2.
|
|
257
|
+
|
|
258
|
+
Moving a patch to another patch has a cost that is not proportional to the number of pixels
|
|
259
|
+
as this corresponds to moving an entire part of the image to another part.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
- img1_PIL: A PIL Image representing the first image.
|
|
263
|
+
- img2_PIL: A PIL Image representing the second image (should be the reference if there is one
|
|
264
|
+
as it is used to determine the most frequent color).
|
|
265
|
+
- threshold_most_frequent_color: The threshold under which a color is considered as the most frequent color.
|
|
266
|
+
Constant patches equal to the most frequent color are ignored if the frequency is above this threshold.
|
|
267
|
+
- patch_size: Tuple indicating the height and width of the patches.
|
|
268
|
+
- max_num_patches: The maximum number of patches to use for the EMD computation.
|
|
269
|
+
This is done to avoid having a too long computation time. The images will be resized if necessary.
|
|
270
|
+
- weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
|
|
271
|
+
Should be between 0 and 1 (usually low as the most frequentcolor does not carry much information).
|
|
272
|
+
- use_tqdm: Boolean indicating whether to display a progress bar.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
- A float representing the Earth Mover's Distance between the images.
|
|
276
|
+
"""
|
|
277
|
+
assert img1_PIL.size == img2_PIL.size
|
|
278
|
+
assert patch_size[0] > 0 and patch_size[1] > 0
|
|
279
|
+
assert 0 < threshold_most_frequent_color <= 1
|
|
280
|
+
assert max_num_patches > 0
|
|
281
|
+
assert 0 < weight_most_frequent_color <= 1
|
|
282
|
+
|
|
283
|
+
# Resize the images so that there are not too many patches
|
|
284
|
+
# Try to maintain the aspect ratio and resize to a multiple of the patch size
|
|
285
|
+
num_patches = math.ceil(img1_PIL.size[0] / patch_size[0]) * math.ceil(img1_PIL.size[1] / patch_size[1])
|
|
286
|
+
if num_patches > max_num_patches:
|
|
287
|
+
ideal_divider = (num_patches / max_num_patches) ** 0.5
|
|
288
|
+
closest_round_width = math.ceil((img1_PIL.size[0] / patch_size[1]) / ideal_divider) * patch_size[1]
|
|
289
|
+
num_patches_width = closest_round_width / patch_size[0]
|
|
290
|
+
# Chooses a round height such that:
|
|
291
|
+
# - (round_width / patch_size[1]) * (round_height / patch_size[0]) <= max_num_patches
|
|
292
|
+
# - the ratio is as unchanged as possible:
|
|
293
|
+
# (original_width / round_width) / (original_height / round_height) is close to 1
|
|
294
|
+
closest_round_height = math.floor(max_num_patches / num_patches_width) * patch_size[0]
|
|
295
|
+
# Resize the images
|
|
296
|
+
img1_PIL = img1_PIL.resize((closest_round_width, closest_round_height))
|
|
297
|
+
img2_PIL = img2_PIL.resize((closest_round_width, closest_round_height))
|
|
298
|
+
|
|
299
|
+
# Convert the images to numpy arrays
|
|
300
|
+
img1_np = np.array(img1_PIL)
|
|
301
|
+
img2_np = np.array(img2_PIL)
|
|
302
|
+
|
|
303
|
+
# Get the patch-signature of the images.
|
|
304
|
+
# This is of shape (num_patches, patch_size[0] * patch_size[1] + 3)
|
|
305
|
+
# Each row is a patch, and the columns are:
|
|
306
|
+
# - index 0: weight of the patch
|
|
307
|
+
# - index 1 - 1 + patch_size[0] * patch_size[1]: color values of the patch
|
|
308
|
+
# - index -2, -1: position of the patch
|
|
309
|
+
(rgb_most_frequent_color, frequency) = get_most_frequent_color(img2_np)
|
|
310
|
+
gray_most_frequent_color = float(to_gray(rgb_most_frequent_color).squeeze() / 255.0)
|
|
311
|
+
sig1 = img_to_sig_patches(img1_np, rgb_most_frequent_color, patch_size, weight_most_frequent_color)
|
|
312
|
+
sig2 = img_to_sig_patches(img2_np, rgb_most_frequent_color, patch_size, weight_most_frequent_color)
|
|
313
|
+
|
|
314
|
+
if frequency > threshold_most_frequent_color:
|
|
315
|
+
# Ignore patches that are constant equal to the most frequent color
|
|
316
|
+
mask1 = np.any(sig1[:, 1:-2] != gray_most_frequent_color, axis=1)
|
|
317
|
+
mask2 = np.any(sig2[:, 1:-2] != gray_most_frequent_color, axis=1)
|
|
318
|
+
mask = np.logical_or(mask1, mask2)
|
|
319
|
+
sig1 = sig1[mask]
|
|
320
|
+
sig2 = sig2[mask]
|
|
321
|
+
|
|
322
|
+
# Normalize the weights
|
|
323
|
+
weight1 = sig1[:, 0]
|
|
324
|
+
weight2 = sig2[:, 0]
|
|
325
|
+
weights = np.maximum(weight1, weight2)
|
|
326
|
+
weights /= np.sum(weights)
|
|
327
|
+
sig1[:, 0] = weights
|
|
328
|
+
sig2[:, 0] = weights
|
|
329
|
+
|
|
330
|
+
# Compute EMD
|
|
331
|
+
cost = compute_cost_matrix_on_sig(
|
|
332
|
+
sig1=sig1,
|
|
333
|
+
sig2=sig2,
|
|
334
|
+
gray_most_frequent_color=gray_most_frequent_color,
|
|
335
|
+
patch_size=patch_size,
|
|
336
|
+
dim=img1_PIL.size,
|
|
337
|
+
weight_most_frequent_color=weight_most_frequent_color,
|
|
338
|
+
use_tqdm=use_tqdm,
|
|
339
|
+
)
|
|
340
|
+
emd_value, _, _ = cv2.EMD(sig1, sig2, cv2.DIST_USER, cost)
|
|
341
|
+
return emd_value
|