crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +31 -3
  5. helm/benchmark/adaptation/adapters/adapter.py +2 -2
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
  9. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
  10. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  11. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
  12. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  13. helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
  14. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
  15. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
  16. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  17. helm/benchmark/adaptation/request_state.py +6 -1
  18. helm/benchmark/adaptation/scenario_state.py +6 -2
  19. helm/benchmark/annotation/annotator.py +43 -0
  20. helm/benchmark/annotation/annotator_factory.py +61 -0
  21. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  22. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  23. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  24. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  25. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  26. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  27. helm/benchmark/annotation_executor.py +124 -0
  28. helm/benchmark/augmentations/data_augmenter.py +0 -2
  29. helm/benchmark/augmentations/gender_perturbation.py +1 -1
  30. helm/benchmark/augmentations/perturbation.py +8 -2
  31. helm/benchmark/augmentations/perturbation_description.py +1 -1
  32. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  33. helm/benchmark/augmentations/test_perturbation.py +11 -7
  34. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  35. helm/benchmark/config_registry.py +7 -1
  36. helm/benchmark/executor.py +46 -16
  37. helm/benchmark/huggingface_registration.py +20 -7
  38. helm/benchmark/metrics/basic_metrics.py +169 -664
  39. helm/benchmark/metrics/bbq_metrics.py +3 -4
  40. helm/benchmark/metrics/bias_metrics.py +6 -6
  41. helm/benchmark/metrics/classification_metrics.py +11 -8
  42. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  43. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  44. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  45. helm/benchmark/metrics/common_metric_specs.py +167 -0
  46. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  47. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  48. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  49. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  50. helm/benchmark/metrics/disinformation_metrics.py +4 -110
  51. helm/benchmark/metrics/dry_run_metrics.py +2 -2
  52. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  53. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  54. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  55. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  56. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  57. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  58. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  59. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  60. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  61. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  62. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  63. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  64. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  65. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  66. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  67. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  68. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  69. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  70. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  71. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  72. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  73. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  74. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  75. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  76. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  77. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  78. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  79. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  80. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  81. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  82. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  83. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  84. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  85. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  86. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  87. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  88. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  89. helm/benchmark/metrics/machine_translation_metrics.py +89 -0
  90. helm/benchmark/metrics/metric.py +93 -172
  91. helm/benchmark/metrics/metric_name.py +0 -1
  92. helm/benchmark/metrics/metric_service.py +16 -0
  93. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  94. helm/benchmark/metrics/ranking_metrics.py +2 -2
  95. helm/benchmark/metrics/reference_metric.py +148 -0
  96. helm/benchmark/metrics/summac/model_summac.py +0 -2
  97. helm/benchmark/metrics/summarization_metrics.py +2 -2
  98. helm/benchmark/metrics/test_classification_metrics.py +8 -5
  99. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  100. helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
  101. helm/benchmark/metrics/test_metric.py +2 -2
  102. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
  103. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  104. helm/benchmark/metrics/toxicity_utils.py +23 -0
  105. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  106. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  107. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  108. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  109. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  110. helm/benchmark/model_deployment_registry.py +74 -0
  111. helm/benchmark/model_metadata_registry.py +36 -0
  112. helm/benchmark/multi_gpu_runner.py +133 -0
  113. helm/benchmark/presentation/create_plots.py +8 -7
  114. helm/benchmark/presentation/run_display.py +26 -10
  115. helm/benchmark/presentation/schema.py +15 -40
  116. helm/benchmark/presentation/summarize.py +119 -79
  117. helm/benchmark/presentation/table.py +8 -8
  118. helm/benchmark/presentation/test_contamination.py +2 -2
  119. helm/benchmark/presentation/test_run_entry.py +1 -2
  120. helm/benchmark/presentation/test_summarize.py +3 -3
  121. helm/benchmark/run.py +54 -26
  122. helm/benchmark/run_expander.py +214 -16
  123. helm/benchmark/run_spec.py +93 -0
  124. helm/benchmark/run_spec_factory.py +162 -0
  125. helm/benchmark/run_specs/__init__.py +0 -0
  126. helm/benchmark/run_specs/classic_run_specs.py +1510 -0
  127. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  128. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  129. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  130. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  131. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  132. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  133. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  134. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  135. helm/benchmark/runner.py +51 -57
  136. helm/benchmark/runner_config_registry.py +21 -0
  137. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  138. helm/benchmark/scenarios/bold_scenario.py +2 -2
  139. helm/benchmark/scenarios/code_scenario.py +1 -0
  140. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  141. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  142. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  143. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  144. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  145. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  146. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  147. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  148. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  149. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  150. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  151. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  152. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  153. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  154. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  155. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  156. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  157. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  158. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  159. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  160. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  161. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  162. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  163. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  164. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  165. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  166. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  167. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  168. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  169. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  170. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  171. helm/benchmark/scenarios/math_scenario.py +19 -2
  172. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  173. helm/benchmark/scenarios/numeracy_scenario.py +1 -1
  174. helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
  175. helm/benchmark/scenarios/scenario.py +4 -0
  176. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  177. helm/benchmark/scenarios/test_math_scenario.py +6 -0
  178. helm/benchmark/scenarios/test_scenario.py +6 -3
  179. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  180. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  181. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  182. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  183. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  184. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  185. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  186. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  187. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  188. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  189. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  190. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  191. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  192. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  193. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  194. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  195. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  196. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  197. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  198. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  199. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  200. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  201. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  202. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  203. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  204. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  205. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
  206. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  207. helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
  208. helm/benchmark/server.py +24 -1
  209. helm/benchmark/slurm_runner.py +70 -49
  210. helm/benchmark/static/benchmarking.js +1 -1
  211. helm/benchmark/static/schema_classic.yaml +258 -1066
  212. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  213. helm/benchmark/static/schema_lite.yaml +2 -227
  214. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  215. helm/benchmark/static/schema_unitxt.yaml +428 -0
  216. helm/benchmark/static/schema_vlm.yaml +576 -0
  217. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  218. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  219. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  220. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  221. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  222. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  223. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  224. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  225. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  226. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  227. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  228. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  229. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  230. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  231. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  232. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  233. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  234. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  235. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  236. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  237. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  238. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  239. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  240. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  241. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  242. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  243. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  244. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  245. helm/benchmark/static_build/config.js +4 -0
  246. helm/benchmark/static_build/index.html +20 -0
  247. helm/benchmark/test_data_preprocessor.py +3 -3
  248. helm/benchmark/test_model_deployment_definition.py +14 -16
  249. helm/benchmark/test_run_expander.py +1 -1
  250. helm/benchmark/window_services/ai21_window_service.py +22 -33
  251. helm/benchmark/window_services/cohere_window_service.py +1 -63
  252. helm/benchmark/window_services/default_window_service.py +2 -44
  253. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  254. helm/benchmark/window_services/ice_window_service.py +0 -34
  255. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  256. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  257. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  258. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  259. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  260. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  261. helm/benchmark/window_services/local_window_service.py +21 -4
  262. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  263. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  264. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  265. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  266. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  267. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  268. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  269. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  270. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  271. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  272. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  273. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  274. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  275. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  276. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  277. helm/benchmark/window_services/test_utils.py +3 -2
  278. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  279. helm/benchmark/window_services/window_service.py +42 -0
  280. helm/benchmark/window_services/window_service_factory.py +4 -1
  281. helm/benchmark/window_services/yalm_window_service.py +0 -27
  282. helm/clients/__init__.py +0 -0
  283. helm/{proxy/clients → clients}/ai21_client.py +3 -9
  284. helm/clients/aleph_alpha_client.py +112 -0
  285. helm/{proxy/clients → clients}/anthropic_client.py +203 -18
  286. helm/{proxy/clients → clients}/auto_client.py +59 -31
  287. helm/clients/bedrock_client.py +128 -0
  288. helm/clients/bedrock_utils.py +72 -0
  289. helm/{proxy/clients → clients}/client.py +65 -7
  290. helm/clients/clip_score_client.py +49 -0
  291. helm/clients/clip_scorers/__init__.py +0 -0
  292. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  293. helm/clients/clip_scorers/clip_scorer.py +50 -0
  294. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  295. helm/{proxy/clients → clients}/cohere_client.py +4 -11
  296. helm/clients/gcs_client.py +82 -0
  297. helm/{proxy/clients → clients}/google_client.py +5 -5
  298. helm/clients/google_translate_client.py +35 -0
  299. helm/{proxy/clients → clients}/http_model_client.py +5 -7
  300. helm/{proxy/clients → clients}/huggingface_client.py +43 -64
  301. helm/clients/image_generation/__init__.py +0 -0
  302. helm/clients/image_generation/adobe_vision_client.py +78 -0
  303. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  304. helm/clients/image_generation/cogview2/__init__.py +0 -0
  305. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  306. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  307. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  308. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  309. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  310. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  311. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  312. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  313. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  314. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  315. helm/clients/image_generation/cogview2_client.py +191 -0
  316. helm/clients/image_generation/dalle2_client.py +192 -0
  317. helm/clients/image_generation/dalle3_client.py +108 -0
  318. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  319. helm/clients/image_generation/dalle_mini/data.py +442 -0
  320. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  321. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  322. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  323. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  324. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  325. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  326. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  327. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  328. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  329. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  330. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  331. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  332. helm/clients/image_generation/dalle_mini_client.py +190 -0
  333. helm/clients/image_generation/deep_floyd_client.py +78 -0
  334. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  335. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  336. helm/clients/image_generation/lexica_client.py +86 -0
  337. helm/clients/image_generation/mindalle/__init__.py +0 -0
  338. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  339. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  340. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  341. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  342. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  343. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  344. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  345. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  346. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  347. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  348. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  349. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  350. helm/clients/image_generation/mindalle_client.py +115 -0
  351. helm/clients/image_generation/nudity_check_client.py +64 -0
  352. helm/clients/image_generation/together_image_generation_client.py +111 -0
  353. helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
  354. helm/{proxy/clients → clients}/megatron_client.py +5 -5
  355. helm/clients/mistral_client.py +134 -0
  356. helm/clients/moderation_api_client.py +109 -0
  357. helm/clients/open_lm_client.py +43 -0
  358. helm/clients/openai_client.py +302 -0
  359. helm/{proxy/clients → clients}/palmyra_client.py +6 -8
  360. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  361. helm/clients/simple_client.py +64 -0
  362. helm/{proxy/clients → clients}/test_auto_client.py +13 -15
  363. helm/clients/test_client.py +100 -0
  364. helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
  365. helm/clients/test_simple_client.py +19 -0
  366. helm/{proxy/clients → clients}/test_together_client.py +20 -8
  367. helm/{proxy/clients → clients}/together_client.py +12 -72
  368. helm/clients/vertexai_client.py +391 -0
  369. helm/clients/vision_language/__init__.py +0 -0
  370. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  371. helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
  372. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  373. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  374. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  375. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  376. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  377. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  378. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  379. helm/clients/vision_language/open_flamingo_client.py +155 -0
  380. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  381. helm/clients/vllm_client.py +46 -0
  382. helm/common/cache.py +16 -4
  383. helm/common/cache_backend_config.py +47 -0
  384. helm/common/clip_score_request.py +41 -0
  385. helm/common/file_caches/__init__.py +0 -0
  386. helm/common/file_caches/file_cache.py +16 -0
  387. helm/common/file_caches/local_file_cache.py +61 -0
  388. helm/common/file_caches/test_local_file_cache.py +25 -0
  389. helm/common/file_upload_request.py +27 -0
  390. helm/common/general.py +1 -1
  391. helm/common/image_generation_parameters.py +25 -0
  392. helm/common/images_utils.py +24 -1
  393. helm/common/key_value_store.py +35 -4
  394. helm/common/media_object.py +13 -0
  395. helm/common/moderations_api_request.py +71 -0
  396. helm/common/mongo_key_value_store.py +3 -3
  397. helm/common/multimodal_request_utils.py +31 -0
  398. helm/common/nudity_check_request.py +29 -0
  399. helm/common/request.py +15 -17
  400. helm/common/test_general.py +6 -0
  401. helm/common/tokenization_request.py +1 -1
  402. helm/config/model_deployments.yaml +1069 -546
  403. helm/config/model_metadata.yaml +753 -31
  404. helm/config/tokenizer_configs.yaml +142 -43
  405. helm/proxy/accounts.py +31 -4
  406. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  407. helm/proxy/critique/model_critique_client.py +8 -6
  408. helm/proxy/example_queries.py +29 -17
  409. helm/proxy/server.py +70 -5
  410. helm/proxy/services/remote_service.py +31 -0
  411. helm/proxy/services/server_service.py +96 -16
  412. helm/proxy/services/service.py +30 -0
  413. helm/proxy/services/test_remote_service.py +4 -3
  414. helm/proxy/services/test_service.py +0 -12
  415. helm/proxy/test_accounts.py +32 -0
  416. helm/proxy/token_counters/auto_token_counter.py +37 -37
  417. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  418. helm/proxy/token_counters/token_counter.py +3 -5
  419. helm/tokenizers/__init__.py +0 -0
  420. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  421. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
  422. helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
  423. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  424. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  425. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
  426. helm/tokenizers/simple_tokenizer.py +33 -0
  427. helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
  428. helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
  429. helm/tokenizers/test_simple_tokenizer.py +33 -0
  430. helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
  431. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  432. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  433. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  434. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  435. crfm_helm-0.4.0.dist-info/RECORD +0 -397
  436. helm/benchmark/run_specs.py +0 -2762
  437. helm/benchmark/test_model_properties.py +0 -1570
  438. helm/benchmark/vlm_run_specs.py +0 -97
  439. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  440. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  441. helm/benchmark/window_services/huggingface_window_service.py +0 -60
  442. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  443. helm/benchmark/window_services/t511b_window_service.py +0 -30
  444. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  445. helm/benchmark/window_services/ul2_window_service.py +0 -30
  446. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  447. helm/common/cache_utils.py +0 -14
  448. helm/proxy/clients/aleph_alpha_client.py +0 -95
  449. helm/proxy/clients/goose_ai_client.py +0 -99
  450. helm/proxy/clients/microsoft_client.py +0 -180
  451. helm/proxy/clients/openai_client.py +0 -206
  452. helm/proxy/clients/simple_client.py +0 -60
  453. helm/proxy/clients/test_client.py +0 -49
  454. helm/proxy/clients/vertexai_client.py +0 -115
  455. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  456. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  457. helm/proxy/token_counters/free_token_counter.py +0 -12
  458. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  459. helm/proxy/token_counters/openai_token_counter.py +0 -22
  460. helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
  461. helm/proxy/token_counters/test_openai_token_counter.py +0 -81
  462. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  463. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  464. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  465. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  466. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  467. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  468. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  469. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  470. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  471. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  472. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  473. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  474. /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
  475. /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
  476. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  477. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  478. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  479. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  480. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  481. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  482. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,267 @@
1
+ """
2
+ Based on: https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from einops_exts import rearrange_many
8
+ from torch import einsum, nn
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def FeedForward(dim, mult=4):
16
+ inner_dim = int(dim * mult)
17
+ return nn.Sequential(
18
+ nn.LayerNorm(dim),
19
+ nn.Linear(dim, inner_dim, bias=False),
20
+ nn.GELU(),
21
+ nn.Linear(inner_dim, dim, bias=False),
22
+ )
23
+
24
+
25
+ class PerceiverAttention(nn.Module):
26
+ def __init__(self, *, dim, dim_head=64, heads=8):
27
+ super().__init__()
28
+ self.scale = dim_head**-0.5
29
+ self.heads = heads
30
+ inner_dim = dim_head * heads
31
+
32
+ self.norm_media = nn.LayerNorm(dim)
33
+ self.norm_latents = nn.LayerNorm(dim)
34
+
35
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
36
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
38
+
39
+ def forward(self, x, latents):
40
+ """
41
+ Args:
42
+ x (torch.Tensor): image features
43
+ shape (b, T, n1, D)
44
+ latent (torch.Tensor): latent features
45
+ shape (b, T, n2, D)
46
+ """
47
+ x = self.norm_media(x)
48
+ latents = self.norm_latents(latents)
49
+
50
+ h = self.heads
51
+
52
+ q = self.to_q(latents)
53
+ kv_input = torch.cat((x, latents), dim=-2)
54
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
56
+ q = q * self.scale
57
+
58
+ # attention
59
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
60
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
61
+ attn = sim.softmax(dim=-1)
62
+
63
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
64
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
65
+ return self.to_out(out)
66
+
67
+
68
+ class PerceiverResampler(nn.Module):
69
+ def __init__(
70
+ self,
71
+ *,
72
+ dim,
73
+ depth=6,
74
+ dim_head=64,
75
+ heads=8,
76
+ num_latents=64,
77
+ max_num_media=None,
78
+ max_num_frames=None,
79
+ ff_mult=4,
80
+ ):
81
+ super().__init__()
82
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
83
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
84
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
85
+
86
+ self.layers = nn.ModuleList([])
87
+ for _ in range(depth):
88
+ self.layers.append(
89
+ nn.ModuleList(
90
+ [
91
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
92
+ FeedForward(dim=dim, mult=ff_mult),
93
+ ]
94
+ )
95
+ )
96
+
97
+ self.norm = nn.LayerNorm(dim)
98
+
99
+ def forward(self, x):
100
+ """
101
+ Args:
102
+ x (torch.Tensor): image features
103
+ shape (b, T, F, v, D)
104
+ Returns:
105
+ shape (b, T, n, D) where n is self.num_latents
106
+ """
107
+ b, T, F, v = x.shape[:4]
108
+
109
+ # frame and media time embeddings
110
+ if exists(self.frame_embs):
111
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
112
+ x = x + frame_embs
113
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
114
+ if exists(self.media_time_embs):
115
+ x = x + self.media_time_embs[:T]
116
+
117
+ # blocks
118
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
119
+ for attn, ff in self.layers:
120
+ latents = attn(x, latents) + latents
121
+ latents = ff(latents) + latents
122
+ return self.norm(latents)
123
+
124
+
125
+ # gated cross attention
126
+ class MaskedCrossAttention(nn.Module):
127
+ def __init__(
128
+ self,
129
+ *,
130
+ dim,
131
+ dim_visual,
132
+ dim_head=64,
133
+ heads=8,
134
+ only_attend_immediate_media=True,
135
+ ):
136
+ super().__init__()
137
+ self.scale = dim_head**-0.5
138
+ self.heads = heads
139
+ inner_dim = dim_head * heads
140
+
141
+ self.norm = nn.LayerNorm(dim)
142
+
143
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
144
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
145
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
146
+
147
+ # whether for text to only attend to immediate preceding image, or all previous images
148
+ self.only_attend_immediate_media = only_attend_immediate_media
149
+
150
+ def forward(self, x, media, media_locations=None, use_cached_media=False):
151
+ """
152
+ Args:
153
+ x (torch.Tensor): text features
154
+ shape (B, T_txt, D_txt)
155
+ media (torch.Tensor): image features
156
+ shape (B, T_img, n, D_img) where n is the dim of the latents
157
+ media_locations: boolean mask identifying the media tokens in x
158
+ shape (B, T_txt)
159
+ use_cached_media: bool
160
+ If true, treat all of x as if they occur after the last media
161
+ registered in media_locations. T_txt does not need to exactly
162
+ equal media_locations.shape[1] in this case
163
+ """
164
+
165
+ if not use_cached_media:
166
+ assert (
167
+ media_locations.shape[1] == x.shape[1]
168
+ ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
169
+
170
+ T_txt = x.shape[1]
171
+ _, T_img, n = media.shape[:3]
172
+ h = self.heads
173
+
174
+ x = self.norm(x)
175
+
176
+ q = self.to_q(x)
177
+ media = rearrange(media, "b t n d -> b (t n) d")
178
+
179
+ k, v = self.to_kv(media).chunk(2, dim=-1)
180
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
181
+
182
+ q = q * self.scale
183
+
184
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
185
+
186
+ if exists(media_locations):
187
+ media_time = torch.arange(T_img, device=x.device) + 1
188
+
189
+ if use_cached_media:
190
+ # text time is set to the last cached media location
191
+ text_time = repeat(
192
+ torch.count_nonzero(media_locations, dim=1),
193
+ "b -> b i",
194
+ i=T_txt,
195
+ )
196
+ else:
197
+ # at each boolean of True, increment the time counter (relative to media time)
198
+ text_time = media_locations.cumsum(dim=-1)
199
+
200
+ # text time must equal media time if only attending to most immediate image
201
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
202
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
203
+
204
+ text_to_media_mask = mask_op(
205
+ rearrange(text_time, "b i -> b 1 i 1"),
206
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
207
+ )
208
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
209
+
210
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
211
+ attn = sim.softmax(dim=-1)
212
+
213
+ if exists(media_locations) and self.only_attend_immediate_media:
214
+ # any text without a preceding media needs to have attention zeroed out
215
+ text_without_media_mask = text_time == 0
216
+ text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
217
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
218
+
219
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
220
+ out = rearrange(out, "b h n d -> b n (h d)")
221
+ return self.to_out(out)
222
+
223
+
224
+ class GatedCrossAttentionBlock(nn.Module):
225
+ def __init__(
226
+ self,
227
+ *,
228
+ dim,
229
+ dim_visual,
230
+ dim_head=64,
231
+ heads=8,
232
+ ff_mult=4,
233
+ only_attend_immediate_media=True,
234
+ ):
235
+ super().__init__()
236
+ self.attn = MaskedCrossAttention(
237
+ dim=dim,
238
+ dim_visual=dim_visual,
239
+ dim_head=dim_head,
240
+ heads=heads,
241
+ only_attend_immediate_media=only_attend_immediate_media,
242
+ )
243
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
244
+
245
+ self.ff = FeedForward(dim, mult=ff_mult)
246
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
247
+
248
+ def forward(
249
+ self,
250
+ x,
251
+ media,
252
+ media_locations=None,
253
+ use_cached_media=False,
254
+ ):
255
+ x = (
256
+ self.attn(
257
+ x,
258
+ media,
259
+ media_locations=media_locations,
260
+ use_cached_media=use_cached_media,
261
+ )
262
+ * self.attn_gate.tanh()
263
+ + x
264
+ )
265
+ x = self.ff(x) * self.ff_gate.tanh() + x
266
+
267
+ return x
@@ -0,0 +1,47 @@
1
+ """
2
+ Source: https://github.com/mlfoundations/open_flamingo
3
+ """
4
+
5
+
6
+ def extend_instance(obj, mixin):
7
+ """Apply mixins to a class instance after creation"""
8
+ base_cls = obj.__class__
9
+ base_cls_name = obj.__class__.__name__
10
+ obj.__class__ = type(
11
+ base_cls_name, (mixin, base_cls), {}
12
+ ) # mixin needs to go first for our forward() logic to work
13
+
14
+
15
+ def getattr_recursive(obj, att):
16
+ """
17
+ Return nested attribute of obj
18
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
19
+ """
20
+ if att == "":
21
+ return obj
22
+ i = att.find(".")
23
+ if i < 0:
24
+ return getattr(obj, att)
25
+ else:
26
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
27
+
28
+
29
+ def setattr_recursive(obj, att, val):
30
+ """
31
+ Set nested attribute of obj
32
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
33
+ """
34
+ if "." in att:
35
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
36
+ setattr(obj, att.split(".")[-1], val)
37
+
38
+
39
+ def apply_with_stopping_condition(module, apply_fn, apply_condition=None, stopping_condition=None, **other_args):
40
+ if stopping_condition(module):
41
+ return
42
+ if apply_condition(module):
43
+ apply_fn(module, **other_args)
44
+ for child in module.children():
45
+ apply_with_stopping_condition(
46
+ child, apply_fn, apply_condition=apply_condition, stopping_condition=stopping_condition, **other_args
47
+ )
@@ -0,0 +1,155 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from helm.common.cache import CacheConfig
8
+ from helm.common.hierarchical_logger import hlog, htrack_block
9
+ from helm.common.images_utils import open_image
10
+ from helm.common.gpu_utils import get_torch_device_name
11
+ from helm.common.media_object import TEXT_TYPE
12
+ from helm.common.optional_dependencies import handle_module_not_found_error
13
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
14
+ from helm.common.request import wrap_request_time
15
+ from helm.clients.vision_language.open_flamingo import create_model_and_transforms
16
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
17
+
18
+ try:
19
+ from PIL import Image
20
+ except ModuleNotFoundError as e:
21
+ handle_module_not_found_error(e, ["images"])
22
+
23
+
24
+ class OpenFlamingoClient(CachingClient):
25
+ """
26
+ OpenFlamingo is an open source implementation of DeepMind's Flamingo models.
27
+ Implementation following:
28
+ https://github.com/mlfoundations/open_flamingo
29
+ https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b
30
+ """
31
+
32
+ END_OF_CHUNK_TOKEN: str = "<|endofchunk|>"
33
+ IMAGE_TOKEN: str = "<image>"
34
+
35
+ _model_lock: Lock = Lock()
36
+
37
+ def __init__(
38
+ self,
39
+ cache_config: CacheConfig,
40
+ checkpoint_path: Optional[str] = None,
41
+ tokenizer_name: Optional[str] = None,
42
+ cross_attn_every_n_layers: int = 4,
43
+ ):
44
+ super().__init__(cache_config)
45
+ self._device: str = get_torch_device_name()
46
+ self._checkpoint_path: Optional[str] = checkpoint_path
47
+ self._tokenizer_name: Optional[str] = tokenizer_name
48
+ self._cross_attn_every_n_layers: int = cross_attn_every_n_layers
49
+
50
+ # Model
51
+ # The model is only initialized when the first request is made
52
+ # This is to avoid loading the model if it is not used
53
+ self._model: Optional[torch.nn.Module] = None
54
+
55
+ def _get_model(self):
56
+ if not self._checkpoint_path:
57
+ raise ValueError("OpenFlamingoClient requires a checkpoint path")
58
+ if not self._tokenizer_name:
59
+ raise ValueError("OpenFlamingoClient requires a tokenizer name")
60
+ with htrack_block("Initializing OpenFlamingo model"):
61
+ with self._model_lock:
62
+ self._model, self.image_processor, self.tokenizer = create_model_and_transforms(
63
+ clip_vision_encoder_path="ViT-L-14",
64
+ clip_vision_encoder_pretrained="openai",
65
+ lang_encoder_path=self._tokenizer_name,
66
+ tokenizer_path=self._tokenizer_name,
67
+ cross_attn_every_n_layers=self._cross_attn_every_n_layers,
68
+ )
69
+ self.tokenizer.padding_side = "left"
70
+ checkpoint_path = hf_hub_download(self._checkpoint_path, "checkpoint.pt")
71
+ self._model.load_state_dict(torch.load(checkpoint_path), strict=False)
72
+ self._model = self._model.to(self._device)
73
+ hlog(f"Loaded model to {self._device}.")
74
+
75
+ def make_request(self, request: Request) -> RequestResult:
76
+ assert request.multimodal_prompt is not None, "Multimodal prompt is required"
77
+
78
+ # Load model if needed
79
+ if self._model is None:
80
+ self._get_model()
81
+
82
+ # Build the prompt
83
+ prompt_text: str = ""
84
+ images: List[Image.Image] = []
85
+ for media_object in request.multimodal_prompt.media_objects:
86
+ if media_object.is_type("image") and media_object.location:
87
+ images.append(open_image(media_object.location))
88
+ prompt_text += self.IMAGE_TOKEN
89
+ elif media_object.is_type(TEXT_TYPE):
90
+ if media_object.text is None:
91
+ raise ValueError("MediaObject of text type has missing text field value")
92
+ prompt_text += media_object.text
93
+ else:
94
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
95
+
96
+ # Preprocess
97
+ vision_x: torch.Tensor = torch.cat([self.image_processor(image).unsqueeze(0) for image in images], dim=0)
98
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0)
99
+ lang_x = self.tokenizer([prompt_text], return_tensors="pt")
100
+
101
+ # Generate
102
+ try:
103
+ generation_args = {
104
+ "max_new_tokens": request.max_tokens,
105
+ "n": request.num_completions,
106
+ }
107
+
108
+ def do_it():
109
+ tensors = self._model.generate(
110
+ vision_x=vision_x.to(self._device),
111
+ lang_x=lang_x["input_ids"].to(self._device),
112
+ attention_mask=lang_x["attention_mask"].to(self._device),
113
+ max_new_tokens=generation_args["max_new_tokens"],
114
+ num_beams=generation_args["n"],
115
+ num_return_sequences=generation_args["n"],
116
+ )
117
+ generated_completions: List[Tuple[str, List[str]]] = []
118
+ for tensor in tensors:
119
+ generated_text: str = self.tokenizer.decode(tensor)
120
+ raw_tokens: List[str] = self.tokenizer.tokenize(generated_text)
121
+ generated_completions.append((generated_text, raw_tokens))
122
+
123
+ return {"output": generated_completions}
124
+
125
+ cache_key = CachingClient.make_cache_key(
126
+ raw_request={
127
+ "model": request.model,
128
+ "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
129
+ **generation_args,
130
+ },
131
+ request=request,
132
+ )
133
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
134
+ except RuntimeError as ex:
135
+ return RequestResult(success=False, cached=False, error=str(ex), completions=[], embedding=[])
136
+
137
+ completions: List[GeneratedOutput] = []
138
+ for text, tokens in result["output"]:
139
+ # Remove the prompt from the generated text
140
+ text = (
141
+ text[len(prompt_text) :].replace(self.END_OF_CHUNK_TOKEN, "").strip()
142
+ if len(text) >= len(prompt_text)
143
+ else text[-1]
144
+ )
145
+ completions.append(
146
+ GeneratedOutput(text=text, logprob=0, tokens=[Token(text=token, logprob=0) for token in tokens])
147
+ )
148
+
149
+ return RequestResult(
150
+ success=True,
151
+ cached=cached,
152
+ request_time=result["request_time"],
153
+ completions=completions,
154
+ embedding=[],
155
+ )
@@ -0,0 +1,171 @@
1
+ from threading import Lock
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from dataclasses import dataclass
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from transformers.generation import GenerationConfig
7
+
8
+ from helm.common.cache import CacheConfig
9
+ from helm.common.gpu_utils import get_torch_device_name
10
+ from helm.common.hierarchical_logger import hlog, htrack_block
11
+ from helm.common.media_object import TEXT_TYPE
12
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
13
+ from helm.common.request import wrap_request_time
14
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class LoadedQwenModelProcessor:
19
+ """Loaded model and processor for Qwen."""
20
+
21
+ model: AutoModelForCausalLM
22
+ tokenizer: AutoTokenizer
23
+
24
+
25
+ _models_lock: Lock = Lock()
26
+ _models: Dict[str, Optional[LoadedQwenModelProcessor]] = {
27
+ "Qwen/Qwen-VL": None,
28
+ "Qwen/Qwen-VL-Chat": None,
29
+ }
30
+
31
+
32
+ class QwenVLMClient(CachingClient):
33
+ """
34
+ From https://huggingface.co/Qwen/Qwen-VL,
35
+ Qwen-VL (Qwen Large Vision Language Model) is the visual multimodal version of the large model series,
36
+ Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. Qwen-VL accepts image, text, and bounding box
37
+ as inputs, outputs text and bounding box.
38
+ Alibaba released Qwen-VL and Qwen-VL-Chat, which is a chatbot model based on Qwen-VL.
39
+
40
+ Paper: https://arxiv.org/abs/2308.12966
41
+ """
42
+
43
+ END_OF_TEXT_TOKEN: str = "<|endoftext|>"
44
+
45
+ def __init__(self, cache_config: CacheConfig):
46
+ super().__init__(cache_config=cache_config)
47
+ self._device: str = get_torch_device_name()
48
+
49
+ def _get_model(self, helm_model_name: str) -> LoadedQwenModelProcessor:
50
+ global _models_lock
51
+ global _models
52
+
53
+ model_name: str
54
+ if helm_model_name == "qwen-vl-chat":
55
+ model_name = "Qwen/Qwen-VL-Chat"
56
+ elif helm_model_name == "qwen-vl":
57
+ model_name = "Qwen/Qwen-VL"
58
+ else:
59
+ raise ValueError(f"Unhandled model name: {helm_model_name}")
60
+
61
+ # Ensure that only one thread is loading the model at a time
62
+ with _models_lock:
63
+ loaded_model_processor = _models[model_name]
64
+ if loaded_model_processor is None:
65
+ hlog(f"Loading model {model_name} and caching in memory...")
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_name, device_map=self._device, trust_remote_code=True, bf16=True
68
+ ).eval()
69
+ if model_name == "Qwen/Qwen-VL-Chat":
70
+ model.generation_config = GenerationConfig.from_pretrained(model_name, trust_remote_code=True)
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
72
+ _models[model_name] = LoadedQwenModelProcessor(model, tokenizer)
73
+ loaded_model_processor = _models[model_name]
74
+
75
+ assert loaded_model_processor is not None
76
+ return loaded_model_processor
77
+
78
+ def make_request(self, request: Request) -> RequestResult:
79
+ assert request.multimodal_prompt is not None, "Multimodal prompt is required"
80
+
81
+ loaded_model_processor: LoadedQwenModelProcessor = self._get_model(request.model_engine)
82
+ model = loaded_model_processor.model
83
+ tokenizer = loaded_model_processor.tokenizer
84
+
85
+ generation_args = {
86
+ "max_length": request.max_tokens,
87
+ }
88
+
89
+ query: List[Dict[str, str]] = []
90
+ prompt_text: str = ""
91
+
92
+ image_index: int = 1
93
+ for media_object in request.multimodal_prompt.media_objects:
94
+ if media_object.is_type("image") and media_object.location:
95
+ query.append({"image": media_object.location})
96
+ prompt_text += f"Picture {image_index}: <img>{media_object.location}</img>\n"
97
+ image_index += 1
98
+ elif media_object.is_type(TEXT_TYPE):
99
+ if media_object.text is None:
100
+ raise ValueError("MediaObject of text type has missing text field value")
101
+
102
+ query.append({"text": media_object.text})
103
+ prompt_text += media_object.text
104
+ else:
105
+ raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
106
+
107
+ completions: List[GeneratedOutput] = []
108
+ request_time: float = 0
109
+ request_datetime: Optional[int] = None
110
+ all_cached: bool = True
111
+
112
+ with htrack_block(f"Generating for prompt: {prompt_text}"):
113
+ for completion_index in range(request.num_completions):
114
+ try:
115
+
116
+ def do_it() -> Dict[str, Any]:
117
+ if request.model_engine == "qwen-vl-chat":
118
+ completion, _ = model.chat(tokenizer, query=tokenizer.from_list_format(query), history=None)
119
+ else:
120
+ inputs = tokenizer(tokenizer.from_list_format(query), return_tensors="pt")
121
+ inputs = inputs.to(self._device)
122
+ pred = model.generate(**inputs, **generation_args)
123
+ completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
124
+
125
+ tokens: List[str] = tokenizer.tokenize(completion)
126
+ return {"output": (completion, tokens)}
127
+
128
+ # Include the prompt and model name in the cache key
129
+ cache_key = CachingClient.make_cache_key(
130
+ raw_request={
131
+ "completion_index": completion_index,
132
+ "model": request.model,
133
+ "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
134
+ **generation_args,
135
+ },
136
+ request=request,
137
+ )
138
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
139
+ except RuntimeError as model_error:
140
+ return RequestResult(
141
+ success=False, cached=False, error=str(model_error), completions=[], embedding=[]
142
+ )
143
+
144
+ text, tokens = result["output"]
145
+
146
+ # Truncate the output text as the original Qwen includes the prompt in the output sequence
147
+ if request.model_engine == "qwen-vl":
148
+ text = text[len(prompt_text) :]
149
+ text = text.replace(self.END_OF_TEXT_TOKEN, "")
150
+ hlog(f"Truncated: {text}")
151
+
152
+ # Tokenize truncated text to get the list of tokens
153
+ completions.append(
154
+ GeneratedOutput(
155
+ text=text, logprob=0, tokens=[Token(text=str(token), logprob=0) for token in tokens]
156
+ )
157
+ )
158
+
159
+ request_time += result["request_time"]
160
+ # Use the datetime from the first completion because that's when the request was fired
161
+ request_datetime = request_datetime or result.get("request_datetime")
162
+ all_cached = all_cached and cached
163
+
164
+ return RequestResult(
165
+ success=True,
166
+ cached=all_cached,
167
+ request_time=request_time,
168
+ request_datetime=request_datetime,
169
+ completions=completions,
170
+ embedding=[],
171
+ )
@@ -0,0 +1,46 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from helm.common.cache import CacheConfig
4
+ from helm.common.request import Request
5
+ from helm.clients.openai_client import OpenAIClient
6
+ from helm.tokenizers.tokenizer import Tokenizer
7
+
8
+
9
+ class VLLMClient(OpenAIClient):
10
+ """Sends request to a vLLM server using the OpenAI-compatible API.
11
+
12
+ See: https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server"""
13
+
14
+ def __init__(
15
+ self,
16
+ tokenizer: Tokenizer,
17
+ tokenizer_name: str,
18
+ cache_config: CacheConfig,
19
+ base_url: Optional[str] = None,
20
+ ):
21
+ super().__init__(
22
+ tokenizer=tokenizer,
23
+ tokenizer_name=tokenizer_name,
24
+ cache_config=cache_config,
25
+ api_key="EMPTY",
26
+ org_id=None,
27
+ base_url=base_url,
28
+ )
29
+ self.tokenizer = tokenizer
30
+ self.tokenizer_name = tokenizer_name
31
+
32
+ def _is_chat_model_engine(self, model_engine: str) -> bool:
33
+ # Only support vLLM completion models for now.
34
+ return False
35
+
36
+ def _get_model_for_request(self, request: Request) -> str:
37
+ # The `model` parameter for vLLM should be the whole model name including the creator organization,
38
+ # unlike OpenAI which only uses the model engine.
39
+ return request.model
40
+
41
+ def _to_raw_completion_request(self, request: Request) -> Dict[str, Any]:
42
+ raw_request = super()._to_raw_completion_request(request)
43
+ # This avoids the error: best_of must be 1 when using greedy sampling
44
+ if "best_of" in raw_request and raw_request["best_of"] > 1:
45
+ raw_request["best_of"] = 1
46
+ return raw_request