crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,149 @@
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import torch
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from torch.nn import functional as F
11
+
12
+
13
+ def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
14
+ if k is None:
15
+ return logits
16
+ else:
17
+ v, ix = torch.topk(logits, k)
18
+ out = logits.clone()
19
+ out[out < v[:, [-1]]] = -float("Inf")
20
+ return out
21
+
22
+
23
+ def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
24
+ if p is None:
25
+ return probs
26
+ else:
27
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
28
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
29
+
30
+ sorted_idx_remove_cond = cum_probs >= p
31
+
32
+ sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
33
+ sorted_idx_remove_cond[..., 0] = 0
34
+
35
+ indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
36
+ probs = probs.masked_fill(indices_to_remove, 0.0)
37
+ norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
38
+ return norm_probs
39
+
40
+
41
+ def get_positional_encoding(inputs: torch.LongTensor, mode: str = "1d") -> torch.LongTensor:
42
+ device = inputs.device
43
+ if mode == "1d":
44
+ B, N = inputs.shape
45
+ xs_pos = torch.arange(N, device=device).repeat((B, 1))
46
+ elif mode == "2d":
47
+ B, H, W = inputs.shape
48
+ xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
49
+ xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
50
+ xs_pos = (xs_pos_h, xs_pos_w)
51
+ else:
52
+ raise ValueError("%s positional encoding invalid" % mode)
53
+ return xs_pos
54
+
55
+
56
+ @torch.no_grad()
57
+ def sampling(
58
+ model: torch.nn.Module,
59
+ tokens: torch.LongTensor,
60
+ top_k: Optional[float] = None,
61
+ top_p: Optional[float] = None,
62
+ softmax_temperature: float = 1.0,
63
+ is_tqdm: bool = True,
64
+ use_fp16: bool = True,
65
+ max_seq_len: int = 256,
66
+ ) -> torch.LongTensor:
67
+ code = None
68
+ past = None
69
+
70
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
71
+ pos_enc_tokens = get_positional_encoding(tokens, mode="1d")
72
+
73
+ for cnt, h in enumerate(pbar):
74
+ if code is None:
75
+ code_ = None
76
+ pos_enc_code_ = None
77
+ else:
78
+ code_ = code.clone().detach()
79
+ pos_enc_code_ = get_positional_encoding(code_, mode="1d")
80
+ code_ = code_[:, cnt - 1].unsqueeze(-1)
81
+ pos_enc_code_ = pos_enc_code_[:, cnt - 1].unsqueeze(-1)
82
+
83
+ logits, present = model.sampling(
84
+ images=code_, texts=tokens, pos_images=pos_enc_code_, pos_texts=pos_enc_tokens, use_fp16=use_fp16, past=past
85
+ )
86
+ logits = logits.to(dtype=torch.float32)
87
+ logits = logits / softmax_temperature
88
+
89
+ present = torch.stack(present).clone().detach()
90
+ if past is None:
91
+ past = [present]
92
+ else:
93
+ past.append(present)
94
+
95
+ logits = cutoff_topk_logits(logits, top_k)
96
+ probs = F.softmax(logits, dim=-1)
97
+ probs = cutoff_topp_probs(probs, top_p)
98
+
99
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
100
+ code = idx if code is None else torch.cat([code, idx], axis=1)
101
+
102
+ del past
103
+ return code
104
+
105
+
106
+ @torch.no_grad()
107
+ def sampling_igpt(
108
+ model: torch.nn.Module,
109
+ sos: torch.FloatTensor,
110
+ top_k: Optional[float] = None,
111
+ top_p: Optional[float] = None,
112
+ softmax_temperature: float = 1.0,
113
+ is_tqdm: bool = True,
114
+ use_fp16: bool = True,
115
+ max_seq_len: int = 256,
116
+ ) -> torch.LongTensor:
117
+ code = None
118
+ past = None
119
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
120
+
121
+ for cnt, h in enumerate(pbar):
122
+ if code is None:
123
+ code_ = None
124
+ pos_enc_code_ = None
125
+ else:
126
+ code_ = code.clone().detach()
127
+ pos_enc_code_ = get_positional_encoding(code_, mode="1d")
128
+ code_ = code_[:, cnt - 1].unsqueeze(-1)
129
+ pos_enc_code_ = pos_enc_code_[:, cnt - 1].unsqueeze(-1)
130
+
131
+ logits, present = model.sampling(sos=sos, codes=code_, pos_codes=pos_enc_code_, use_fp16=use_fp16, past=past)
132
+ logits = logits.to(dtype=torch.float32)
133
+ logits = logits / softmax_temperature
134
+
135
+ present = torch.stack(present).clone().detach()
136
+ if past is None:
137
+ past = [present]
138
+ else:
139
+ past.append(present)
140
+
141
+ logits = cutoff_topk_logits(logits, top_k)
142
+ probs = F.softmax(logits, dim=-1)
143
+ probs = cutoff_topp_probs(probs, top_p)
144
+
145
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
146
+ code = idx if code is None else torch.cat([code, idx], axis=1)
147
+
148
+ del past
149
+ return code
@@ -0,0 +1,89 @@
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import random
9
+ import urllib
10
+ import hashlib
11
+ import tarfile
12
+ import torch
13
+ import numpy as np
14
+ from torch.nn import functional as F
15
+ from tqdm import tqdm
16
+
17
+ from helm.common.optional_dependencies import handle_module_not_found_error
18
+
19
+
20
+ def set_seed(seed: int):
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ @torch.no_grad()
28
+ def clip_score(
29
+ prompt: str, images: np.ndarray, model_clip: torch.nn.Module, preprocess_clip, device: str
30
+ ) -> np.ndarray:
31
+ try:
32
+ import clip
33
+ from PIL import Image
34
+ except ModuleNotFoundError as e:
35
+ handle_module_not_found_error(e, ["heim"])
36
+
37
+ images = [preprocess_clip(Image.fromarray((image * 255).astype(np.uint8))) for image in images]
38
+ images = torch.stack(images, dim=0).to(device=device)
39
+ texts = clip.tokenize(prompt).to(device=device)
40
+ texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
41
+
42
+ image_features = model_clip.encode_image(images)
43
+ text_features = model_clip.encode_text(texts)
44
+
45
+ scores = F.cosine_similarity(image_features, text_features).squeeze()
46
+ rank = torch.argsort(scores, descending=True).cpu().numpy()
47
+ return rank
48
+
49
+
50
+ def download(url: str, root: str) -> str:
51
+ os.makedirs(root, exist_ok=True)
52
+ filename = os.path.basename(url)
53
+ pathname = filename[: -len(".tar.gz")]
54
+
55
+ expected_md5 = url.split("/")[-2]
56
+ download_target = os.path.join(root, filename)
57
+ result_path = os.path.join(root, pathname)
58
+
59
+ if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
60
+ return result_path
61
+
62
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
63
+ with tqdm(
64
+ total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
65
+ ) as loop:
66
+ while True:
67
+ buffer = source.read(8192)
68
+ if not buffer:
69
+ break
70
+
71
+ output.write(buffer)
72
+ loop.update(len(buffer))
73
+
74
+ if hashlib.md5(open(download_target, "rb").read()).hexdigest() != expected_md5:
75
+ raise RuntimeError(f"Model has been downloaded but the md5 checksum does not not match")
76
+
77
+ with tarfile.open(download_target, "r:gz") as f:
78
+ pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
79
+ for member in pbar:
80
+ pbar.set_description(f"extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)")
81
+ f.extract(member=member, path=root)
82
+
83
+ return result_path
84
+
85
+
86
+ def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
87
+ if urllib.parse.urlparse(url_or_path).scheme in ("http", "https"):
88
+ return download(url_or_path, root)
89
+ return url_or_path
@@ -0,0 +1,115 @@
1
+ from typing import Any, Dict, List
2
+
3
+ import numpy as np
4
+
5
+ from helm.common.cache import CacheConfig, Cache
6
+ from helm.common.file_caches.file_cache import FileCache
7
+ from helm.common.gpu_utils import get_torch_device_name
8
+ from helm.common.hierarchical_logger import hlog, htrack_block
9
+ from helm.common.optional_dependencies import handle_module_not_found_error
10
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
11
+ from helm.common.tokenization_request import (
12
+ DecodeRequest,
13
+ DecodeRequestResult,
14
+ TokenizationRequest,
15
+ TokenizationRequestResult,
16
+ )
17
+ from helm.clients.client import Client, CachingClient
18
+ from .image_generation_client_utils import get_single_image_multimedia_object
19
+
20
+ try:
21
+ from PIL import Image
22
+ except ModuleNotFoundError as e:
23
+ handle_module_not_found_error(e, ["heim"])
24
+
25
+
26
+ class MinDALLEClient(Client):
27
+ """
28
+ Source: https://github.com/kakaobrain/mindall-e
29
+ """
30
+
31
+ def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
32
+ self._cache = Cache(cache_config)
33
+ self._file_cache: FileCache = file_cache
34
+
35
+ self._model = None
36
+
37
+ def _get_model(self):
38
+ try:
39
+ from helm.clients.image_generation.mindalle.models import Dalle
40
+ except ModuleNotFoundError as e:
41
+ handle_module_not_found_error(e, ["heim"])
42
+
43
+ if self._model is None:
44
+ self._model = Dalle.from_pretrained("minDALL-E/1.3B")
45
+ self._model = self._model.to(get_torch_device_name())
46
+ return self._model
47
+
48
+ def make_request(self, request: Request) -> RequestResult:
49
+ raw_request = {
50
+ "prompt": request.prompt,
51
+ # Setting this to a higher value can cause CUDA OOM
52
+ # Fix it to 1 and generate an image `request.num_completions` times
53
+ "num_candidates": 1,
54
+ "softmax_temperature": 1.0,
55
+ "top_k": 256, # It is recommended that top_k is set lower than 256.
56
+ "top_p": None,
57
+ "device": "cuda",
58
+ }
59
+
60
+ try:
61
+
62
+ def do_it() -> Dict[str, Any]:
63
+ prompt: str = request.prompt
64
+
65
+ with htrack_block(f"Generating images for prompt: {prompt}"):
66
+ model = self._get_model()
67
+
68
+ images: List[Image] = []
69
+ for _ in range(request.num_completions):
70
+ output = model.sampling(**raw_request).cpu().numpy()
71
+ output = np.transpose(output, (0, 2, 3, 1))
72
+ image = Image.fromarray(np.asarray(output[0] * 255, dtype=np.uint8))
73
+ images.append(image)
74
+
75
+ assert (
76
+ len(images) == request.num_completions
77
+ ), f"Expected {request.num_completions} images, but got {len(images)}"
78
+
79
+ result = {"file_locations": []}
80
+ for image in images:
81
+ # Write out the image to a file and save the path
82
+ file_location: str = self._file_cache.get_unique_file_location()
83
+ image.save(file_location)
84
+ hlog(f"Image saved at {file_location}.")
85
+ result["file_locations"].append(file_location)
86
+ return result
87
+
88
+ # Include the model name and number of completions in the cache key
89
+ cache_key: Dict = CachingClient.make_cache_key(
90
+ {"model": request.model_engine, "n": request.num_completions, **raw_request}, request
91
+ )
92
+ results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
93
+ except RuntimeError as ex:
94
+ error: str = f"MinDALLEClient error: {ex}"
95
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
96
+
97
+ completions: List[GeneratedOutput] = [
98
+ GeneratedOutput(
99
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(location)
100
+ )
101
+ for location in results["file_locations"]
102
+ ]
103
+ return RequestResult(
104
+ success=True,
105
+ cached=cached,
106
+ request_time=results["request_time"],
107
+ completions=completions,
108
+ embedding=[],
109
+ )
110
+
111
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
112
+ raise NotImplementedError("This client does not support tokenizing.")
113
+
114
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
115
+ raise NotImplementedError("This client does not support decoding.")
@@ -0,0 +1,64 @@
1
+ from typing import Any, Dict, Optional
2
+ import os
3
+
4
+ from helm.common.cache import Cache, CacheConfig
5
+ from helm.common.general import ensure_file_downloaded, ensure_directory_exists
6
+ from helm.common.optional_dependencies import handle_module_not_found_error
7
+ from helm.common.nudity_check_request import NudityCheckRequest, NudityCheckResult
8
+
9
+
10
+ class NudityCheckClientError(Exception):
11
+ pass
12
+
13
+
14
+ class NudityCheckClient:
15
+ MODEL_DOWNLOAD_URL: str = (
16
+ "https://worksheets.codalab.org/rest/bundles/0x81f7d1febb374fdcad6d9ae970f46750/contents/blob/"
17
+ )
18
+
19
+ def __init__(self, cache_config: CacheConfig):
20
+ try:
21
+ from nudenet import NudeClassifier
22
+ except ModuleNotFoundError as e:
23
+ handle_module_not_found_error(e, ["heim"])
24
+
25
+ self.cache = Cache(cache_config)
26
+ self._nudity_classifier: Optional[NudeClassifier] = None
27
+
28
+ def check_nudity(self, request: NudityCheckRequest) -> NudityCheckResult:
29
+ """Check for nudity for a batch of images using NudeNet."""
30
+ try:
31
+ from nudenet import NudeClassifier
32
+ except ModuleNotFoundError as e:
33
+ handle_module_not_found_error(e, ["heim"])
34
+
35
+ try:
36
+
37
+ def do_it() -> Dict[str, Any]:
38
+ if self._nudity_classifier is None:
39
+ # The NudeNet library does not automatically provide model weights that work, so
40
+ # manually download them. The path is hardcoded in the NudeNet library.
41
+ base_path: str = os.path.join(os.path.expanduser("~"), ".NudeNet")
42
+ ensure_directory_exists(base_path)
43
+ model_path: str = os.path.join(base_path, "classifier_model.onnx")
44
+ ensure_file_downloaded(source_url=self.MODEL_DOWNLOAD_URL, target_path=model_path)
45
+ self._nudity_classifier = NudeClassifier()
46
+
47
+ path_to_nudity_scores: Dict[str, Dict[str, float]] = self._nudity_classifier.classify(
48
+ request.image_locations
49
+ )
50
+ return path_to_nudity_scores
51
+
52
+ results, cached = self.cache.get({"locations": sorted(request.image_locations)}, do_it)
53
+ except Exception as e:
54
+ raise NudityCheckClientError(e)
55
+
56
+ nudity_results: Dict[str, bool] = {
57
+ image_location: nudity_result["unsafe"] > nudity_result["safe"]
58
+ for image_location, nudity_result in results.items()
59
+ }
60
+ return NudityCheckResult(
61
+ success=True,
62
+ cached=cached,
63
+ image_to_nudity=nudity_results,
64
+ )
@@ -0,0 +1,111 @@
1
+ from typing import Any, Dict, List, Optional
2
+ import base64
3
+ import requests
4
+
5
+ from helm.common.cache import CacheConfig, Cache
6
+ from helm.common.file_caches.file_cache import FileCache
7
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
8
+ from helm.common.tokenization_request import (
9
+ TokenizationRequest,
10
+ TokenizationRequestResult,
11
+ DecodeRequest,
12
+ DecodeRequestResult,
13
+ )
14
+
15
+ from helm.clients.client import CachingClient, Client
16
+ from .image_generation_client_utils import get_single_image_multimedia_object
17
+
18
+
19
+ class TogetherImageGenerationClient(Client):
20
+ """
21
+ Client for image generation via the Together API.
22
+ """
23
+
24
+ DEFAULT_IMAGE_HEIGHT: int = 512
25
+ DEFAULT_IMAGE_WIDTH: int = 512
26
+
27
+ DEFAULT_GUIDANCE_SCALE: float = 7.5
28
+ DEFAULT_STEPS: int = 50
29
+
30
+ INFERENCE_ENDPOINT: str = "https://api.together.xyz/api/inference"
31
+
32
+ def __init__(self, cache_config: CacheConfig, file_cache: FileCache, api_key: Optional[str] = None):
33
+ self._cache = Cache(cache_config)
34
+ self.file_cache: FileCache = file_cache
35
+
36
+ self._promptist_model = None
37
+ self._promptist_tokenizer = None
38
+
39
+ self.api_key: Optional[str] = api_key
40
+
41
+ def make_request(self, request: Request) -> RequestResult:
42
+ # Following https://docs.together.xyz/en/api
43
+ assert request.image_generation_parameters is not None
44
+ raw_request = {
45
+ "request_type": "image-model-inference",
46
+ "model": request.model_engine,
47
+ "prompt": request.prompt,
48
+ "n": request.num_completions,
49
+ "guidance_scale": (
50
+ request.image_generation_parameters.guidance_scale
51
+ if request.image_generation_parameters.guidance_scale is not None
52
+ else self.DEFAULT_GUIDANCE_SCALE
53
+ ),
54
+ "steps": (
55
+ request.image_generation_parameters.diffusion_denoising_steps
56
+ if request.image_generation_parameters.diffusion_denoising_steps is not None
57
+ else self.DEFAULT_STEPS
58
+ ),
59
+ }
60
+
61
+ if (
62
+ request.image_generation_parameters.output_image_width is None
63
+ or request.image_generation_parameters.output_image_height is None
64
+ ):
65
+ raw_request["width"] = self.DEFAULT_IMAGE_WIDTH
66
+ raw_request["height"] = self.DEFAULT_IMAGE_HEIGHT
67
+ else:
68
+ raw_request["width"] = request.image_generation_parameters.output_image_width
69
+ raw_request["height"] = request.image_generation_parameters.output_image_height
70
+
71
+ cache_key = CachingClient.make_cache_key(raw_request, request)
72
+
73
+ try:
74
+
75
+ def do_it() -> Dict[str, Any]:
76
+ result = requests.post(self.INFERENCE_ENDPOINT, json=raw_request).json()
77
+ assert "output" in result, f"Invalid response: {result} from prompt: {request.prompt}"
78
+
79
+ for choice in result["output"]["choices"]:
80
+ # Write out the image to a file and save the path
81
+ choice["file_path"] = self.file_cache.store(lambda: base64.b64decode(choice["image_base64"]))
82
+ choice.pop("image_base64", None)
83
+ return result["output"]
84
+
85
+ response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
86
+ except RuntimeError as e:
87
+ error: str = f"TogetherVisionClient error: {e}"
88
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
89
+
90
+ completions: List[GeneratedOutput] = [
91
+ GeneratedOutput(
92
+ text="",
93
+ logprob=0,
94
+ tokens=[],
95
+ multimodal_content=get_single_image_multimedia_object(choice["file_path"]),
96
+ )
97
+ for choice in response["choices"]
98
+ ]
99
+ return RequestResult(
100
+ success=True,
101
+ cached=cached,
102
+ request_time=response["request_time"],
103
+ completions=completions,
104
+ embedding=[],
105
+ )
106
+
107
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
108
+ raise NotImplementedError("This client does not support tokenizing.")
109
+
110
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
111
+ raise NotImplementedError("This client does not support decoding.")
@@ -9,8 +9,8 @@ import torch
9
9
 
10
10
  from helm.common.cache import CacheConfig
11
11
  from helm.common.optional_dependencies import OptionalDependencyNotInstalled
12
- from helm.common.request import Request, RequestResult, Sequence, Token
13
- from helm.proxy.tokenizers.tokenizer import Tokenizer
12
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
13
+ from helm.tokenizers.tokenizer import Tokenizer
14
14
 
15
15
  from .client import CachingClient
16
16
  from .lit_gpt_generate import generate # type: ignore
@@ -89,6 +89,7 @@ class LitGPTClient(CachingClient):
89
89
  def __init__(
90
90
  self,
91
91
  tokenizer: Tokenizer,
92
+ tokenizer_name: str,
92
93
  cache_config: CacheConfig,
93
94
  checkpoint_dir: Path = Path(""),
94
95
  precision: str = "bf16-true",
@@ -97,7 +98,8 @@ class LitGPTClient(CachingClient):
97
98
  strategy: str = "auto",
98
99
  quantize: Optional[QuantizationType] = None,
99
100
  ):
100
- super().__init__(cache_config=cache_config, tokenizer=tokenizer)
101
+ super().__init__(cache_config=cache_config)
102
+ self.tokenizer = tokenizer
101
103
  lit_gpt = LitGPT(checkpoint_dir, precision, device, devices, strategy, quantize)
102
104
  self.model = lit_gpt.model
103
105
  self.fabric = lit_gpt.fabric
@@ -154,8 +156,8 @@ class LitGPTClient(CachingClient):
154
156
 
155
157
  generated_tokens = []
156
158
  for token in tokens:
157
- generated_tokens.append(Token(text=tokenizer.decode(token), logprob=0, top_logprobs={}))
158
- completions = [Sequence(text=output, logprob=0, tokens=generated_tokens)]
159
+ generated_tokens.append(Token(text=tokenizer.decode(token), logprob=0))
160
+ completions = [GeneratedOutput(text=output, logprob=0, tokens=generated_tokens)]
159
161
 
160
162
  return RequestResult(
161
163
  success=True,
@@ -2,21 +2,22 @@ import json
2
2
  import requests
3
3
  from typing import Any, Dict, List
4
4
  import traceback
5
+ from helm.common.cache import CacheConfig
5
6
 
6
7
  from helm.common.request import (
7
8
  wrap_request_time,
8
9
  EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
9
10
  Request,
10
11
  RequestResult,
11
- Sequence,
12
+ GeneratedOutput,
12
13
  Token,
13
14
  )
14
15
  from helm.common.tokenization_request import TokenizationRequest
15
- from helm.proxy.clients.huggingface_client import HuggingFaceClient
16
- from helm.proxy.clients.client import CachingClient, truncate_sequence
16
+ from helm.clients.client import CachingClient, truncate_sequence
17
+ from helm.tokenizers.tokenizer import Tokenizer
17
18
 
18
19
 
19
- class MegatronClient(HuggingFaceClient):
20
+ class MegatronClient(CachingClient):
20
21
  """Client for remote Megatron-LM server.
21
22
 
22
23
  This client expects an external Megatron-LM server to be run on localhost:5000. See the
@@ -25,6 +26,11 @@ class MegatronClient(HuggingFaceClient):
25
26
  https://github.com/NVIDIA/Megatron-LM#gpt-text-generation
26
27
  """
27
28
 
29
+ def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
30
+ super().__init__(cache_config=cache_config)
31
+ self.tokenizer = tokenizer
32
+ self.tokenizer_name = tokenizer_name
33
+
28
34
  def _send_request(self, raw_request: Dict[str, Any]) -> Dict[str, Any]:
29
35
  response = requests.request(
30
36
  method="PUT",
@@ -43,10 +49,10 @@ class MegatronClient(HuggingFaceClient):
43
49
  return out
44
50
 
45
51
  def _tokenize_response(self, text: str) -> List[Token]:
46
- tokenized_text = self.tokenizer.tokenize(TokenizationRequest(text, tokenizer="huggingface/gpt2"))
52
+ tokenized_text = self.tokenizer.tokenize(TokenizationRequest(text, tokenizer=self.tokenizer_name))
47
53
 
48
54
  # TODO(tgale): Support logprobs.
49
- tokens = [Token(text=str(token), logprob=0, top_logprobs={}) for token in tokenized_text.raw_tokens]
55
+ tokens = [Token(text=str(token), logprob=0) for token in tokenized_text.raw_tokens]
50
56
  return tokens
51
57
 
52
58
  def _make_request(self, request: Request) -> RequestResult:
@@ -81,7 +87,7 @@ class MegatronClient(HuggingFaceClient):
81
87
 
82
88
  # NOTE: Megatron returns the de-tokenized response. Re-tokenize.
83
89
  tokens = self._tokenize_response(generated_text)
84
- completion = Sequence(text=generated_text, logprob=0, tokens=tokens)
90
+ completion = GeneratedOutput(text=generated_text, logprob=0, tokens=tokens)
85
91
  completion = truncate_sequence(completion, request, print_warning=True)
86
92
 
87
93
  return RequestResult(