crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,190 @@
1
+ from typing import Any, Dict, List
2
+
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from helm.common.cache import CacheConfig, Cache
7
+ from helm.common.file_caches.file_cache import FileCache
8
+ from helm.common.hierarchical_logger import hlog, htrack_block
9
+ from helm.common.optional_dependencies import handle_module_not_found_error
10
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
11
+ from helm.common.tokenization_request import (
12
+ DecodeRequest,
13
+ DecodeRequestResult,
14
+ TokenizationRequest,
15
+ TokenizationRequestResult,
16
+ )
17
+ from helm.clients.client import Client, CachingClient
18
+ from .image_generation_client_utils import get_single_image_multimedia_object
19
+
20
+
21
+ class DALLEMiniClient(Client):
22
+ """
23
+ Source: https://github.com/borisdayma/dalle-mini, https://github.com/patil-suraj/vqgan-jax
24
+ """
25
+
26
+ VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
27
+ VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
28
+
29
+ def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
30
+ self._cache = Cache(cache_config)
31
+ self._file_cache: FileCache = file_cache
32
+
33
+ self._model_engine_to_model = {}
34
+
35
+ def _get_model(self, model_engine: str):
36
+ """
37
+ Initialize the model based on the model name.
38
+ Cache the model, so it doesn't get reinitialize for a new request.
39
+ """
40
+ try:
41
+ import jax.numpy as jnp
42
+ from flax.jax_utils import replicate
43
+
44
+ from helm.clients.image_generation.dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
45
+ from helm.clients.image_generation.dalle_mini import DalleBart, DalleBartProcessor
46
+ except ModuleNotFoundError as e:
47
+ handle_module_not_found_error(e, ["heim"])
48
+
49
+ if model_engine not in self._model_engine_to_model:
50
+ model_name: str
51
+ if model_engine == "dalle-mini":
52
+ model_name = "dalle-mini/dalle-mini/mini-1:v0"
53
+ elif model_engine == "dalle-mega":
54
+ model_name = "dalle-mini/dalle-mini/mega-1-fp16:latest"
55
+ else:
56
+ raise ValueError(f"Unhandled model: {model_engine}")
57
+
58
+ model, params = DalleBart.from_pretrained(model_name, revision=None, dtype=jnp.float16, _do_init=False)
59
+ processor = DalleBartProcessor.from_pretrained(model_name, revision=None)
60
+ vqgan, vqgan_params = VQModel.from_pretrained(
61
+ self.VQGAN_REPO, revision=self.VQGAN_COMMIT_ID, _do_init=False
62
+ )
63
+ params = replicate(params)
64
+ vqgan_params = replicate(vqgan_params)
65
+ self._model_engine_to_model[model_engine] = [model, params, processor, vqgan, vqgan_params]
66
+ return self._model_engine_to_model[model_engine]
67
+
68
+ def make_request(self, request: Request) -> RequestResult:
69
+ try:
70
+ import jax
71
+ from flax.training.common_utils import shard_prng_key
72
+ from flax.jax_utils import replicate
73
+ from PIL import Image
74
+ except ModuleNotFoundError as e:
75
+ handle_module_not_found_error(e, ["heim"])
76
+
77
+ raw_request = {
78
+ "prompt": request.prompt,
79
+ "top_k": None,
80
+ "top_p": None,
81
+ "temperature": None,
82
+ "condition_scale": 10.0,
83
+ }
84
+
85
+ try:
86
+
87
+ def _inference(
88
+ model, params, vqgan, vqgan_params, tokenized_prompt, subkey, top_k, top_p, temperature, condition_scale
89
+ ):
90
+ @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
91
+ def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
92
+ return model.generate(
93
+ **tokenized_prompt,
94
+ prng_key=key,
95
+ params=params,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ temperature=temperature,
99
+ condition_scale=condition_scale,
100
+ )
101
+
102
+ @partial(jax.pmap, axis_name="batch")
103
+ def p_decode(indices, params):
104
+ return vqgan.decode_code(indices, params=params)
105
+
106
+ # generate images
107
+ encoded_images = p_generate(
108
+ tokenized_prompt,
109
+ shard_prng_key(subkey),
110
+ params,
111
+ top_k,
112
+ top_p,
113
+ temperature,
114
+ condition_scale,
115
+ )
116
+ # remove BOS
117
+ encoded_images = encoded_images.sequences[..., 1:]
118
+ # decode images
119
+ decoded_images = p_decode(encoded_images, vqgan_params)
120
+ decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
121
+ return decoded_images
122
+
123
+ def do_it() -> Dict[str, Any]:
124
+ prompt: str = request.prompt
125
+
126
+ with htrack_block(f"Generating images for prompt: {prompt}"):
127
+ model, params, processor, vqgan, vqgan_params = self._get_model(request.model_engine)
128
+ tokenized_prompts = processor([prompt])
129
+ tokenized_prompt = replicate(tokenized_prompts)
130
+
131
+ images: List[Image] = []
132
+ key = jax.random.PRNGKey(0)
133
+ for _ in range(request.num_completions):
134
+ key, subkey = jax.random.split(key)
135
+ image = _inference(
136
+ model,
137
+ params,
138
+ vqgan,
139
+ vqgan_params,
140
+ tokenized_prompt,
141
+ subkey,
142
+ raw_request["top_k"],
143
+ raw_request["top_p"],
144
+ raw_request["temperature"],
145
+ raw_request["condition_scale"],
146
+ )[0]
147
+ image = Image.fromarray(np.asarray(image * 255, dtype=np.uint8))
148
+ images.append(image)
149
+
150
+ assert (
151
+ len(images) == request.num_completions
152
+ ), f"Expected {request.num_completions} images, but got {len(images)}"
153
+
154
+ result = {"file_locations": []}
155
+ for image in images:
156
+ # Write out the image to a file and save the path
157
+ file_location: str = self._file_cache.get_unique_file_location()
158
+ image.save(file_location)
159
+ hlog(f"Image saved at {file_location}.")
160
+ result["file_locations"].append(file_location)
161
+ return result
162
+
163
+ # Include the model name and number of completions in the cache key
164
+ cache_key = CachingClient.make_cache_key(
165
+ {"model": request.model_engine, "n": request.num_completions, **raw_request}, request
166
+ )
167
+ results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
168
+ except RuntimeError as e:
169
+ error: str = f"DALLEMiniClient error: {e}"
170
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
171
+
172
+ completions: List[GeneratedOutput] = [
173
+ GeneratedOutput(
174
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_location)
175
+ )
176
+ for file_location in results["file_locations"]
177
+ ]
178
+ return RequestResult(
179
+ success=True,
180
+ cached=cached,
181
+ request_time=results["request_time"],
182
+ completions=completions,
183
+ embedding=[],
184
+ )
185
+
186
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
187
+ raise NotImplementedError("This client does not support tokenizing.")
188
+
189
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
190
+ raise NotImplementedError("This client does not support decoding.")
@@ -0,0 +1,78 @@
1
+ from typing import List, Dict
2
+
3
+ from helm.common.cache import Cache, CacheConfig
4
+ from helm.common.request import Request, RequestResult, GeneratedOutput
5
+ from helm.common.tokenization_request import (
6
+ TokenizationRequest,
7
+ TokenizationRequestResult,
8
+ DecodeRequest,
9
+ DecodeRequestResult,
10
+ )
11
+ from helm.clients.client import Client, CachingClient
12
+ from .image_generation_client_utils import get_single_image_multimedia_object
13
+
14
+
15
+ class DeepFloydClient(Client):
16
+ """
17
+ Client for [DeepFloyd image generation models](https://huggingface.co/docs/diffusers/v0.16.0/api/pipelines/ifs).
18
+ We rely on offline eval for now due to conflicting dependencies (e.g., Transformers).
19
+ """
20
+
21
+ SUPPORTED_MODELS: List[str] = ["IF-I-M-v1.0", "IF-I-L-v1.0", "IF-I-XL-v1.0"]
22
+
23
+ @staticmethod
24
+ def convert_to_raw_request(request: Request) -> Dict:
25
+ # Use default hyperparameters for everything else
26
+ raw_request: Dict = {
27
+ "model": request.model_engine,
28
+ "n": request.num_completions,
29
+ "prompt": request.prompt,
30
+ "request_type": "image-model-inference",
31
+ }
32
+ if request.random is not None:
33
+ raw_request["random"] = request.random
34
+ return raw_request
35
+
36
+ def __init__(self, cache_config: CacheConfig):
37
+ self._cache = Cache(cache_config)
38
+ self._promptist_model = None
39
+ self._promptist_tokenizer = None
40
+
41
+ def make_request(self, request: Request) -> RequestResult:
42
+ if request.model_engine not in self.SUPPORTED_MODELS:
43
+ raise ValueError(f"Unsupported model: {request.model_engine}")
44
+
45
+ raw_request = DeepFloydClient.convert_to_raw_request(request)
46
+ cache_key = CachingClient.make_cache_key(raw_request, request)
47
+
48
+ try:
49
+
50
+ def fail():
51
+ raise RuntimeError(
52
+ f"The result has not been uploaded to the cache for the following request: {cache_key}"
53
+ )
54
+
55
+ response, cached = self._cache.get(cache_key, fail)
56
+ except RuntimeError as e:
57
+ error: str = f"DeepFloyd Client error: {e}"
58
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
59
+
60
+ completions: List[GeneratedOutput] = [
61
+ GeneratedOutput(
62
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_path)
63
+ )
64
+ for file_path in response["images"]
65
+ ]
66
+ return RequestResult(
67
+ success=True,
68
+ cached=cached,
69
+ request_time=response["total_inference_time"],
70
+ completions=completions,
71
+ embedding=[],
72
+ )
73
+
74
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
75
+ raise NotImplementedError("This client does not support tokenizing.")
76
+
77
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
78
+ raise NotImplementedError("This client does not support decoding.")
@@ -0,0 +1,249 @@
1
+ from threading import Lock
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import torch
6
+
7
+ from helm.common.cache import CacheConfig, Cache
8
+ from helm.common.file_caches.file_cache import FileCache
9
+ from helm.common.gpu_utils import get_torch_device_name, is_cuda_available
10
+ from helm.common.hierarchical_logger import hlog, htrack_block
11
+ from helm.common.optional_dependencies import handle_module_not_found_error
12
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
13
+ from helm.common.tokenization_request import (
14
+ DecodeRequest,
15
+ DecodeRequestResult,
16
+ TokenizationRequest,
17
+ TokenizationRequestResult,
18
+ )
19
+ from helm.clients.client import Client, CachingClient
20
+ from .image_generation_client_utils import get_single_image_multimedia_object
21
+
22
+
23
+ _models_lock: Lock = Lock()
24
+ _models: Dict[str, Any] = {}
25
+
26
+
27
+ class HuggingFaceDiffusersClient(Client):
28
+ def __init__(self, hf_auth_token: str, cache_config: CacheConfig, file_cache: FileCache):
29
+ self._hf_auth_token: str = hf_auth_token
30
+ self._cache = Cache(cache_config)
31
+ self._file_cache: FileCache = file_cache
32
+
33
+ self._promptist_model = None
34
+ self._promptist_tokenizer = None
35
+
36
+ def _get_diffuser(self, request: Request):
37
+ """
38
+ Initialize the Diffusion Pipeline based on the model name.
39
+ Cache the model, so it doesn't get reinitialize for a new request.
40
+ """
41
+ try:
42
+ from diffusers import DiffusionPipeline
43
+ except ModuleNotFoundError as e:
44
+ handle_module_not_found_error(e, ["heim"])
45
+
46
+ global _models_lock
47
+ global _models
48
+
49
+ with _models_lock:
50
+ model_engine: str = request.model_engine
51
+
52
+ if model_engine not in _models:
53
+ huggingface_model_name: str
54
+ if model_engine in ["stable-diffusion-v1-4", "promptist-stable-diffusion-v1-4"]:
55
+ huggingface_model_name = "CompVis/stable-diffusion-v1-4"
56
+ elif model_engine == "stable-diffusion-v1-5":
57
+ huggingface_model_name = "runwayml/stable-diffusion-v1-5"
58
+ elif model_engine == "stable-diffusion-v2-base":
59
+ huggingface_model_name = "stabilityai/stable-diffusion-2-base"
60
+ elif model_engine == "stable-diffusion-v2-1-base":
61
+ huggingface_model_name = "stabilityai/stable-diffusion-2-1-base"
62
+ elif model_engine == "dreamlike-diffusion-v1-0":
63
+ huggingface_model_name = "dreamlike-art/dreamlike-diffusion-1.0"
64
+ elif model_engine == "dreamlike-photoreal-v2-0":
65
+ huggingface_model_name = "dreamlike-art/dreamlike-photoreal-2.0"
66
+ elif model_engine == "openjourney-v1-0":
67
+ huggingface_model_name = "prompthero/openjourney"
68
+ elif model_engine == "openjourney-v2-0":
69
+ huggingface_model_name = "prompthero/openjourney-v2"
70
+ elif model_engine == "redshift-diffusion":
71
+ huggingface_model_name = "nitrosocke/redshift-diffusion"
72
+ elif "stable-diffusion-safe" in model_engine:
73
+ huggingface_model_name = "AIML-TUDA/stable-diffusion-safe"
74
+ elif model_engine == "vintedois-diffusion-v0-1":
75
+ huggingface_model_name = "22h/vintedois-diffusion-v0-1"
76
+ elif model_engine == "SSD-1B":
77
+ huggingface_model_name = "segmind/SSD-1B"
78
+ else:
79
+ huggingface_model_name = request.model
80
+
81
+ pipeline = DiffusionPipeline.from_pretrained(
82
+ huggingface_model_name,
83
+ torch_dtype=torch.float16 if is_cuda_available() else torch.float,
84
+ use_auth_token=self._hf_auth_token,
85
+ )
86
+ _models[model_engine] = pipeline.to(get_torch_device_name())
87
+ return _models[model_engine]
88
+
89
+ def make_request(self, request: Request) -> RequestResult:
90
+ try:
91
+ from diffusers import DiffusionPipeline
92
+ from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
93
+ except ModuleNotFoundError as e:
94
+ handle_module_not_found_error(e, ["heim"])
95
+
96
+ raw_request = {
97
+ "prompt": request.prompt,
98
+ # Setting this to a higher value can cause CUDA OOM
99
+ # Fix it to 1 and generate an image `request.num_completions` times
100
+ "num_images_per_prompt": 1,
101
+ }
102
+
103
+ assert request.image_generation_parameters is not None
104
+ if request.image_generation_parameters.guidance_scale is not None:
105
+ raw_request["guidance_scale"] = request.image_generation_parameters.guidance_scale
106
+ if request.image_generation_parameters.diffusion_denoising_steps is not None:
107
+ raw_request["num_inference_steps"] = request.image_generation_parameters.diffusion_denoising_steps
108
+ if request.image_generation_parameters.output_image_width is not None:
109
+ raw_request["width"] = request.image_generation_parameters.output_image_width
110
+ if request.image_generation_parameters.output_image_height is not None:
111
+ raw_request["height"] = request.image_generation_parameters.output_image_height
112
+
113
+ # Add the additional pre-configured parameters for Safe Stable Diffusion
114
+ if request.model_engine == "stable-diffusion-safe-weak":
115
+ raw_request = {**raw_request, **SafetyConfig.WEAK}
116
+ elif request.model_engine == "stable-diffusion-safe-medium":
117
+ raw_request = {**raw_request, **SafetyConfig.MEDIUM}
118
+ elif request.model_engine == "stable-diffusion-safe-strong":
119
+ raw_request = {**raw_request, **SafetyConfig.STRONG}
120
+ elif request.model_engine == "stable-diffusion-safe-max":
121
+ raw_request = {**raw_request, **SafetyConfig.MAX}
122
+
123
+ try:
124
+
125
+ def replace_prompt(request_to_update: Dict, new_prompt: str) -> Dict:
126
+ new_request: Dict = dict(request_to_update)
127
+ assert "prompt" in new_request
128
+ new_request["prompt"] = new_prompt
129
+ return new_request
130
+
131
+ def do_it() -> Dict[str, Any]:
132
+ prompt: str = request.prompt
133
+
134
+ with htrack_block(f"Generating images for prompt: {prompt}"):
135
+ diffuser: DiffusionPipeline = self._get_diffuser(request)
136
+ promptist_prompt: Optional[str] = None
137
+
138
+ images = []
139
+ for _ in range(request.num_completions):
140
+ if request.model_engine == "promptist-stable-diffusion-v1-4":
141
+ promptist_prompt = self._generate_promptist_prompt(prompt)
142
+ hlog(f"Promptist: {prompt} -> {promptist_prompt}")
143
+ image = diffuser(**replace_prompt(raw_request, promptist_prompt)).images[0] # type: ignore
144
+ elif request.model_engine == "openjourney-v1-0":
145
+ # It is required to include "mdjrny-v4 style" in prompt for Openjourney v1
146
+ image = diffuser(
147
+ **replace_prompt(raw_request, f"mdjrny-v4 style {prompt}") # type: ignore
148
+ ).images[0]
149
+ elif request.model_engine == "redshift-diffusion":
150
+ # It is required to include "redshift style" to generate 3D images
151
+ image = diffuser(
152
+ **replace_prompt(raw_request, f"redshift style {prompt}") # type: ignore
153
+ ).images[0]
154
+ else:
155
+ image = diffuser(**raw_request).images[0] # type: ignore
156
+ images.append(image)
157
+
158
+ assert (
159
+ len(images) == request.num_completions
160
+ ), f"Expected {request.num_completions} images, but got {len(images)}"
161
+
162
+ result: Dict = {"file_locations": []}
163
+ if promptist_prompt is not None:
164
+ # Save the Promptist version of the prompts in the cache, just in case we need it later
165
+ result["promptist_prompt"] = promptist_prompt
166
+
167
+ for image in images:
168
+ # Write out the image to a file and save the path
169
+ file_location: str = self._file_cache.generate_unique_new_file_path() # type: ignore
170
+ image.save(file_location)
171
+ hlog(f"Image saved at {file_location}")
172
+ result["file_locations"].append(file_location)
173
+ return result
174
+
175
+ # Include the model name and number of completions in the cache key
176
+ cache_key = CachingClient.make_cache_key(
177
+ {"model": request.model_engine, "n": request.num_completions, **raw_request}, request
178
+ )
179
+ results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
180
+ except RuntimeError as ex:
181
+ error: str = f"HuggingFaceDiffusersClient error: {ex}"
182
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
183
+
184
+ completions: List[GeneratedOutput] = [
185
+ GeneratedOutput(
186
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(file_location)
187
+ )
188
+ for file_location in results["file_locations"]
189
+ ]
190
+ return RequestResult(
191
+ success=True,
192
+ cached=cached,
193
+ request_time=results["request_time"],
194
+ completions=completions,
195
+ embedding=[],
196
+ )
197
+
198
+ def _generate_promptist_prompt(self, prompt: str) -> str:
199
+ """
200
+ Generate a better version of the prompt with Promptist.
201
+ Promptist was trained specifically with CompVis/stable-diffusion-v1-4.
202
+ Adapted from https://huggingface.co/spaces/microsoft/Promptist/blob/main/app.py.
203
+ """
204
+
205
+ def load_promptist():
206
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
207
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
208
+ tokenizer.pad_token = tokenizer.eos_token
209
+ tokenizer.padding_side = "left"
210
+ return prompter_model, tokenizer
211
+
212
+ def generate(plain_text: str) -> str:
213
+ if self._promptist_model is None or self._promptist_tokenizer is None:
214
+ self._promptist_model, self._promptist_tokenizer = load_promptist()
215
+ assert self._promptist_model is not None
216
+ assert self._promptist_tokenizer is not None
217
+
218
+ input_ids = self._promptist_tokenizer(f"{plain_text.strip()} Rephrase:", return_tensors="pt").input_ids
219
+ eos_id = self._promptist_tokenizer.eos_token_id
220
+ # Used the same hyperparameters from the example
221
+ outputs = self._promptist_model.generate(
222
+ input_ids,
223
+ do_sample=False,
224
+ max_new_tokens=75,
225
+ num_beams=8,
226
+ num_return_sequences=8,
227
+ eos_token_id=eos_id,
228
+ pad_token_id=eos_id,
229
+ length_penalty=-1.0,
230
+ )
231
+ output_texts: List[str] = self._promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
232
+
233
+ for output_text in output_texts:
234
+ res: str = output_text.replace(f"{plain_text} Rephrase:", "").strip()
235
+ # The Promptist model sometimes generates empty string results.
236
+ # Return the first non-empty string result.
237
+ if len(res) > 0:
238
+ return res
239
+
240
+ # If all fails, just return the original text.
241
+ return plain_text
242
+
243
+ return generate(prompt)
244
+
245
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
246
+ raise NotImplementedError("This client does not support tokenizing.")
247
+
248
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
249
+ raise NotImplementedError("This client does not support decoding.")
@@ -0,0 +1,9 @@
1
+ from helm.common.media_object import MediaObject, MultimediaObject
2
+
3
+
4
+ def get_single_image_multimedia_object(image_location: str) -> MultimediaObject:
5
+ """
6
+ Returns a `MultimediaObject` containing a single image file used for text-to-image generation clients.
7
+ """
8
+ file_extension: str = image_location.split(".")[-1]
9
+ return MultimediaObject([MediaObject(content_type=f"image/{file_extension}", location=image_location)])
@@ -0,0 +1,86 @@
1
+ from typing import Any, List, Dict, Union
2
+ import base64
3
+ import requests
4
+ import urllib.parse
5
+
6
+ from helm.common.cache import CacheConfig, Cache
7
+ from helm.common.file_caches.file_cache import FileCache
8
+ from helm.common.images_utils import encode_base64
9
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
10
+ from helm.common.tokenization_request import (
11
+ TokenizationRequest,
12
+ TokenizationRequestResult,
13
+ DecodeRequest,
14
+ DecodeRequestResult,
15
+ )
16
+ from helm.clients.client import Client, CachingClient
17
+ from .image_generation_client_utils import get_single_image_multimedia_object
18
+
19
+
20
+ class LexicaClient(Client):
21
+ """
22
+ Client for Lexica API. Does not support image generation.
23
+ """
24
+
25
+ def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
26
+ self.cache = Cache(cache_config)
27
+ self.file_cache: FileCache = file_cache
28
+
29
+ def make_request(self, request: Request) -> RequestResult:
30
+ """
31
+ Retrieves images through Lexica's search API (https://lexica.art/docs).
32
+ The search API is powered by CLIP to fetch the most relevant images for a given query.
33
+ """
34
+ if request.model_engine != "search-stable-diffusion-1.5":
35
+ # Only Stable Diffusion 1.5 is supported at the moment
36
+ raise ValueError(f"Invalid model: {request.model_engine}")
37
+
38
+ raw_request: Dict[str, Union[str, int]] = {
39
+ "model": request.model_engine,
40
+ "prompt": request.prompt,
41
+ "n": request.num_completions,
42
+ }
43
+ cache_key = CachingClient.make_cache_key(raw_request, request)
44
+
45
+ try:
46
+
47
+ def do_it() -> Dict[str, Any]:
48
+ num_completions: int = int(raw_request["n"])
49
+ result = requests.get(
50
+ f"https://lexica.art/api/v1/search?{urllib.parse.urlencode({'q': request.prompt})}"
51
+ ).json()
52
+ assert "images" in result, f"Invalid response: {result} from prompt: {request.prompt}"
53
+ assert len(result["images"]) >= num_completions, "Did not retrieve enough images"
54
+
55
+ image_locations: List[str] = []
56
+ # Most relevant images are at the top of the list
57
+ for image in result["images"][:num_completions]:
58
+ # Write out the image to a file and save the location
59
+ image_base64: str = encode_base64(image["src"])
60
+ image_locations.append(self.file_cache.store(lambda: base64.b64decode(image_base64)))
61
+ return {"image_locations": image_locations}
62
+
63
+ response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
64
+ except RuntimeError as e:
65
+ error: str = f"LexicaClient error: {e}"
66
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
67
+
68
+ completions: List[GeneratedOutput] = [
69
+ GeneratedOutput(
70
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(location)
71
+ )
72
+ for location in response["image_locations"]
73
+ ]
74
+ return RequestResult(
75
+ success=True,
76
+ cached=cached,
77
+ request_time=response["request_time"],
78
+ completions=completions,
79
+ embedding=[],
80
+ )
81
+
82
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
83
+ raise NotImplementedError("This client does not support tokenizing.")
84
+
85
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
86
+ raise NotImplementedError("This client does not support decoding.")
File without changes