crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (499) hide show
  1. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
  2. crfm_helm-0.5.1.dist-info/RECORD +654 -0
  3. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +31 -3
  5. helm/benchmark/adaptation/adapters/adapter.py +2 -2
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
  9. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
  10. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  11. helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
  12. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  13. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
  14. helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
  15. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  16. helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
  17. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
  18. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
  19. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  20. helm/benchmark/adaptation/request_state.py +6 -1
  21. helm/benchmark/adaptation/scenario_state.py +6 -2
  22. helm/benchmark/annotation/annotator.py +43 -0
  23. helm/benchmark/annotation/annotator_factory.py +61 -0
  24. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  25. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  26. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  27. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  28. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  29. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  30. helm/benchmark/annotation_executor.py +124 -0
  31. helm/benchmark/augmentations/data_augmenter.py +0 -2
  32. helm/benchmark/augmentations/gender_perturbation.py +1 -1
  33. helm/benchmark/augmentations/perturbation.py +25 -3
  34. helm/benchmark/augmentations/perturbation_description.py +1 -1
  35. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  36. helm/benchmark/augmentations/test_perturbation.py +41 -7
  37. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  38. helm/benchmark/config_registry.py +7 -1
  39. helm/benchmark/executor.py +46 -16
  40. helm/benchmark/huggingface_registration.py +20 -7
  41. helm/benchmark/metrics/basic_metrics.py +169 -664
  42. helm/benchmark/metrics/bbq_metrics.py +3 -4
  43. helm/benchmark/metrics/bias_metrics.py +6 -6
  44. helm/benchmark/metrics/classification_metrics.py +11 -8
  45. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  46. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  47. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  48. helm/benchmark/metrics/common_metric_specs.py +167 -0
  49. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  50. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  51. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  52. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  53. helm/benchmark/metrics/disinformation_metrics.py +4 -110
  54. helm/benchmark/metrics/dry_run_metrics.py +2 -2
  55. helm/benchmark/metrics/efficiency_metrics.py +213 -0
  56. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  57. helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
  58. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  59. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  60. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  61. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  62. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  63. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  64. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  65. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  66. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  67. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  68. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  69. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  70. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  71. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  72. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  73. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  74. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  75. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  76. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  77. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  78. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  79. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  80. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  81. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  82. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  83. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  84. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  85. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  86. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  87. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  88. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  89. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  90. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  91. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  92. helm/benchmark/metrics/machine_translation_metrics.py +89 -0
  93. helm/benchmark/metrics/metric.py +93 -172
  94. helm/benchmark/metrics/metric_name.py +0 -1
  95. helm/benchmark/metrics/metric_service.py +16 -0
  96. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  97. helm/benchmark/metrics/ranking_metrics.py +2 -2
  98. helm/benchmark/metrics/reference_metric.py +148 -0
  99. helm/benchmark/metrics/summac/model_summac.py +0 -2
  100. helm/benchmark/metrics/summarization_metrics.py +2 -2
  101. helm/benchmark/metrics/test_classification_metrics.py +8 -5
  102. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  103. helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
  104. helm/benchmark/metrics/test_metric.py +2 -2
  105. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
  106. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  107. helm/benchmark/metrics/toxicity_utils.py +23 -0
  108. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  109. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  110. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  111. helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
  112. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  113. helm/benchmark/model_deployment_registry.py +74 -0
  114. helm/benchmark/model_metadata_registry.py +41 -1
  115. helm/benchmark/multi_gpu_runner.py +133 -0
  116. helm/benchmark/presentation/create_plots.py +8 -7
  117. helm/benchmark/presentation/run_display.py +26 -10
  118. helm/benchmark/presentation/schema.py +15 -40
  119. helm/benchmark/presentation/summarize.py +119 -79
  120. helm/benchmark/presentation/table.py +8 -8
  121. helm/benchmark/presentation/test_contamination.py +2 -2
  122. helm/benchmark/presentation/test_run_entry.py +1 -2
  123. helm/benchmark/presentation/test_summarize.py +3 -3
  124. helm/benchmark/run.py +54 -26
  125. helm/benchmark/run_expander.py +205 -35
  126. helm/benchmark/run_spec.py +93 -0
  127. helm/benchmark/run_spec_factory.py +163 -0
  128. helm/benchmark/run_specs/__init__.py +0 -0
  129. helm/benchmark/run_specs/classic_run_specs.py +1510 -0
  130. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  131. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  132. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  133. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  134. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  135. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  136. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  137. helm/benchmark/run_specs/vlm_run_specs.py +757 -0
  138. helm/benchmark/runner.py +51 -57
  139. helm/benchmark/runner_config_registry.py +21 -0
  140. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  141. helm/benchmark/scenarios/bold_scenario.py +2 -2
  142. helm/benchmark/scenarios/code_scenario.py +1 -0
  143. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  144. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  145. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  146. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  147. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  148. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  149. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  150. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  151. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  152. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  153. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  154. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  155. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  156. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  157. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  158. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  159. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  160. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  161. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  162. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  163. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  164. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  165. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  166. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  167. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  168. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  169. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  170. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  171. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  172. helm/benchmark/scenarios/legalbench_scenario.py +6 -2
  173. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  174. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  175. helm/benchmark/scenarios/math_scenario.py +19 -2
  176. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  177. helm/benchmark/scenarios/numeracy_scenario.py +1 -1
  178. helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
  179. helm/benchmark/scenarios/scenario.py +4 -0
  180. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  181. helm/benchmark/scenarios/test_math_scenario.py +6 -0
  182. helm/benchmark/scenarios/test_scenario.py +6 -3
  183. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  184. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  185. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  186. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  187. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  188. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
  189. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  190. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
  191. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
  192. helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
  193. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
  194. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  195. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  196. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  197. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  198. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  199. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  200. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  201. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  202. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  203. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  204. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  205. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  206. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
  207. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  208. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
  209. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  210. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  211. helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
  212. helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
  213. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  214. helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
  215. helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
  216. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  217. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  218. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  219. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
  220. helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
  221. helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
  222. helm/benchmark/server.py +24 -1
  223. helm/benchmark/slurm_runner.py +70 -49
  224. helm/benchmark/static/benchmarking.js +1 -1
  225. helm/benchmark/static/schema_classic.yaml +258 -1066
  226. helm/benchmark/static/schema_image2structure.yaml +304 -0
  227. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  228. helm/benchmark/static/schema_lite.yaml +2 -227
  229. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  230. helm/benchmark/static/schema_unitxt.yaml +428 -0
  231. helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
  232. helm/benchmark/static/schema_vlm.yaml +823 -0
  233. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  234. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  235. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  236. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  237. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  238. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  239. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  240. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  241. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  242. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  243. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  244. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  245. helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
  246. helm/benchmark/static_build/assets/index-878a1094.css +1 -0
  247. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  248. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  249. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  250. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  251. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  252. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  253. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  254. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  255. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  256. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  257. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  258. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  259. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  260. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  261. helm/benchmark/static_build/config.js +4 -0
  262. helm/benchmark/static_build/index.html +20 -0
  263. helm/benchmark/test_data_preprocessor.py +3 -3
  264. helm/benchmark/test_run_expander.py +1 -1
  265. helm/benchmark/window_services/ai21_window_service.py +22 -33
  266. helm/benchmark/window_services/cohere_window_service.py +1 -63
  267. helm/benchmark/window_services/default_window_service.py +2 -44
  268. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  269. helm/benchmark/window_services/ice_window_service.py +0 -34
  270. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  271. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  272. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  273. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  274. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  275. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  276. helm/benchmark/window_services/local_window_service.py +21 -4
  277. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  278. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  279. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  280. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  281. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  282. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  283. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  284. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  285. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  286. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  287. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  288. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  289. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  290. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  291. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  292. helm/benchmark/window_services/test_utils.py +3 -2
  293. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  294. helm/benchmark/window_services/window_service.py +42 -0
  295. helm/benchmark/window_services/window_service_factory.py +4 -1
  296. helm/benchmark/window_services/yalm_window_service.py +0 -27
  297. helm/clients/__init__.py +0 -0
  298. helm/{proxy/clients → clients}/ai21_client.py +3 -9
  299. helm/clients/aleph_alpha_client.py +112 -0
  300. helm/{proxy/clients → clients}/anthropic_client.py +233 -18
  301. helm/{proxy/clients → clients}/auto_client.py +59 -31
  302. helm/clients/bedrock_client.py +128 -0
  303. helm/clients/bedrock_utils.py +72 -0
  304. helm/{proxy/clients → clients}/client.py +65 -7
  305. helm/clients/clip_score_client.py +49 -0
  306. helm/clients/clip_scorers/__init__.py +0 -0
  307. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  308. helm/clients/clip_scorers/clip_scorer.py +50 -0
  309. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  310. helm/{proxy/clients → clients}/cohere_client.py +4 -11
  311. helm/clients/gcs_client.py +82 -0
  312. helm/{proxy/clients → clients}/google_client.py +5 -5
  313. helm/clients/google_translate_client.py +35 -0
  314. helm/{proxy/clients → clients}/http_model_client.py +5 -7
  315. helm/{proxy/clients → clients}/huggingface_client.py +43 -64
  316. helm/clients/image_generation/__init__.py +0 -0
  317. helm/clients/image_generation/adobe_vision_client.py +78 -0
  318. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  319. helm/clients/image_generation/cogview2/__init__.py +0 -0
  320. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  321. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  322. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  323. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  324. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  325. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  326. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  327. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  328. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  329. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  330. helm/clients/image_generation/cogview2_client.py +191 -0
  331. helm/clients/image_generation/dalle2_client.py +192 -0
  332. helm/clients/image_generation/dalle3_client.py +108 -0
  333. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  334. helm/clients/image_generation/dalle_mini/data.py +442 -0
  335. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  336. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  337. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  338. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  339. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  340. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  341. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  342. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  343. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  344. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  345. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  346. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  347. helm/clients/image_generation/dalle_mini_client.py +190 -0
  348. helm/clients/image_generation/deep_floyd_client.py +78 -0
  349. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  350. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  351. helm/clients/image_generation/lexica_client.py +86 -0
  352. helm/clients/image_generation/mindalle/__init__.py +0 -0
  353. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  354. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  355. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  356. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  357. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  358. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  359. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  360. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  361. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  362. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  363. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  364. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  365. helm/clients/image_generation/mindalle_client.py +115 -0
  366. helm/clients/image_generation/nudity_check_client.py +64 -0
  367. helm/clients/image_generation/together_image_generation_client.py +111 -0
  368. helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
  369. helm/{proxy/clients → clients}/megatron_client.py +5 -5
  370. helm/clients/mistral_client.py +134 -0
  371. helm/clients/moderation_api_client.py +109 -0
  372. helm/clients/open_lm_client.py +43 -0
  373. helm/clients/openai_client.py +301 -0
  374. helm/{proxy/clients → clients}/palmyra_client.py +6 -8
  375. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  376. helm/clients/simple_client.py +64 -0
  377. helm/{proxy/clients → clients}/test_auto_client.py +13 -15
  378. helm/clients/test_client.py +100 -0
  379. helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
  380. helm/clients/test_simple_client.py +19 -0
  381. helm/{proxy/clients → clients}/test_together_client.py +20 -8
  382. helm/{proxy/clients → clients}/together_client.py +104 -73
  383. helm/clients/vertexai_client.py +400 -0
  384. helm/clients/vision_language/__init__.py +0 -0
  385. helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
  386. helm/clients/vision_language/huggingface_vlm_client.py +111 -0
  387. helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
  388. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  389. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  390. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  391. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  392. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  393. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  394. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  395. helm/clients/vision_language/open_flamingo_client.py +155 -0
  396. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  397. helm/clients/vllm_client.py +46 -0
  398. helm/common/cache.py +16 -4
  399. helm/common/cache_backend_config.py +47 -0
  400. helm/common/clip_score_request.py +41 -0
  401. helm/common/file_caches/__init__.py +0 -0
  402. helm/common/file_caches/file_cache.py +16 -0
  403. helm/common/file_caches/local_file_cache.py +61 -0
  404. helm/common/file_caches/test_local_file_cache.py +25 -0
  405. helm/common/file_upload_request.py +27 -0
  406. helm/common/general.py +1 -1
  407. helm/common/image_generation_parameters.py +25 -0
  408. helm/common/images_utils.py +33 -3
  409. helm/common/key_value_store.py +35 -4
  410. helm/common/media_object.py +13 -0
  411. helm/common/moderations_api_request.py +71 -0
  412. helm/common/mongo_key_value_store.py +3 -3
  413. helm/common/multimodal_request_utils.py +31 -0
  414. helm/common/nudity_check_request.py +29 -0
  415. helm/common/request.py +15 -17
  416. helm/common/test_general.py +6 -0
  417. helm/common/tokenization_request.py +1 -1
  418. helm/config/model_deployments.yaml +1159 -538
  419. helm/config/model_metadata.yaml +868 -41
  420. helm/config/tokenizer_configs.yaml +149 -43
  421. helm/proxy/accounts.py +31 -4
  422. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  423. helm/proxy/critique/model_critique_client.py +8 -6
  424. helm/proxy/example_queries.py +29 -17
  425. helm/proxy/server.py +70 -5
  426. helm/proxy/services/remote_service.py +31 -0
  427. helm/proxy/services/server_service.py +96 -16
  428. helm/proxy/services/service.py +30 -0
  429. helm/proxy/services/test_remote_service.py +4 -3
  430. helm/proxy/services/test_service.py +0 -12
  431. helm/proxy/test_accounts.py +32 -0
  432. helm/proxy/token_counters/auto_token_counter.py +37 -37
  433. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  434. helm/proxy/token_counters/token_counter.py +3 -5
  435. helm/tokenizers/__init__.py +0 -0
  436. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  437. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
  438. helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
  439. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  440. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  441. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
  442. helm/tokenizers/simple_tokenizer.py +33 -0
  443. helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
  444. helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
  445. helm/tokenizers/test_simple_tokenizer.py +33 -0
  446. helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
  447. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  448. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  449. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  450. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  451. crfm_helm-0.4.0.dist-info/RECORD +0 -397
  452. helm/benchmark/run_specs.py +0 -2762
  453. helm/benchmark/test_model_deployment_definition.py +0 -92
  454. helm/benchmark/test_model_properties.py +0 -1570
  455. helm/benchmark/vlm_run_specs.py +0 -97
  456. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  457. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  458. helm/benchmark/window_services/huggingface_window_service.py +0 -60
  459. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  460. helm/benchmark/window_services/t511b_window_service.py +0 -30
  461. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  462. helm/benchmark/window_services/ul2_window_service.py +0 -30
  463. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  464. helm/common/cache_utils.py +0 -14
  465. helm/proxy/clients/aleph_alpha_client.py +0 -95
  466. helm/proxy/clients/goose_ai_client.py +0 -99
  467. helm/proxy/clients/microsoft_client.py +0 -180
  468. helm/proxy/clients/openai_client.py +0 -206
  469. helm/proxy/clients/simple_client.py +0 -60
  470. helm/proxy/clients/test_client.py +0 -49
  471. helm/proxy/clients/vertexai_client.py +0 -115
  472. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  473. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  474. helm/proxy/token_counters/free_token_counter.py +0 -12
  475. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  476. helm/proxy/token_counters/openai_token_counter.py +0 -22
  477. helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
  478. helm/proxy/token_counters/test_openai_token_counter.py +0 -81
  479. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  480. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
  481. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
  482. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
  483. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  484. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  485. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  486. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  487. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  488. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  489. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  490. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  491. /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
  492. /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
  493. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  494. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  495. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  496. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  497. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  498. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  499. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -1,5 +1,5 @@
1
1
  from threading import Lock
2
- from typing import Dict, List, Optional, Union
2
+ from typing import Any, Dict, List, Optional, Union
3
3
 
4
4
  import torch
5
5
  from dataclasses import dataclass
@@ -8,17 +8,14 @@ from transformers import IdeficsForVisionText2Text, AutoProcessor, IdeficsProces
8
8
  from helm.common.cache import CacheConfig
9
9
  from helm.common.images_utils import open_image
10
10
  from helm.common.gpu_utils import get_torch_device_name
11
- from helm.common.hierarchical_logger import hlog
11
+ from helm.common.hierarchical_logger import hlog, htrack_block
12
12
  from helm.common.media_object import TEXT_TYPE
13
13
  from helm.common.optional_dependencies import handle_module_not_found_error
14
- from helm.common.request import Request, RequestResult, Sequence, Token
15
- from helm.common.tokenization_request import (
16
- TokenizationRequest,
17
- TokenizationRequestResult,
18
- )
14
+ from helm.common.request import Request, RequestResult, GeneratedOutput, Token
15
+ from helm.common.tokenization_request import TokenizationRequest
19
16
  from helm.common.request import wrap_request_time
20
- from helm.proxy.clients.client import CachingClient, generate_uid_for_multimodal_prompt
21
- from helm.proxy.tokenizers.tokenizer import Tokenizer
17
+ from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
18
+ from helm.tokenizers.tokenizer import Tokenizer
22
19
 
23
20
  try:
24
21
  from PIL import Image
@@ -54,6 +51,8 @@ class IDEFICSClient(CachingClient):
54
51
  END_OF_UTTERANCE_TOKEN: str = "<end_of_utterance>"
55
52
  BAD_WORD_TOKENS: List[str] = ["<image>", "<fake_token_around_image>"]
56
53
 
54
+ ASSISTANT_PREFIX: str = "Assistant: "
55
+
57
56
  def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
58
57
  super().__init__(cache_config=cache_config)
59
58
  self.tokenizer = tokenizer
@@ -69,8 +68,8 @@ class IDEFICSClient(CachingClient):
69
68
  loaded_model_processor = _models[checkpoint]
70
69
  if loaded_model_processor is None:
71
70
  hlog(f"Loading model {checkpoint} and caching in memory...")
72
- model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).to(
73
- self._device
71
+ model = IdeficsForVisionText2Text.from_pretrained(
72
+ checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
74
73
  )
75
74
  processor = AutoProcessor.from_pretrained(checkpoint)
76
75
  _models[checkpoint] = LoadedIDEFICSModelProcessor(model, processor)
@@ -89,7 +88,7 @@ class IDEFICSClient(CachingClient):
89
88
 
90
89
  input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
91
90
  generation_args = {
92
- "max_length": request.max_tokens,
91
+ "max_new_tokens": request.max_tokens,
93
92
  "bad_words_ids": processor.tokenizer(self.BAD_WORD_TOKENS, add_special_tokens=False).input_ids,
94
93
  }
95
94
 
@@ -112,43 +111,49 @@ class IDEFICSClient(CachingClient):
112
111
  raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
113
112
  prompt_text: str = request.multimodal_prompt.text.replace(self.END_OF_UTTERANCE_TOKEN, " ")
114
113
 
115
- try:
116
-
117
- def do_it():
118
- inputs = processor(multimodal_prompt, **input_args).to(self._device)
119
- generated_ids = model.generate(**inputs, **generation_args)
120
- generated_text: str = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
121
- assert generated_text.startswith(
122
- prompt_text
123
- ), f"Generated text: {generated_text} does not start with prompt: {prompt_text}"
124
-
125
- # Remove the prompt from the generated text
126
- generated_text = generated_text[len(prompt_text) :].strip()
127
- return {"output": generated_text}
128
-
129
- # Include the prompt and model name in the cache key
130
- cache_key = CachingClient.make_cache_key(
131
- raw_request={
132
- "model": request.model,
133
- "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
134
- **generation_args,
135
- },
136
- request=request,
137
- )
138
- result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
139
- except RuntimeError as e:
140
- return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
141
-
142
- # TODO: Support multiple completions and figure out how get the log probs
143
- # TODO: Does it make sense to support echo? Include these params in the cache key.
144
- # TODO: Together might support this model so use the TogetherClient
145
- tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
146
- TokenizationRequest(result["output"], tokenizer=self.tokenizer_name)
147
- )
148
- tokens: List[Token] = [
149
- Token(text=str(text), logprob=0, top_logprobs={}) for text in tokenization_result.raw_tokens
150
- ]
151
- completions: List[Sequence] = [Sequence(text=result["output"], logprob=0, tokens=tokens)]
114
+ completions: List[GeneratedOutput] = []
115
+ with htrack_block(f"Generating for prompt: {prompt_text}"):
116
+ try:
117
+
118
+ def do_it() -> Dict[str, Any]:
119
+ inputs = processor([multimodal_prompt] * request.num_completions, **input_args).to(self._device)
120
+ generated_ids = model.generate(**inputs, **generation_args)
121
+ generated_text: List[str] = processor.batch_decode(generated_ids, skip_special_tokens=True)
122
+ return {"output": generated_text}
123
+
124
+ # Include the prompt and model name in the cache key
125
+ cache_key = CachingClient.make_cache_key(
126
+ raw_request={
127
+ "n": request.num_completions,
128
+ "model": request.model,
129
+ "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
130
+ **generation_args,
131
+ },
132
+ request=request,
133
+ )
134
+ result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
135
+ except RuntimeError as model_error:
136
+ return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
137
+
138
+ for text in result["output"]:
139
+ hlog(f"Generated text: {text}")
140
+
141
+ # Truncate the output text as IDEFICS outputs the entire sequence including the prompt
142
+ if "instruct" in request.model:
143
+ assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output: {text}"
144
+ text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
145
+ else:
146
+ # Best we can do is to remove the text portion of the prompt from the output
147
+ text = text[len(prompt_text) :]
148
+
149
+ # Tokenize truncated text to get the list of tokens
150
+ hlog(f"Truncated: {text}")
151
+ tokenization_result = self.tokenizer.tokenize(
152
+ TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
153
+ )
154
+ tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
155
+ completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
156
+
152
157
  return RequestResult(
153
158
  success=True,
154
159
  cached=cached,
@@ -0,0 +1,2 @@
1
+ from .src.flamingo import Flamingo
2
+ from .src.factory import create_model_and_transforms
@@ -0,0 +1,147 @@
1
+ """
2
+ Source: https://github.com/mlfoundations/open_flamingo
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from helm.common.general import handle_module_not_found_error
10
+ from .flamingo import Flamingo
11
+ from .flamingo_lm import FlamingoLMMixin
12
+ from .utils import extend_instance
13
+
14
+
15
+ def create_model_and_transforms(
16
+ clip_vision_encoder_path: str,
17
+ clip_vision_encoder_pretrained: str,
18
+ lang_encoder_path: str,
19
+ tokenizer_path: str,
20
+ cross_attn_every_n_layers: int = 1,
21
+ use_local_files: bool = False,
22
+ decoder_layers_attr_name: str = None,
23
+ freeze_lm_embeddings: bool = False,
24
+ cache_dir: Optional[str] = None,
25
+ **flamingo_kwargs,
26
+ ):
27
+ """
28
+ Initialize a Flamingo model from a pretrained vision encoder and language encoder.
29
+ Appends special tokens to the tokenizer and freezes backbones.
30
+
31
+ Args:
32
+ clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
33
+ clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
34
+ lang_encoder_path (str): path to pretrained language encoder
35
+ tokenizer_path (str): path to pretrained tokenizer
36
+ cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
37
+ use_local_files (bool, optional): whether to use local files. Defaults to False.
38
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
39
+ freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
40
+ cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
41
+ Returns:
42
+ Flamingo: Flamingo model from pretrained vision and language encoders
43
+ Image processor: Pipeline to preprocess input images
44
+ Tokenizer: A tokenizer for the language model
45
+ """
46
+ try:
47
+ import open_clip
48
+ except ModuleNotFoundError as e:
49
+ handle_module_not_found_error(e, ["vlm"])
50
+
51
+ vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
52
+ clip_vision_encoder_path,
53
+ pretrained=clip_vision_encoder_pretrained,
54
+ cache_dir=cache_dir,
55
+ )
56
+ # set the vision encoder to output the visual features
57
+ vision_encoder.visual.output_tokens = True
58
+
59
+ text_tokenizer = AutoTokenizer.from_pretrained(
60
+ tokenizer_path,
61
+ local_files_only=use_local_files,
62
+ trust_remote_code=True,
63
+ cache_dir=cache_dir,
64
+ )
65
+ # add Flamingo special tokens to the tokenizer
66
+ text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
67
+ if text_tokenizer.pad_token is None:
68
+ # Issue: GPT models don't have a pad token, which we use to
69
+ # modify labels for the loss.
70
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
71
+
72
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
73
+ lang_encoder_path,
74
+ local_files_only=use_local_files,
75
+ trust_remote_code=True,
76
+ cache_dir=cache_dir,
77
+ )
78
+
79
+ # hacks for MPT-1B, which doesn't have a get_input_embeddings method
80
+ if "mpt-1b-redpajama-200b" in lang_encoder_path:
81
+
82
+ class EmbeddingFnMixin:
83
+ def get_input_embeddings(self):
84
+ return self.transformer.wte
85
+
86
+ def set_input_embeddings(self, new_embeddings):
87
+ self.transformer.wte = new_embeddings
88
+
89
+ extend_instance(lang_encoder, EmbeddingFnMixin)
90
+
91
+ # convert LM to FlamingoLM
92
+ extend_instance(lang_encoder, FlamingoLMMixin)
93
+
94
+ if decoder_layers_attr_name is None:
95
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
96
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
97
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
98
+
99
+ model = Flamingo(
100
+ vision_encoder,
101
+ lang_encoder,
102
+ text_tokenizer.encode("<|endofchunk|>")[-1],
103
+ text_tokenizer.encode("<image>")[-1],
104
+ vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
105
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
106
+ **flamingo_kwargs,
107
+ )
108
+
109
+ # Freeze all parameters
110
+ model.requires_grad_(False)
111
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
112
+
113
+ # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
114
+ model.perceiver.requires_grad_(True)
115
+ model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
116
+ if not freeze_lm_embeddings:
117
+ model.lang_encoder.get_input_embeddings().requires_grad_(True)
118
+ # TODO: investigate also training the output embeddings when untied
119
+
120
+ print(
121
+ f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
122
+ )
123
+
124
+ return model, image_processor, text_tokenizer
125
+
126
+
127
+ def _infer_decoder_layers_attr_name(model):
128
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
129
+ if k.lower() in model.__class__.__name__.lower():
130
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
131
+
132
+ raise ValueError(
133
+ "We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. "
134
+ "Please supply this string manually."
135
+ )
136
+
137
+
138
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
139
+ "opt": "model.decoder.layers",
140
+ "gptj": "transformer.h",
141
+ "gpt-j": "transformer.h",
142
+ "pythia": "gpt_neox.layers",
143
+ "llama": "model.layers",
144
+ "gptneoxforcausallm": "gpt_neox.layers",
145
+ "mpt": "transformer.blocks",
146
+ "mosaicgpt": "transformer.blocks",
147
+ }
@@ -0,0 +1,337 @@
1
+ """
2
+ Source: https://github.com/mlfoundations/open_flamingo
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from .helpers import PerceiverResampler
9
+ from torch.distributed.fsdp.wrap import (
10
+ enable_wrap,
11
+ wrap,
12
+ )
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from torch.distributed.fsdp import (
15
+ FullyShardedDataParallel as FSDP,
16
+ )
17
+
18
+ from .utils import apply_with_stopping_condition
19
+
20
+
21
+ class Flamingo(nn.Module):
22
+ def __init__(
23
+ self,
24
+ vision_encoder: nn.Module,
25
+ lang_encoder: nn.Module,
26
+ eoc_token_id: int,
27
+ media_token_id: int,
28
+ vis_dim: int,
29
+ cross_attn_every_n_layers: int = 1,
30
+ gradient_checkpointing: bool = False,
31
+ ):
32
+ """
33
+ Args:
34
+ vision_encoder (nn.Module): HF CLIPModel
35
+ lang_encoder (nn.Module): HF causal language model
36
+ eoc_token_id (int): Token id for <|endofchunk|>
37
+ media_token_id (int): Token id for <image>
38
+ vis_dim (int): Dimension of the visual features.
39
+ Visual features are projected to match this shape along the last dimension.
40
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
41
+ """
42
+ super().__init__()
43
+ self.eoc_token_id = eoc_token_id
44
+ self.media_token_id = media_token_id
45
+ self.vis_dim = vis_dim
46
+ if hasattr(lang_encoder.config, "d_model"):
47
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
48
+ else:
49
+ self.lang_dim = lang_encoder.config.hidden_size
50
+
51
+ self.vision_encoder = vision_encoder.visual
52
+ self.perceiver = PerceiverResampler(dim=self.vis_dim)
53
+ self.lang_encoder = lang_encoder
54
+ self.lang_encoder.init_flamingo(
55
+ media_token_id=media_token_id,
56
+ lang_hidden_size=self.lang_dim,
57
+ vis_hidden_size=self.vis_dim,
58
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
59
+ gradient_checkpointing=gradient_checkpointing,
60
+ )
61
+ self._use_gradient_checkpointing = gradient_checkpointing
62
+ self.perceiver._use_gradient_checkpointing = gradient_checkpointing
63
+
64
+ def forward(
65
+ self,
66
+ vision_x: torch.Tensor,
67
+ lang_x: torch.Tensor,
68
+ attention_mask: torch.Tensor = None,
69
+ labels: torch.Tensor = None,
70
+ clear_conditioned_layers: bool = True,
71
+ past_key_values=None,
72
+ use_cache: bool = False,
73
+ ):
74
+ """
75
+ Forward pass of Flamingo.
76
+
77
+ Args:
78
+ vision_x (torch.Tensor): Vision input
79
+ shape (B, T_img, F, C, H, W) with F=1
80
+ lang_x (torch.Tensor): Language input ids
81
+ shape (B, T_txt)
82
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
83
+ labels (torch.Tensor, optional): Labels. Defaults to None.
84
+ clear_conditioned_layers: if True, clear the conditioned layers
85
+ once the foward pass is completed. Set this to false if the
86
+ same set of images will be reused in another subsequent
87
+ forward pass.
88
+ past_key_values: pre-computed values to pass to language model.
89
+ See past_key_values documentation in Hugging Face
90
+ CausalLM models.
91
+ use_cache: whether to use cached key values. See use_cache
92
+ documentation in Hugging Face CausalLM models.
93
+ """
94
+ assert (
95
+ self.lang_encoder.initialized_flamingo
96
+ ), "Flamingo layers are not initialized. Please call `init_flamingo` first."
97
+
98
+ assert (
99
+ self.lang_encoder._use_cached_vision_x or vision_x is not None
100
+ ), "Must provide either vision_x or have precached media using cache_media()."
101
+
102
+ if self.lang_encoder._use_cached_vision_x:
103
+ # Case: use cached; vision_x should be cached and other
104
+ # vision-related inputs should not be provided.
105
+ assert (
106
+ vision_x is None
107
+ ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
108
+ assert self.lang_encoder.is_conditioned()
109
+
110
+ else:
111
+ # Case: do not use caching (i.e. this is a standard forward pass);
112
+ self._encode_vision_x(vision_x=vision_x)
113
+ self._condition_media_locations(input_ids=lang_x)
114
+
115
+ output = self.lang_encoder(
116
+ input_ids=lang_x,
117
+ attention_mask=attention_mask,
118
+ labels=labels,
119
+ past_key_values=past_key_values,
120
+ use_cache=use_cache,
121
+ )
122
+
123
+ if clear_conditioned_layers:
124
+ self.lang_encoder.clear_conditioned_layers()
125
+
126
+ return output
127
+
128
+ def generate(
129
+ self,
130
+ vision_x: torch.Tensor,
131
+ lang_x: torch.Tensor,
132
+ attention_mask: torch.Tensor = None,
133
+ **kwargs,
134
+ ):
135
+ """
136
+ Generate text conditioned on vision and language inputs.
137
+
138
+ Args:
139
+ vision_x (torch.Tensor): Vision input
140
+ shape (B, T_img, F, C, H, W)
141
+ images in the same chunk are collated along T_img, and frames are collated along F
142
+ currently only F=1 is supported (single-frame videos)
143
+ lang_x (torch.Tensor): Language input
144
+ shape (B, T_txt)
145
+ **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
146
+ max_length (int, optional): Maximum length of the output. Defaults to None.
147
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
148
+ num_beams (int, optional): Number of beams. Defaults to 1.
149
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
150
+ temperature (float, optional): Temperature. Defaults to 1.0.
151
+ top_k (int, optional): Top k. Defaults to 50.
152
+ top_p (float, optional): Top p. Defaults to 1.0.
153
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
154
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
155
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
156
+ do_sample (bool, optional): Do sample. Defaults to False.
157
+ early_stopping (bool, optional): Early stopping. Defaults to False.
158
+ Returns:
159
+ torch.Tensor: lang_x with generated tokens appended to it
160
+ """
161
+ num_beams = kwargs.pop("num_beams", 1)
162
+ if num_beams > 1:
163
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
164
+
165
+ self.lang_encoder._use_cached_vision_x = True
166
+ self._encode_vision_x(vision_x=vision_x)
167
+
168
+ eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
169
+ output = self.lang_encoder.generate(
170
+ input_ids=lang_x,
171
+ attention_mask=attention_mask,
172
+ eos_token_id=eos_token_id,
173
+ num_beams=num_beams,
174
+ **kwargs,
175
+ )
176
+
177
+ self.lang_encoder.clear_conditioned_layers()
178
+ self.lang_encoder._use_cached_vision_x = False
179
+ return output
180
+
181
+ def _encode_vision_x(self, vision_x: torch.Tensor):
182
+ """
183
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
184
+ Args:
185
+ vision_x (torch.Tensor): Vision input
186
+ shape (B, T_img, F, C, H, W)
187
+ Images in the same chunk are collated along T_img, and frames are collated along F
188
+ Currently only F=1 is supported (single-frame videos)
189
+
190
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
191
+ """
192
+
193
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
194
+ b, T, F = vision_x.shape[:3]
195
+ assert F == 1, "Only single frame supported"
196
+
197
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
198
+ with torch.no_grad():
199
+ vision_x = self.vision_encoder(vision_x)[1]
200
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
201
+ vision_x = self.perceiver(vision_x)
202
+
203
+ for layer in self.lang_encoder._get_decoder_layers():
204
+ layer.condition_vis_x(vision_x)
205
+
206
+ def wrap_fsdp(self, wrapper_kwargs, device_id):
207
+ """
208
+ Manually wraps submodules for FSDP and move other parameters to device_id.
209
+
210
+ Why manually wrap?
211
+ - all parameters within the FSDP wrapper must have the same requires_grad.
212
+ We have a mix of frozen and unfrozen parameters.
213
+ - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors
214
+ See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344
215
+
216
+ The rough wrapping structure is:
217
+ - FlamingoModel
218
+ - FSDP(FSDP(vision_encoder))
219
+ - FSDP(FSDP(perceiver))
220
+ - lang_encoder
221
+ - FSDP(FSDP(input_embeddings))
222
+ - FlamingoLayers
223
+ - FSDP(FSDP(gated_cross_attn_layer))
224
+ - FSDP(FSDP(decoder_layer))
225
+ - FSDP(FSDP(output_embeddings))
226
+ - other parameters
227
+
228
+ Known issues:
229
+ - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied,
230
+ train with DDP or set the --freeze_lm_embeddings flag to true.
231
+ - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound.
232
+ Although the training curves look okay, we found that downstream performance dramatically
233
+ degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
234
+
235
+ FAQs about our FSDP wrapping strategy:
236
+ Why double wrap?
237
+ As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook
238
+ only free gathered parameters if the module is NOT FSDP root.
239
+
240
+ Why unfreeze the decoder_layers?
241
+ See https://github.com/pytorch/pytorch/issues/95805
242
+ As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param
243
+ requires_grad=True. We need the postback to fire to avoid OOM.
244
+ To effectively freeze the decoder layers, we exclude them from the optimizer.
245
+
246
+ What is assumed to be frozen v. unfrozen?
247
+ We assume that the model is being trained under normal Flamingo settings
248
+ with these lines being called in factory.py:
249
+ ```
250
+ # Freeze all parameters
251
+ model.requires_grad_(False)
252
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
253
+
254
+ # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
255
+ model.perceiver.requires_grad_(True)
256
+ model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
257
+ [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True)
258
+ ```
259
+ """
260
+ # unfreeze the decoder layers
261
+ for block in self.lang_encoder.old_decoder_blocks:
262
+ block.requires_grad_(True)
263
+
264
+ # wrap in FSDP
265
+ with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
266
+ self.perceiver = wrap(wrap(self.perceiver))
267
+ self.lang_encoder.old_decoder_blocks = nn.ModuleList(
268
+ wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
269
+ )
270
+ self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
271
+ wrap(wrap(layer)) if layer is not None else None for layer in self.lang_encoder.gated_cross_attn_layers
272
+ )
273
+ self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
274
+ self.lang_encoder.set_input_embeddings(wrap(wrap(self.lang_encoder.get_input_embeddings())))
275
+ self.lang_encoder.set_output_embeddings(wrap(wrap(self.lang_encoder.get_output_embeddings())))
276
+ self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen
277
+
278
+ # manually move non-FSDP managed parameters to device_id
279
+ # these are all in lang_encoder
280
+ apply_with_stopping_condition(
281
+ module=self.lang_encoder,
282
+ apply_fn=lambda m: m.to(device_id),
283
+ apply_condition=lambda m: len(list(m.children())) == 0,
284
+ stopping_condition=lambda m: isinstance(m, FSDP),
285
+ )
286
+
287
+ # exclude the original decoder layers from the optimizer
288
+ for block in self.lang_encoder.old_decoder_blocks:
289
+ for p in block.parameters():
290
+ p.exclude_from_optimizer = True
291
+
292
+ # set up clip_grad_norm_ function
293
+ def clip_grad_norm_(max_norm):
294
+ self.perceiver.clip_grad_norm_(max_norm)
295
+ for layer in self.lang_encoder.gated_cross_attn_layers:
296
+ if layer is not None:
297
+ layer.clip_grad_norm_(max_norm)
298
+ self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
299
+
300
+ self.clip_grad_norm_ = clip_grad_norm_
301
+
302
+ def _condition_media_locations(self, input_ids: torch.Tensor):
303
+ """
304
+ Compute the media token locations from lang_x and condition the language model on these.
305
+ Args:
306
+ input_ids (torch.Tensor): Language input
307
+ shape (B, T_txt)
308
+ """
309
+ media_locations = input_ids == self.media_token_id
310
+
311
+ for layer in self.lang_encoder._get_decoder_layers():
312
+ layer.condition_media_locations(media_locations)
313
+
314
+ def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
315
+ """
316
+ Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
317
+ All subsequent calls to forward() will generate attending to the LAST
318
+ image in vision_x.
319
+ This is not meant to be used to cache things for generate().
320
+ Args:
321
+ input_ids (torch.Tensor): Language input
322
+ shape (B, T_txt)
323
+ vision_x (torch.Tensor): Vision input
324
+ shape (B, T_img, F, C, H, W)
325
+ Images in the same chunk are collated along T_img, and frames are collated along F
326
+ Currently only F=1 is supported (single-frame videos)
327
+ """
328
+ self._encode_vision_x(vision_x=vision_x)
329
+ self._condition_media_locations(input_ids=input_ids)
330
+ self.lang_encoder._use_cached_vision_x = True
331
+
332
+ def uncache_media(self):
333
+ """
334
+ Clear all conditioning.
335
+ """
336
+ self.lang_encoder.clear_conditioned_layers()
337
+ self.lang_encoder._use_cached_vision_x = False