crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,120 @@
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : itersr_sampling.py
4
+ @Time : 2022/03/03 14:24:28
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ """
8
+
9
+ # here put the import lib
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from icetk import icetk as tokenizer
13
+
14
+
15
+ def top_k_logits_(logits, top_k=0, filter_value=-float("Inf")):
16
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
17
+ logits[indices_to_remove] = filter_value
18
+ return logits
19
+
20
+
21
+ class IterativeEntfilterStrategy:
22
+ def __init__(self, invalid_slices=[], temperature=1.0, topk=10):
23
+ self.invalid_slices = invalid_slices
24
+ self.temperature = temperature
25
+ self.topk = topk
26
+
27
+ def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
28
+ # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
29
+ if temperature is None:
30
+ temperature = self.temperature
31
+
32
+ logits = logits.float() / temperature
33
+ for invalid_slice in self.invalid_slices:
34
+ logits[..., invalid_slice] = -float("Inf")
35
+
36
+ # debiased topk
37
+ # probs = F.softmax(logits, dim=-1)
38
+ # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
39
+ # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
40
+ # edge_idx = tk_idx[:, :, -1:]
41
+ # edge_value = tk_value[:, :, -1:]
42
+ # edge_mask = probs.gather(dim=-1, index=pred) < edge_value
43
+ # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
44
+ # pred.squeeze_(-1) # [batch_size, seq_length]
45
+
46
+ top_k_logits_(logits, self.topk)
47
+ probs = F.softmax(logits, dim=-1)
48
+ pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
49
+ pred.squeeze_(-1)
50
+
51
+ assert tokens.shape[1] == pred.shape[1]
52
+ tokens = pred
53
+ return tokens
54
+
55
+
56
+ def filling_sequence_itersr(
57
+ model,
58
+ seq0,
59
+ seq1,
60
+ warmup_steps=3,
61
+ block_hw=(4, 4),
62
+ strategy=IterativeEntfilterStrategy(topk=10),
63
+ ):
64
+ """
65
+ seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
66
+ 4095 {layout[2]} final_token.
67
+ Attention:
68
+ The sampling temperature are changing, temporally we hard code them here.
69
+ The temperature in the strategy is not used.
70
+ """
71
+ assert hasattr(model, "layout")
72
+ layout = model.layout
73
+
74
+ device = seq0.device
75
+ # concat and pad sequences
76
+ batch_size = seq0.shape[0]
77
+ n_pad = layout[0] - seq0.shape[1]
78
+ assert n_pad >= 0, "You should truncate long input before filling."
79
+ seq = torch.cat(
80
+ (torch.tensor([0] * n_pad, device=device, dtype=seq0.dtype).unsqueeze(0).expand(batch_size, n_pad), seq0, seq1),
81
+ dim=1,
82
+ ) # [b, layout[-1]+1]
83
+ assert seq.shape[1] == layout[-1]
84
+
85
+ # build initial tokens, attention_mask, and position_ids
86
+ tokens = seq.clone()
87
+ attention_mask = torch.ones(layout[0]).to(device)
88
+ attention_mask[:n_pad] = 0
89
+ attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
90
+ position_ids = torch.cat(
91
+ (
92
+ torch.zeros(n_pad, dtype=torch.long),
93
+ torch.arange(0, layout[0] - n_pad),
94
+ torch.arange(1024, 1024 + layout[1] - layout[0]),
95
+ )
96
+ ).to(device)
97
+ log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
98
+ log_attention_weights[n_pad : layout[0]] = 0.0
99
+ log_attention_weights = log_attention_weights.unsqueeze(0)
100
+
101
+ # prepare for interation
102
+ unfixed = tokens == tokenizer["<start_of_image>"]
103
+ ll, rr = block_hw
104
+ # edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
105
+ num_steps = 1
106
+ # interative refining
107
+
108
+ # unfixed[..., -(layout[-1] - layout[-2]):].view(
109
+ # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
110
+
111
+ ret = []
112
+ # ret.append(tokens[:, layout[-2]:-1].clone())
113
+ for step_cnt in range(1, num_steps + 1):
114
+ logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
115
+ real_temp = 1.0
116
+ new_tokens = strategy.forward(logits, tokens, real_temp)
117
+ tokens[unfixed] = new_tokens[unfixed]
118
+
119
+ ret.append(tokens[:, layout[-2] :].clone())
120
+ return torch.cat(ret, dim=0)
@@ -0,0 +1,42 @@
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : sr_group.py
4
+ @Time : 2022/04/02 01:17:21
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ """
8
+
9
+ # here put the import lib
10
+ from .direct_sr import DirectSuperResolution
11
+ from .iterative_sr import IterativeSuperResolution
12
+
13
+ from helm.common.optional_dependencies import handle_module_not_found_error
14
+
15
+
16
+ class SRGroup:
17
+ def __init__(
18
+ self,
19
+ args,
20
+ home_path=None,
21
+ ):
22
+ try:
23
+ from SwissArmyTransformer.resources import auto_create
24
+ except ModuleNotFoundError as e:
25
+ handle_module_not_found_error(e, ["heim"])
26
+
27
+ dsr_path = auto_create("cogview2-dsr", path=home_path)
28
+ itersr_path = auto_create("cogview2-itersr", path=home_path)
29
+ dsr = DirectSuperResolution(args, dsr_path)
30
+ itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
31
+ self.dsr = dsr
32
+ self.itersr = itersr
33
+
34
+ def sr_base(self, img_tokens, txt_tokens):
35
+ assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
36
+ batch_size = img_tokens.shape[0]
37
+ txt_len = txt_tokens.shape[-1]
38
+ if len(txt_tokens.shape) == 1:
39
+ txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
40
+ sred_tokens = self.dsr(txt_tokens, img_tokens)
41
+ iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
42
+ return iter_tokens[-batch_size:]
@@ -0,0 +1,191 @@
1
+ import os
2
+ import argparse
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from icetk import icetk as tokenizer
8
+ from torchvision.utils import save_image
9
+
10
+ from helm.common.cache import CacheConfig, Cache
11
+ from helm.common.file_caches.file_cache import FileCache
12
+ from helm.common.hierarchical_logger import hlog, htrack_block
13
+ from helm.common.optional_dependencies import handle_module_not_found_error
14
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
15
+ from helm.common.tokenization_request import (
16
+ DecodeRequest,
17
+ DecodeRequestResult,
18
+ TokenizationRequest,
19
+ TokenizationRequestResult,
20
+ )
21
+ from helm.clients.client import Client, CachingClient
22
+ from helm.clients.image_generation.cogview2.coglm_strategy import CoglmStrategy
23
+ from .image_generation_client_utils import get_single_image_multimedia_object
24
+
25
+
26
+ class CogView2Client(Client):
27
+ """
28
+ https://github.com/THUDM/CogView2
29
+ """
30
+
31
+ MAX_SEQ_LEN: int = 95
32
+ MODEL_URL: str = "https://nlp.stanford.edu/projects/vhelm/cogview2/sharefs.zip"
33
+
34
+ def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
35
+ self._cache = Cache(cache_config)
36
+ self._file_cache: FileCache = file_cache
37
+
38
+ self._args: Optional[argparse.Namespace] = None
39
+ self._strategy: Optional[CoglmStrategy] = None
40
+ self._model = None
41
+ self._srg = None
42
+
43
+ def _get_model(self) -> None:
44
+ try:
45
+ from SwissArmyTransformer import get_args
46
+ from helm.clients.image_generation.cogview2.coglm_utils import (
47
+ get_recipe,
48
+ InferenceModel,
49
+ )
50
+ from helm.clients.image_generation.cogview2.sr_pipeline import SRGroup
51
+ except ModuleNotFoundError as e:
52
+ handle_module_not_found_error(e, ["heim"])
53
+
54
+ tokenizer.add_special_tokens(["<start_of_image>", "<start_of_english>", "<start_of_chinese>"])
55
+
56
+ model_local_path: str = f"{self._file_cache._location}/cogview2" # type: ignore
57
+ os.environ["SAT_HOME"] = f"{model_local_path}/sharefs/cogview-new"
58
+
59
+ # Download the model if not yet
60
+ if not os.path.exists(model_local_path):
61
+ os.system(f"mkdir -p {model_local_path}")
62
+ os.system(f"wget {self.MODEL_URL} -P {model_local_path}")
63
+ os.system(f"unzip {model_local_path}/sharefs.zip -d {model_local_path}")
64
+
65
+ if self._model is None:
66
+ # Set up args
67
+ args = get_args("--mode inference --fp16".split())
68
+ self._args = argparse.Namespace(**vars(args), **get_recipe("none"))
69
+ self._args.img_size = 160
70
+ self._args.only_first_stage = False
71
+ self._args.inverse_prompt = False
72
+ self._args.batch_size = 1
73
+ self._args.max_inference_batch_size = 1
74
+
75
+ # Load the model components
76
+ self._model, self._args = InferenceModel.from_pretrained(self._args, "coglm")
77
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
78
+ self._strategy = CoglmStrategy(
79
+ invalid_slices,
80
+ temperature=getattr(self._args, "temp_all_gen"),
81
+ top_k=getattr(self._args, "topk_gen"),
82
+ top_k_cluster=getattr(self._args, "temp_cluster_gen"),
83
+ )
84
+ self._srg = SRGroup(self._args) # type: ignore
85
+
86
+ def _model_inference(self, prompt) -> torch.Tensor:
87
+ try:
88
+ from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
89
+ from helm.clients.image_generation.cogview2.coglm_utils import get_masks_and_position_ids_coglm
90
+ except ModuleNotFoundError as e:
91
+ handle_module_not_found_error(e, ["heim"])
92
+
93
+ with torch.no_grad():
94
+ text = getattr(self._args, "query_template").format(prompt)
95
+ seq = tokenizer.encode(text)
96
+ if len(seq) > self.MAX_SEQ_LEN:
97
+ seq = seq[: self.MAX_SEQ_LEN - 2] + seq[-2:]
98
+ txt_len = len(seq) - 1
99
+ device = getattr(self._args, "device")
100
+ seq = torch.tensor(seq + [-1] * 400, device=device)
101
+ # calibrate text length
102
+ log_attention_weights = torch.zeros(
103
+ len(seq), len(seq), device=device, dtype=torch.half if getattr(self._args, "fp16") else torch.float32
104
+ )
105
+ log_attention_weights[:, :txt_len] = getattr(self._args, "attn_plus")
106
+ # generation
107
+ mbz = getattr(self._args, "max_inference_batch_size")
108
+ batch_size = getattr(self._args, "batch_size")
109
+ assert batch_size < mbz or batch_size % mbz == 0
110
+ get_func = partial(get_masks_and_position_ids_coglm, context_length=txt_len)
111
+ output_list = []
112
+ for tim in range(max(batch_size // mbz, 1)):
113
+ setattr(self._strategy, "start_pos", txt_len + 1)
114
+ coarse_samples = filling_sequence(
115
+ self._model,
116
+ seq.clone(),
117
+ batch_size=min(batch_size, mbz),
118
+ strategy=self._strategy,
119
+ log_attention_weights=log_attention_weights,
120
+ get_masks_and_position_ids=get_func,
121
+ )[0]
122
+ output_list.append(coarse_samples)
123
+
124
+ output_tokens = torch.cat(output_list, dim=0)
125
+ images = []
126
+ iter_tokens = getattr(self._srg, "sr_base")(output_tokens[:, -400:], seq[:txt_len])
127
+ for seq in iter_tokens:
128
+ decoded_img = tokenizer.decode(image_ids=seq[-3600:])
129
+ decoded_img = torch.nn.functional.interpolate(decoded_img, size=(480, 480))
130
+ images.append(decoded_img) # only the last image (target)
131
+ return images[0]
132
+
133
+ def make_request(self, request: Request) -> RequestResult:
134
+ raw_request = {
135
+ "prompt": request.prompt,
136
+ }
137
+
138
+ try:
139
+
140
+ def do_it() -> Dict[str, Any]:
141
+ prompt: str = request.prompt
142
+
143
+ with htrack_block(f"Generating images for prompt: {prompt}"):
144
+ self._get_model()
145
+
146
+ images: List[torch.Tensor] = []
147
+ for _ in range(request.num_completions):
148
+ output = self._model_inference(**raw_request).cpu() # (1, 3, 480, 480)
149
+ images.append(output)
150
+
151
+ assert (
152
+ len(images) == request.num_completions
153
+ ), f"Expected {request.num_completions} images, but got {len(images)}"
154
+
155
+ result: Dict = {"file_locations": []}
156
+ for image in images:
157
+ # Write out the image to a file and save the path
158
+ file_location: str = self._file_cache.generate_unique_new_file_path() # type: ignore
159
+ save_image(image, file_location, normalize=True)
160
+ hlog(f"Image saved at {file_location}.")
161
+ result["file_locations"].append(file_location)
162
+ return result
163
+
164
+ # Include the model name and number of completions in the cache key
165
+ cache_key = CachingClient.make_cache_key(
166
+ {"model": request.model_engine, "n": request.num_completions, **raw_request}, request
167
+ )
168
+ results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
169
+ except RuntimeError as e:
170
+ error: str = f"CogView2Client error: {e}"
171
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
172
+
173
+ completions: List[GeneratedOutput] = [
174
+ GeneratedOutput(
175
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(location)
176
+ )
177
+ for location in results["file_locations"]
178
+ ]
179
+ return RequestResult(
180
+ success=True,
181
+ cached=cached,
182
+ request_time=results["request_time"],
183
+ completions=completions,
184
+ embedding=[],
185
+ )
186
+
187
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
188
+ raise NotImplementedError("This client does not support tokenizing.")
189
+
190
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
191
+ raise NotImplementedError("This client does not support decoding.")
@@ -0,0 +1,192 @@
1
+ from typing import Any, Dict, List, Optional
2
+ import base64
3
+
4
+ from helm.common.cache import CacheConfig, Cache
5
+ from helm.common.general import hlog
6
+ from helm.common.file_caches.file_cache import FileCache
7
+ from helm.common.media_object import MultimediaObject
8
+ from helm.common.optional_dependencies import handle_module_not_found_error
9
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
10
+ from helm.common.tokenization_request import (
11
+ TokenizationRequest,
12
+ TokenizationRequestResult,
13
+ DecodeRequest,
14
+ DecodeRequestResult,
15
+ )
16
+ from helm.clients.moderation_api_client import ModerationAPIClient
17
+ from helm.clients.client import Client, CachingClient
18
+ from .image_generation_client_utils import get_single_image_multimedia_object
19
+
20
+ try:
21
+ import openai
22
+ from openai import OpenAI
23
+ except ModuleNotFoundError as missing_module_exception:
24
+ handle_module_not_found_error(missing_module_exception, ["openai"])
25
+
26
+
27
+ class DALLE2Client(Client):
28
+ MAX_PROMPT_LENGTH: int = 1000
29
+ DEFAULT_IMAGE_SIZE_STR: str = "512x512"
30
+ VALID_IMAGE_SIZES: List[str] = ["256x256", DEFAULT_IMAGE_SIZE_STR, "1024x1024"]
31
+
32
+ # Set the finish reason to this if the prompt violates OpenAI's content policy
33
+ CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
34
+ "The prompt violates OpenAI's content policy. "
35
+ "See https://labs.openai.com/policies/content-policy for more information."
36
+ )
37
+
38
+ # The DALL-E API will respond with the following error messages (or even a substring of the message)
39
+ # if it has any issues generating images for a particular prompt
40
+ PROMPT_FLAGGED_ERROR: str = (
41
+ "Your request was rejected as a result of our safety system. "
42
+ "Your prompt may contain text that is not allowed by our safety system."
43
+ )
44
+ PROMPT_FLAGGED_ERROR2: str = (
45
+ "Something went wrong with your generation. You may try again or ask for a different prompt"
46
+ )
47
+ PROMPT_FLAGGED_ERROR3: str = (
48
+ "The server had an error while processing your request. Sorry about that! You can retry your request, "
49
+ "or contact us through our help center at help.openai.com if the error persists."
50
+ )
51
+
52
+ def __init__(
53
+ self,
54
+ api_key: str,
55
+ cache_config: CacheConfig,
56
+ file_cache: FileCache,
57
+ moderation_api_client: ModerationAPIClient,
58
+ org_id: Optional[str] = None,
59
+ ):
60
+ self.file_cache: FileCache = file_cache
61
+ self._cache = Cache(cache_config)
62
+
63
+ self.client = OpenAI(api_key=api_key, organization=org_id)
64
+ self.moderation_api_client: ModerationAPIClient = moderation_api_client
65
+
66
+ def get_content_policy_violated_result(self, request: Request) -> RequestResult:
67
+ """
68
+ Return a RequestResult with no images and a finish reason indicating that the prompt / generated images
69
+ violate OpenAI's content policy.
70
+ """
71
+ no_image = GeneratedOutput(
72
+ text="",
73
+ logprob=0,
74
+ tokens=[],
75
+ multimodal_content=MultimediaObject(),
76
+ finish_reason={"reason": self.CONTENT_POLICY_VIOLATED_FINISH_REASON},
77
+ )
78
+ return RequestResult(
79
+ success=True,
80
+ cached=False,
81
+ request_time=0,
82
+ completions=[no_image] * request.num_completions,
83
+ embedding=[],
84
+ )
85
+
86
+ def get_size_str(self, request: Request) -> str:
87
+ """
88
+ Return the size string for the image generation request.
89
+ If the request does not specify a size, return the default size.
90
+ """
91
+ assert request.image_generation_parameters is not None
92
+ w: Optional[int] = request.image_generation_parameters.output_image_width
93
+ h: Optional[int] = request.image_generation_parameters.output_image_height
94
+ if w is None or h is None:
95
+ return self.DEFAULT_IMAGE_SIZE_STR
96
+
97
+ image_dimensions: str = f"{w}x{h}"
98
+ assert image_dimensions in self.VALID_IMAGE_SIZES, f"Valid image sizes are {self.VALID_IMAGE_SIZES}"
99
+ return image_dimensions
100
+
101
+ def fail_if_invalid_request(self, request: Request) -> None:
102
+ """
103
+ Validate the request to ensure it is a valid request for the DALL-E API.
104
+ """
105
+ assert request.image_generation_parameters is not None
106
+ if len(request.prompt) > self.MAX_PROMPT_LENGTH:
107
+ raise ValueError("The maximum length of the prompt is 1000 characters.")
108
+ if request.num_completions < 1 or request.num_completions > 10:
109
+ raise ValueError("`num_completions` must be between 1 and 10.")
110
+
111
+ def handle_openai_error(self, request: Request, error: Exception) -> RequestResult:
112
+ """
113
+ Handle a thrown error from the DALL-E API.
114
+ """
115
+ if (
116
+ str(error) in self.PROMPT_FLAGGED_ERROR
117
+ # Sometimes the DALL-E API will add additional information to the error message.
118
+ or self.PROMPT_FLAGGED_ERROR2 in str(error)
119
+ or self.PROMPT_FLAGGED_ERROR3 in str(error)
120
+ ):
121
+ # Some requests fail even if we check the prompt against the moderation API.
122
+ # For example, "black" in Spanish (negro) causes requests to DALL-E to fail even
123
+ # though the prompt does not get flagged by the Moderation API.
124
+ hlog(f"Failed safety check: {request.prompt}")
125
+ return self.get_content_policy_violated_result(request)
126
+ else:
127
+ return RequestResult(
128
+ success=False, cached=False, error=f"DALL-E error: {error}", completions=[], embedding=[]
129
+ )
130
+
131
+ def generate_with_dalle_api(self, raw_request: Dict[str, Any]) -> Dict:
132
+ """
133
+ Makes a single request to generate the images with the DALL-E API.
134
+ """
135
+ result = self.client.images.generate(**raw_request).model_dump(mode="json")
136
+ assert "data" in result, f"Invalid response: {result} from prompt: {raw_request['prompt']}"
137
+
138
+ for image in result["data"]:
139
+ # Write out the image to a file and save the path
140
+ image["file_path"] = self.file_cache.store(lambda: base64.b64decode(image["b64_json"]))
141
+ # Don't cache contents of `b64_json` as we already have the image stored
142
+ image.pop("b64_json", None)
143
+ return result
144
+
145
+ def make_request(self, request: Request) -> RequestResult:
146
+ self.fail_if_invalid_request(request)
147
+
148
+ # Use the Moderation API to check if the prompt violates OpenAI's content policy before generating images
149
+ if self.moderation_api_client.will_be_flagged(request.prompt):
150
+ return self.get_content_policy_violated_result(request)
151
+
152
+ # https://beta.openai.com/docs/api-reference/images/create#images/create-response_format
153
+ raw_request: Dict[str, Any] = {
154
+ "prompt": request.prompt,
155
+ "n": request.num_completions,
156
+ "size": self.get_size_str(request),
157
+ "response_format": "b64_json", # Always set to b64_json as URLs are only valid for an hour
158
+ }
159
+
160
+ try:
161
+
162
+ def do_it() -> Dict[str, Any]:
163
+ # To maintain backwards compatibility, specify the model in the request but not in the cache key
164
+ return self.generate_with_dalle_api({"model": "dall-e-2", **raw_request})
165
+
166
+ cache_key = CachingClient.make_cache_key(raw_request, request)
167
+ response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
168
+ except openai.OpenAIError as e:
169
+ return self.handle_openai_error(request, e)
170
+
171
+ completions: List[GeneratedOutput] = [
172
+ GeneratedOutput(
173
+ text="",
174
+ logprob=0,
175
+ tokens=[],
176
+ multimodal_content=get_single_image_multimedia_object(generated_image["file_path"]),
177
+ )
178
+ for generated_image in response["data"]
179
+ ]
180
+ return RequestResult(
181
+ success=True,
182
+ cached=cached,
183
+ request_time=response["request_time"],
184
+ completions=completions,
185
+ embedding=[],
186
+ )
187
+
188
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
189
+ raise NotImplementedError("This client does not support tokenizing.")
190
+
191
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
192
+ raise NotImplementedError("This client does not support decoding.")
@@ -0,0 +1,108 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from helm.common.cache import CacheConfig
4
+ from helm.common.file_caches.file_cache import FileCache
5
+ from helm.common.general import singleton
6
+ from helm.common.optional_dependencies import handle_module_not_found_error
7
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
8
+ from helm.clients.moderation_api_client import ModerationAPIClient
9
+ from helm.clients.client import CachingClient
10
+ from .dalle2_client import DALLE2Client
11
+ from .image_generation_client_utils import get_single_image_multimedia_object
12
+
13
+ try:
14
+ import openai
15
+ except ModuleNotFoundError as missing_module_exception:
16
+ handle_module_not_found_error(missing_module_exception, ["openai"])
17
+
18
+
19
+ class DALLE3Client(DALLE2Client):
20
+ """
21
+ Client for the OpenAI's DALL-E 3 API.
22
+ DALL-E 3 cookbook with explanations for the different parameters:
23
+ https://cookbook.openai.com/articles/what_is_new_with_dalle_3
24
+ """
25
+
26
+ DEFAULT_IMAGE_SIZE_STR: str = "1024x1024"
27
+ VALID_IMAGE_SIZES: List[str] = [DEFAULT_IMAGE_SIZE_STR, "1792x1024", "1024x1792"]
28
+
29
+ def __init__(
30
+ self,
31
+ api_key: str,
32
+ cache_config: CacheConfig,
33
+ file_cache: FileCache,
34
+ moderation_api_client: ModerationAPIClient,
35
+ org_id: Optional[str] = None,
36
+ ):
37
+ super().__init__(api_key, cache_config, file_cache, moderation_api_client, org_id)
38
+
39
+ def make_request(self, request: Request) -> RequestResult:
40
+ self.fail_if_invalid_request(request)
41
+ if self.moderation_api_client.will_be_flagged(request.prompt):
42
+ return self.get_content_policy_violated_result(request)
43
+
44
+ raw_request: Dict[str, Any] = {
45
+ "model": "dall-e-3",
46
+ "prompt": request.prompt,
47
+ "n": 1, # As of December 2023, the DALL-E 3 API only supports a single generated image per request
48
+ "size": self.get_size_str(request),
49
+ "response_format": "b64_json", # Always set to b64_json as URLs are only valid for an hour
50
+ }
51
+
52
+ if request.model_engine == "dall-e-3":
53
+ raw_request["quality"] = "standard"
54
+ raw_request["style"] = "vivid"
55
+ elif request.model_engine == "dall-e-3-natural":
56
+ raw_request["quality"] = "standard"
57
+ raw_request["style"] = "natural"
58
+ elif request.model_engine == "dall-e-3-hd":
59
+ raw_request["quality"] = "hd"
60
+ raw_request["style"] = "vivid"
61
+ elif request.model_engine == "dall-e-3-hd-natural":
62
+ raw_request["quality"] = "hd"
63
+ raw_request["style"] = "natural"
64
+ else:
65
+ raise ValueError(f"Invalid DALL-E 3 model: {request.model_engine}")
66
+
67
+ responses: List[Dict[str, Any]] = []
68
+ all_cached: bool = True
69
+
70
+ # Since the DALL-E 3 API only supports a single generated image, make `request.num_completions` requests
71
+ for completion_index in range(request.num_completions):
72
+ try:
73
+
74
+ def do_it() -> Dict[str, Any]:
75
+ return self.generate_with_dalle_api({**raw_request})
76
+
77
+ cache_key = CachingClient.make_cache_key({"completion_index": completion_index, **raw_request}, request)
78
+ response, cached = self._cache.get(cache_key, wrap_request_time(do_it))
79
+
80
+ responses.append(response)
81
+ all_cached = all_cached and cached
82
+ except openai.OpenAIError as e:
83
+ return self.handle_openai_error(request, e)
84
+
85
+ completions: List[GeneratedOutput] = []
86
+ total_request_time: float = 0
87
+ for response in responses:
88
+ image_response: Dict[str, Any] = singleton(response["data"])
89
+ completions.append(
90
+ GeneratedOutput(
91
+ # From https://cookbook.openai.com/articles/what_is_new_with_dalle_3,
92
+ # "a new feature in the latest DALL·E-3 API is prompt rewriting, where we use
93
+ # GPT-4 to optimize all of your prompts before they’re passed to DALL-E."
94
+ text=image_response["revised_prompt"],
95
+ multimodal_content=get_single_image_multimedia_object(image_response["file_path"]),
96
+ logprob=0,
97
+ tokens=[],
98
+ )
99
+ )
100
+ total_request_time += response["request_time"]
101
+
102
+ return RequestResult(
103
+ success=True,
104
+ cached=all_cached,
105
+ request_time=total_request_time,
106
+ completions=completions,
107
+ embedding=[],
108
+ )
@@ -0,0 +1,3 @@
1
+ __version__ = "0.1.4"
2
+
3
+ from .model import DalleBart, DalleBartProcessor