crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,81 @@
1
+ import re
2
+ from typing import Dict, List
3
+
4
+ from datasets import load_dataset
5
+ import evaluate
6
+
7
+ from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats
8
+ from helm.benchmark.adaptation.scenario_state import ScenarioState
9
+ from helm.benchmark.metrics.metric_name import MetricName
10
+ from helm.benchmark.metrics.metric_service import MetricService
11
+ from helm.benchmark.metrics.statistic import Stat
12
+
13
+
14
+ class UnitxtMetric(MetricInterface):
15
+ ID_PATTERN = re.compile("([a-z]+)([0-9]+)")
16
+
17
+ def __init__(self, **kwargs):
18
+ super().__init__()
19
+ dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
20
+ self.dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)
21
+
22
+ def evaluate(
23
+ self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
24
+ ) -> MetricResult:
25
+ # Fetch references from dataset and make them parallel to predictions
26
+ predictions: List[str] = []
27
+ references: List = []
28
+ for request_state in scenario_state.request_states:
29
+ assert request_state.instance.id
30
+ id_match = UnitxtMetric.ID_PATTERN.match(request_state.instance.id)
31
+ assert id_match
32
+ unitxt_split_name = id_match.group(1)
33
+ row_index = int(id_match.group(2))
34
+ references.append(self.dataset[unitxt_split_name][row_index])
35
+ assert request_state.result
36
+ assert len(request_state.result.completions) == 1
37
+ predictions.append(request_state.result.completions[0].text)
38
+
39
+ # Compute metrics
40
+ evaluate_results: List[Dict] = evaluate.load("unitxt/metric").compute(
41
+ predictions=predictions, references=references
42
+ )
43
+
44
+ # Extract instance metrics
45
+ per_instance_stats: List[PerInstanceStats] = []
46
+ for request_state, evaluate_result in zip(scenario_state.request_states, evaluate_results):
47
+ instance = request_state.instance
48
+ instance_stats: List[Stat] = []
49
+ instance_results = evaluate_result["score"]["instance"]
50
+ for metric_name, metric_score in instance_results.items():
51
+ if metric_name == "score" or metric_name == "score_name":
52
+ continue
53
+ instance_stats.append(
54
+ Stat(
55
+ MetricName(
56
+ name=metric_name,
57
+ split=instance.split,
58
+ sub_split=instance.sub_split,
59
+ perturbation=instance.perturbation,
60
+ )
61
+ ).add(metric_score)
62
+ )
63
+ assert instance.id
64
+ per_instance_stats.append(
65
+ PerInstanceStats(
66
+ instance_id=instance.id,
67
+ perturbation=instance.perturbation,
68
+ train_trial_index=request_state.train_trial_index,
69
+ stats=instance_stats,
70
+ )
71
+ )
72
+
73
+ # Extract global metrics
74
+ aggregated_stats: List[Stat] = []
75
+ if len(evaluate_results) > 0:
76
+ global_results = evaluate_results[-1]["score"]["global"]
77
+ for metric_name, metric_score in global_results.items():
78
+ if metric_name == "score" or metric_name == "score_name":
79
+ continue
80
+ aggregated_stats.append(Stat(MetricName(name=metric_name)).add(metric_score))
81
+ return MetricResult(aggregated_stats=aggregated_stats, per_instance_stats=per_instance_stats)
File without changes
@@ -0,0 +1,341 @@
1
+ from typing import List, Tuple
2
+ from tqdm import tqdm
3
+
4
+ import numpy as np
5
+ import math
6
+
7
+ from helm.common.optional_dependencies import handle_module_not_found_error
8
+
9
+ try:
10
+ import cv2
11
+ from PIL import Image
12
+ except ModuleNotFoundError as e:
13
+ handle_module_not_found_error(e, suggestions=["images"])
14
+
15
+
16
+ def to_gray(img: np.ndarray) -> np.ndarray:
17
+ return np.matmul(img, np.array([[0.299], [0.587], [0.114]]))
18
+
19
+
20
+ def get_most_frequent_color(img: np.ndarray) -> Tuple[np.ndarray, float]:
21
+ """Get the most frequent color in the image and its frequency.
22
+
23
+ Args:
24
+ img (np.array): Input image array of shape (height, width, channels).
25
+
26
+ Returns:
27
+ Tuple[np.array, float]: Most frequent color and its frequency as a percentage of the total number of pixels.
28
+ """
29
+ # Assert to ensure input is a 3D numpy array
30
+ assert len(img.shape) == 3, "Input image must be a 3D numpy array"
31
+
32
+ # Reshape image array to 2D (pixel, RGB)
33
+ pixels = img.reshape(-1, img.shape[2])
34
+
35
+ # Find unique rows (colors) and their counts
36
+ unique_colors, counts = np.unique(pixels, axis=0, return_counts=True)
37
+
38
+ # Find the index of the most frequent color
39
+ most_frequent_color_index = np.argmax(counts)
40
+
41
+ # Most frequent color
42
+ most_frequent_color = unique_colors[most_frequent_color_index]
43
+
44
+ # Calculate frequency percentage
45
+ frequency = counts[most_frequent_color_index] / pixels.shape[0]
46
+
47
+ return most_frequent_color, frequency
48
+
49
+
50
+ def img_to_sig_patches(
51
+ img: np.ndarray,
52
+ rgb_most_frequent_color: np.ndarray,
53
+ patch_size: Tuple[int, int],
54
+ weight_most_frequent_color: float = 0.01,
55
+ ):
56
+ """
57
+ Convert an RGB image to a signature for cv2.EMD, processing the image in patches.
58
+
59
+ Args:
60
+ - img: A 3D numpy array representing an RGB image (height, width, channels).
61
+ - rgb_most_frequent_color: The most frequent color in the image.
62
+ - patch_size: Tuple indicating the height and width of the patches.
63
+ - weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
64
+
65
+ Returns:
66
+ - A numpy array suitable for cv2.EMD, containing color values and coordinates of each patch.
67
+ The shape is (num_patches, patch_size[0] * patch_size[1] + 3).
68
+ """
69
+ assert len(img.shape) == 3, "Input image must be a 3D numpy array"
70
+
71
+ # Ensure img is a numpy array of type float32
72
+ img = np.array(img, dtype=np.float32)
73
+
74
+ # Determine padding needs
75
+ pad_height = (-img.shape[0]) % patch_size[0]
76
+ pad_width = (-img.shape[1]) % patch_size[1]
77
+
78
+ # Adjust padding for RGB channels
79
+ padding = ((0, pad_height), (0, pad_width), (0, 0))
80
+ pad_values = (
81
+ (rgb_most_frequent_color[0], rgb_most_frequent_color[0]),
82
+ (rgb_most_frequent_color[1], rgb_most_frequent_color[1]),
83
+ (rgb_most_frequent_color[2], rgb_most_frequent_color[2]),
84
+ )
85
+
86
+ # Find the most frequent color for padding
87
+ if pad_height > 0 or pad_width > 0:
88
+ img = np.pad(img, padding, "constant", constant_values=pad_values)
89
+ img /= 255.0 # Normalize colors to [0, 1]
90
+
91
+ # Collapse color dimensions to grayscale
92
+ img = to_gray(img)
93
+
94
+ # Reshape image into patches and flatten the color dimensions within each patch
95
+ patches = (
96
+ img.reshape(
97
+ (img.shape[0] // patch_size[0], patch_size[0], img.shape[1] // patch_size[1], patch_size[1], img.shape[2])
98
+ )
99
+ .transpose(0, 2, 1, 3, 4)
100
+ .reshape(-1, *patch_size, img.shape[2])
101
+ )
102
+
103
+ # Calculate patch positions
104
+ patch_positions = (
105
+ np.mgrid[0 : img.shape[0] // patch_size[0], 0 : img.shape[1] // patch_size[1]].transpose(1, 2, 0).reshape(-1, 2)
106
+ )
107
+
108
+ # Normalize positions
109
+ patch_positions = patch_positions / np.array([img.shape[0] // patch_size[0], img.shape[1] // patch_size[1]])
110
+
111
+ # Compute the weight of each patch
112
+ # The weight of each point is 1 if the color is not the most frequent color, weight_most_frequent_color otherwise
113
+ flattened_patches = patches.reshape(patches.shape[0], -1)
114
+ gray_most_frequent_color: float = float(to_gray(rgb_most_frequent_color).squeeze() / 255.0)
115
+ weight = weight_most_frequent_color + (1 - weight_most_frequent_color) * np.any(
116
+ flattened_patches != gray_most_frequent_color, axis=1, keepdims=True
117
+ ).astype(np.float32)
118
+ weight /= np.sum(weight)
119
+
120
+ # Flatten patches and concatenate with their normalized positions and weights
121
+ sig = np.hstack((weight, flattened_patches, patch_positions))
122
+
123
+ return sig.astype(np.float32)
124
+
125
+
126
+ def pad(small_image: Image.Image, large_image: Image.Image, axis: int) -> Image.Image:
127
+ """Pad the axis of the small image to match the size of the large image."""
128
+ new_dim: List[int] = list(small_image.size)
129
+ new_dim[axis] = large_image.size[axis]
130
+ new_dim_tupe: Tuple[int, int] = tuple(new_dim) # type: ignore
131
+ new_image: Image.Image = Image.new("RGB", new_dim_tupe, (255, 255, 255))
132
+ new_image.paste(small_image, (0, 0))
133
+ return new_image
134
+
135
+
136
+ def reshape_sub_sig_batch(
137
+ sub_sigs: np.ndarray,
138
+ patch_size: Tuple[int, int],
139
+ gray_most_frequent_color: float,
140
+ weight_most_frequent_color: float = 0.01,
141
+ ) -> np.ndarray:
142
+ """
143
+ Reshape a patch-based signature of an image (Shape: (num_patches, patch_size[0] * patch_size[1] + 3))
144
+ to a batch of signatures for each patch (Shape: (num_patches, patch_size[0] * patch_size[1], 4)).
145
+ Basically goes from a signature on the patch level to a batch of signatures on the pixel level.
146
+
147
+ Args:
148
+ - sub_sigs: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 1) representing the
149
+ sub-signatures. (the spatial info should have been stripped).
150
+ - patch_size: Tuple indicating the height and width of the patches.
151
+ - gray_most_frequent_color: The most frequent color in the image.
152
+ This is used to reduce the weight assigned to the most frequent color in the patches.
153
+ - weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
154
+
155
+ Returns:
156
+ - A numpy array of shape (num_patches, patch_size[0] * patch_size[1], 4) representing the sub-signatures of
157
+ each patch (pixel-level signatures).
158
+ """
159
+ # Ensure sub_sigs has the correct shape
160
+ num_patches = sub_sigs.shape[0]
161
+ flat_patch_size = patch_size[0] * patch_size[1]
162
+ assert sub_sigs.shape[1] == flat_patch_size + 1, f"Expected {flat_patch_size + 1} columns, got {sub_sigs.shape[1]}."
163
+
164
+ # Ensure sub_sigs is reshaped to include an extra dimension for concatenation
165
+ num_channels: int = int(round(sub_sigs.shape[0] * sub_sigs.shape[1] / (num_patches * flat_patch_size)))
166
+ assert num_channels == 1, "Only grayscale images are supported for now."
167
+ sub_sigs_reshaped = sub_sigs[:, 1:].reshape(num_patches, flat_patch_size, num_channels)
168
+
169
+ # Generate spatial information
170
+ x = np.arange(patch_size[0]) / patch_size[0]
171
+ y = np.arange(patch_size[1]) / patch_size[1]
172
+ x, y = np.meshgrid(x, y)
173
+ spatial_info = np.stack((x.ravel(), y.ravel()), axis=1) # Shape: (flat_patch_size, 2)
174
+
175
+ # Repeat spatial_info for each patch
176
+ spatial_info_repeated = np.repeat(
177
+ spatial_info[np.newaxis, :, :], num_patches, axis=0
178
+ ) # Shape: (num_patches, flat_patch_size, 2)
179
+
180
+ # The weight of each point is 1 if the color is not the most frequent color, weight_most_frequent_color otherwise
181
+ # The weight of a pixel is the product of the weight of the patch and the weight of the pixel in the patch
182
+ local_weights = weight_most_frequent_color + (1 - weight_most_frequent_color) * (
183
+ sub_sigs_reshaped != gray_most_frequent_color
184
+ ).astype(np.float32)
185
+ global_weights = sub_sigs[:, 0:1]
186
+ local_weights *= global_weights.reshape(-1, 1, 1)
187
+ local_weights /= np.sum(local_weights, axis=1, keepdims=True)
188
+
189
+ # Concatenate sub_sigs with weights and spatial information
190
+ sub_sigs_with_spatial_info = np.concatenate(
191
+ (local_weights, sub_sigs_reshaped, spatial_info_repeated), axis=2
192
+ ) # Shape: (num_patches, flat_patch_size, 4)
193
+
194
+ return sub_sigs_with_spatial_info
195
+
196
+
197
+ def compute_cost_matrix_on_sig(
198
+ sig1: np.ndarray,
199
+ sig2: np.ndarray,
200
+ gray_most_frequent_color: float,
201
+ patch_size: Tuple[int, int],
202
+ dim: Tuple[int, int],
203
+ weight_most_frequent_color: float = 0.01,
204
+ use_tqdm: bool = True,
205
+ ) -> np.ndarray:
206
+ """
207
+ Compute the cost matrix for the EMD between two signatures with pre-reshaping optimization.
208
+
209
+ Args:
210
+ - sig1: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 2) representing the first signature.
211
+ - sig2: A numpy array of shape (num_patches, patch_size[0] * patch_size[1] + 2) representing the second signature.
212
+ - gray_most_frequent_color: The most frequent color in the images, used to filter out patches that are constant
213
+ equal to the most frequent color.
214
+ - patch_size: Tuple indicating the height and width of the patches.
215
+ - use_tqdm: Boolean indicating whether to display a progress bar.
216
+
217
+ Returns:
218
+ - A numpy array of shape (num_patches, num_patches) representing the cost matrix.
219
+ """
220
+ assert sig1.shape == sig2.shape
221
+
222
+ # Reshape the sub-signatures at the beginning
223
+ sig1_reshaped = reshape_sub_sig_batch(
224
+ sig1[:, :-2], patch_size, gray_most_frequent_color, weight_most_frequent_color
225
+ ).astype(np.float32)
226
+ sig2_reshaped = reshape_sub_sig_batch(
227
+ sig2[:, :-2], patch_size, gray_most_frequent_color, weight_most_frequent_color
228
+ ).astype(np.float32)
229
+
230
+ cost_matrix = np.zeros((sig1.shape[0], sig2.shape[0]))
231
+ multiplier: float = (patch_size[0] * patch_size[1]) ** 0.5 / (dim[0] + dim[1])
232
+ for i in tqdm(range(sig1.shape[0]), disable=not use_tqdm):
233
+ for j in range(sig2.shape[0]):
234
+ pos_sig1 = sig1[i, -2:]
235
+ pos_sig2 = sig2[j, -2:]
236
+ sub_sig1 = sig1_reshaped[i]
237
+ sub_sig2 = sig2_reshaped[j]
238
+ emd_value, _, _ = cv2.EMD(sub_sig1, sub_sig2, cv2.DIST_L1)
239
+ cost_matrix[i, j] = emd_value + np.linalg.norm(pos_sig1 - pos_sig2, 1) * multiplier # Use L1
240
+ return cost_matrix.astype(np.float32)
241
+
242
+
243
+ def compute_emd_recursive(
244
+ img1_PIL: Image.Image,
245
+ img2_PIL: Image.Image,
246
+ threshold_most_frequent_color: float = 0.5,
247
+ patch_size: Tuple[int, int] = (8, 8),
248
+ max_num_patches: int = 100,
249
+ weight_most_frequent_color: float = 0.001,
250
+ use_tqdm: bool = False,
251
+ ):
252
+ """
253
+ Compute the Earth Mover's Distance between two images using a recursive approach.
254
+ Both images are discretized into patches, and the EMD is computed on the patches.
255
+ This is done by computing a cost matrix C such that C[i, j] is the cost of moving
256
+ the patch i of img1 to the patch j of img2.
257
+
258
+ Moving a patch to another patch has a cost that is not proportional to the number of pixels
259
+ as this corresponds to moving an entire part of the image to another part.
260
+
261
+ Args:
262
+ - img1_PIL: A PIL Image representing the first image.
263
+ - img2_PIL: A PIL Image representing the second image (should be the reference if there is one
264
+ as it is used to determine the most frequent color).
265
+ - threshold_most_frequent_color: The threshold under which a color is considered as the most frequent color.
266
+ Constant patches equal to the most frequent color are ignored if the frequency is above this threshold.
267
+ - patch_size: Tuple indicating the height and width of the patches.
268
+ - max_num_patches: The maximum number of patches to use for the EMD computation.
269
+ This is done to avoid having a too long computation time. The images will be resized if necessary.
270
+ - weight_most_frequent_color: The weight assigned to the most frequent color in the patches.
271
+ Should be between 0 and 1 (usually low as the most frequentcolor does not carry much information).
272
+ - use_tqdm: Boolean indicating whether to display a progress bar.
273
+
274
+ Returns:
275
+ - A float representing the Earth Mover's Distance between the images.
276
+ """
277
+ assert img1_PIL.size == img2_PIL.size
278
+ assert patch_size[0] > 0 and patch_size[1] > 0
279
+ assert 0 < threshold_most_frequent_color <= 1
280
+ assert max_num_patches > 0
281
+ assert 0 < weight_most_frequent_color <= 1
282
+
283
+ # Resize the images so that there are not too many patches
284
+ # Try to maintain the aspect ratio and resize to a multiple of the patch size
285
+ num_patches = math.ceil(img1_PIL.size[0] / patch_size[0]) * math.ceil(img1_PIL.size[1] / patch_size[1])
286
+ if num_patches > max_num_patches:
287
+ ideal_divider = (num_patches / max_num_patches) ** 0.5
288
+ closest_round_width = math.ceil((img1_PIL.size[0] / patch_size[1]) / ideal_divider) * patch_size[1]
289
+ num_patches_width = closest_round_width / patch_size[0]
290
+ # Chooses a round height such that:
291
+ # - (round_width / patch_size[1]) * (round_height / patch_size[0]) <= max_num_patches
292
+ # - the ratio is as unchanged as possible:
293
+ # (original_width / round_width) / (original_height / round_height) is close to 1
294
+ closest_round_height = math.floor(max_num_patches / num_patches_width) * patch_size[0]
295
+ # Resize the images
296
+ img1_PIL = img1_PIL.resize((closest_round_width, closest_round_height))
297
+ img2_PIL = img2_PIL.resize((closest_round_width, closest_round_height))
298
+
299
+ # Convert the images to numpy arrays
300
+ img1_np = np.array(img1_PIL)
301
+ img2_np = np.array(img2_PIL)
302
+
303
+ # Get the patch-signature of the images.
304
+ # This is of shape (num_patches, patch_size[0] * patch_size[1] + 3)
305
+ # Each row is a patch, and the columns are:
306
+ # - index 0: weight of the patch
307
+ # - index 1 - 1 + patch_size[0] * patch_size[1]: color values of the patch
308
+ # - index -2, -1: position of the patch
309
+ (rgb_most_frequent_color, frequency) = get_most_frequent_color(img2_np)
310
+ gray_most_frequent_color = float(to_gray(rgb_most_frequent_color).squeeze() / 255.0)
311
+ sig1 = img_to_sig_patches(img1_np, rgb_most_frequent_color, patch_size, weight_most_frequent_color)
312
+ sig2 = img_to_sig_patches(img2_np, rgb_most_frequent_color, patch_size, weight_most_frequent_color)
313
+
314
+ if frequency > threshold_most_frequent_color:
315
+ # Ignore patches that are constant equal to the most frequent color
316
+ mask1 = np.any(sig1[:, 1:-2] != gray_most_frequent_color, axis=1)
317
+ mask2 = np.any(sig2[:, 1:-2] != gray_most_frequent_color, axis=1)
318
+ mask = np.logical_or(mask1, mask2)
319
+ sig1 = sig1[mask]
320
+ sig2 = sig2[mask]
321
+
322
+ # Normalize the weights
323
+ weight1 = sig1[:, 0]
324
+ weight2 = sig2[:, 0]
325
+ weights = np.maximum(weight1, weight2)
326
+ weights /= np.sum(weights)
327
+ sig1[:, 0] = weights
328
+ sig2[:, 0] = weights
329
+
330
+ # Compute EMD
331
+ cost = compute_cost_matrix_on_sig(
332
+ sig1=sig1,
333
+ sig2=sig2,
334
+ gray_most_frequent_color=gray_most_frequent_color,
335
+ patch_size=patch_size,
336
+ dim=img1_PIL.size,
337
+ weight_most_frequent_color=weight_most_frequent_color,
338
+ use_tqdm=use_tqdm,
339
+ )
340
+ emd_value, _, _ = cv2.EMD(sig1, sig2, cv2.DIST_USER, cost)
341
+ return emd_value