crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (499) hide show
  1. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
  2. crfm_helm-0.5.1.dist-info/RECORD +654 -0
  3. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +31 -3
  5. helm/benchmark/adaptation/adapters/adapter.py +2 -2
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
  9. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
  10. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  11. helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
  12. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  13. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
  14. helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
  15. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  16. helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
  17. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
  18. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
  19. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  20. helm/benchmark/adaptation/request_state.py +6 -1
  21. helm/benchmark/adaptation/scenario_state.py +6 -2
  22. helm/benchmark/annotation/annotator.py +43 -0
  23. helm/benchmark/annotation/annotator_factory.py +61 -0
  24. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  25. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  26. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  27. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  28. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  29. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  30. helm/benchmark/annotation_executor.py +124 -0
  31. helm/benchmark/augmentations/data_augmenter.py +0 -2
  32. helm/benchmark/augmentations/gender_perturbation.py +1 -1
  33. helm/benchmark/augmentations/perturbation.py +25 -3
  34. helm/benchmark/augmentations/perturbation_description.py +1 -1
  35. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  36. helm/benchmark/augmentations/test_perturbation.py +41 -7
  37. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  38. helm/benchmark/config_registry.py +7 -1
  39. helm/benchmark/executor.py +46 -16
  40. helm/benchmark/huggingface_registration.py +20 -7
  41. helm/benchmark/metrics/basic_metrics.py +169 -664
  42. helm/benchmark/metrics/bbq_metrics.py +3 -4
  43. helm/benchmark/metrics/bias_metrics.py +6 -6
  44. helm/benchmark/metrics/classification_metrics.py +11 -8
  45. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  46. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  47. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  48. helm/benchmark/metrics/common_metric_specs.py +167 -0
  49. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  50. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  51. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  52. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  53. helm/benchmark/metrics/disinformation_metrics.py +4 -110
  54. helm/benchmark/metrics/dry_run_metrics.py +2 -2
  55. helm/benchmark/metrics/efficiency_metrics.py +213 -0
  56. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  57. helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
  58. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  59. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  60. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  61. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  62. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  63. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  64. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  65. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  66. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  67. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  68. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  69. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  70. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  71. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  72. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  73. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  74. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  75. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  76. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  77. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  78. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  79. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  80. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  81. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  82. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  83. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  84. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  85. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  86. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  87. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  88. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  89. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  90. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  91. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  92. helm/benchmark/metrics/machine_translation_metrics.py +89 -0
  93. helm/benchmark/metrics/metric.py +93 -172
  94. helm/benchmark/metrics/metric_name.py +0 -1
  95. helm/benchmark/metrics/metric_service.py +16 -0
  96. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  97. helm/benchmark/metrics/ranking_metrics.py +2 -2
  98. helm/benchmark/metrics/reference_metric.py +148 -0
  99. helm/benchmark/metrics/summac/model_summac.py +0 -2
  100. helm/benchmark/metrics/summarization_metrics.py +2 -2
  101. helm/benchmark/metrics/test_classification_metrics.py +8 -5
  102. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  103. helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
  104. helm/benchmark/metrics/test_metric.py +2 -2
  105. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
  106. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  107. helm/benchmark/metrics/toxicity_utils.py +23 -0
  108. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  109. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  110. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  111. helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
  112. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  113. helm/benchmark/model_deployment_registry.py +74 -0
  114. helm/benchmark/model_metadata_registry.py +41 -1
  115. helm/benchmark/multi_gpu_runner.py +133 -0
  116. helm/benchmark/presentation/create_plots.py +8 -7
  117. helm/benchmark/presentation/run_display.py +26 -10
  118. helm/benchmark/presentation/schema.py +15 -40
  119. helm/benchmark/presentation/summarize.py +119 -79
  120. helm/benchmark/presentation/table.py +8 -8
  121. helm/benchmark/presentation/test_contamination.py +2 -2
  122. helm/benchmark/presentation/test_run_entry.py +1 -2
  123. helm/benchmark/presentation/test_summarize.py +3 -3
  124. helm/benchmark/run.py +54 -26
  125. helm/benchmark/run_expander.py +205 -35
  126. helm/benchmark/run_spec.py +93 -0
  127. helm/benchmark/run_spec_factory.py +163 -0
  128. helm/benchmark/run_specs/__init__.py +0 -0
  129. helm/benchmark/run_specs/classic_run_specs.py +1510 -0
  130. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  131. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  132. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  133. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  134. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  135. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  136. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  137. helm/benchmark/run_specs/vlm_run_specs.py +757 -0
  138. helm/benchmark/runner.py +51 -57
  139. helm/benchmark/runner_config_registry.py +21 -0
  140. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  141. helm/benchmark/scenarios/bold_scenario.py +2 -2
  142. helm/benchmark/scenarios/code_scenario.py +1 -0
  143. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  144. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  145. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  146. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  147. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  148. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  149. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  150. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  151. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  152. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  153. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  154. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  155. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  156. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  157. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  158. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  159. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  160. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  161. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  162. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  163. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  164. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  165. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  166. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  167. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  168. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  169. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  170. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  171. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  172. helm/benchmark/scenarios/legalbench_scenario.py +6 -2
  173. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  174. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  175. helm/benchmark/scenarios/math_scenario.py +19 -2
  176. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  177. helm/benchmark/scenarios/numeracy_scenario.py +1 -1
  178. helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
  179. helm/benchmark/scenarios/scenario.py +4 -0
  180. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  181. helm/benchmark/scenarios/test_math_scenario.py +6 -0
  182. helm/benchmark/scenarios/test_scenario.py +6 -3
  183. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  184. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  185. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  186. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  187. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  188. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
  189. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  190. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
  191. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
  192. helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
  193. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
  194. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  195. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  196. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  197. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  198. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  199. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  200. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  201. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  202. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  203. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  204. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  205. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  206. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
  207. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  208. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
  209. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  210. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  211. helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
  212. helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
  213. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  214. helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
  215. helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
  216. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  217. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  218. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  219. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
  220. helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
  221. helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
  222. helm/benchmark/server.py +24 -1
  223. helm/benchmark/slurm_runner.py +70 -49
  224. helm/benchmark/static/benchmarking.js +1 -1
  225. helm/benchmark/static/schema_classic.yaml +258 -1066
  226. helm/benchmark/static/schema_image2structure.yaml +304 -0
  227. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  228. helm/benchmark/static/schema_lite.yaml +2 -227
  229. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  230. helm/benchmark/static/schema_unitxt.yaml +428 -0
  231. helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
  232. helm/benchmark/static/schema_vlm.yaml +823 -0
  233. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  234. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  235. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  236. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  237. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  238. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  239. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  240. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  241. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  242. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  243. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  244. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  245. helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
  246. helm/benchmark/static_build/assets/index-878a1094.css +1 -0
  247. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  248. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  249. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  250. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  251. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  252. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  253. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  254. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  255. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  256. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  257. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  258. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  259. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  260. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  261. helm/benchmark/static_build/config.js +4 -0
  262. helm/benchmark/static_build/index.html +20 -0
  263. helm/benchmark/test_data_preprocessor.py +3 -3
  264. helm/benchmark/test_run_expander.py +1 -1
  265. helm/benchmark/window_services/ai21_window_service.py +22 -33
  266. helm/benchmark/window_services/cohere_window_service.py +1 -63
  267. helm/benchmark/window_services/default_window_service.py +2 -44
  268. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  269. helm/benchmark/window_services/ice_window_service.py +0 -34
  270. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  271. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  272. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  273. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  274. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  275. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  276. helm/benchmark/window_services/local_window_service.py +21 -4
  277. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  278. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  279. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  280. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  281. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  282. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  283. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  284. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  285. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  286. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  287. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  288. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  289. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  290. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  291. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  292. helm/benchmark/window_services/test_utils.py +3 -2
  293. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  294. helm/benchmark/window_services/window_service.py +42 -0
  295. helm/benchmark/window_services/window_service_factory.py +4 -1
  296. helm/benchmark/window_services/yalm_window_service.py +0 -27
  297. helm/clients/__init__.py +0 -0
  298. helm/{proxy/clients → clients}/ai21_client.py +3 -9
  299. helm/clients/aleph_alpha_client.py +112 -0
  300. helm/{proxy/clients → clients}/anthropic_client.py +233 -18
  301. helm/{proxy/clients → clients}/auto_client.py +59 -31
  302. helm/clients/bedrock_client.py +128 -0
  303. helm/clients/bedrock_utils.py +72 -0
  304. helm/{proxy/clients → clients}/client.py +65 -7
  305. helm/clients/clip_score_client.py +49 -0
  306. helm/clients/clip_scorers/__init__.py +0 -0
  307. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  308. helm/clients/clip_scorers/clip_scorer.py +50 -0
  309. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  310. helm/{proxy/clients → clients}/cohere_client.py +4 -11
  311. helm/clients/gcs_client.py +82 -0
  312. helm/{proxy/clients → clients}/google_client.py +5 -5
  313. helm/clients/google_translate_client.py +35 -0
  314. helm/{proxy/clients → clients}/http_model_client.py +5 -7
  315. helm/{proxy/clients → clients}/huggingface_client.py +43 -64
  316. helm/clients/image_generation/__init__.py +0 -0
  317. helm/clients/image_generation/adobe_vision_client.py +78 -0
  318. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  319. helm/clients/image_generation/cogview2/__init__.py +0 -0
  320. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  321. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  322. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  323. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  324. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  325. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  326. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  327. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  328. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  329. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  330. helm/clients/image_generation/cogview2_client.py +191 -0
  331. helm/clients/image_generation/dalle2_client.py +192 -0
  332. helm/clients/image_generation/dalle3_client.py +108 -0
  333. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  334. helm/clients/image_generation/dalle_mini/data.py +442 -0
  335. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  336. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  337. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  338. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  339. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  340. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  341. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  342. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  343. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  344. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  345. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  346. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  347. helm/clients/image_generation/dalle_mini_client.py +190 -0
  348. helm/clients/image_generation/deep_floyd_client.py +78 -0
  349. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  350. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  351. helm/clients/image_generation/lexica_client.py +86 -0
  352. helm/clients/image_generation/mindalle/__init__.py +0 -0
  353. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  354. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  355. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  356. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  357. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  358. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  359. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  360. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  361. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  362. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  363. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  364. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  365. helm/clients/image_generation/mindalle_client.py +115 -0
  366. helm/clients/image_generation/nudity_check_client.py +64 -0
  367. helm/clients/image_generation/together_image_generation_client.py +111 -0
  368. helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
  369. helm/{proxy/clients → clients}/megatron_client.py +5 -5
  370. helm/clients/mistral_client.py +134 -0
  371. helm/clients/moderation_api_client.py +109 -0
  372. helm/clients/open_lm_client.py +43 -0
  373. helm/clients/openai_client.py +301 -0
  374. helm/{proxy/clients → clients}/palmyra_client.py +6 -8
  375. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  376. helm/clients/simple_client.py +64 -0
  377. helm/{proxy/clients → clients}/test_auto_client.py +13 -15
  378. helm/clients/test_client.py +100 -0
  379. helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
  380. helm/clients/test_simple_client.py +19 -0
  381. helm/{proxy/clients → clients}/test_together_client.py +20 -8
  382. helm/{proxy/clients → clients}/together_client.py +104 -73
  383. helm/clients/vertexai_client.py +400 -0
  384. helm/clients/vision_language/__init__.py +0 -0
  385. helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
  386. helm/clients/vision_language/huggingface_vlm_client.py +111 -0
  387. helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
  388. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  389. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  390. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  391. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  392. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  393. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  394. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  395. helm/clients/vision_language/open_flamingo_client.py +155 -0
  396. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  397. helm/clients/vllm_client.py +46 -0
  398. helm/common/cache.py +16 -4
  399. helm/common/cache_backend_config.py +47 -0
  400. helm/common/clip_score_request.py +41 -0
  401. helm/common/file_caches/__init__.py +0 -0
  402. helm/common/file_caches/file_cache.py +16 -0
  403. helm/common/file_caches/local_file_cache.py +61 -0
  404. helm/common/file_caches/test_local_file_cache.py +25 -0
  405. helm/common/file_upload_request.py +27 -0
  406. helm/common/general.py +1 -1
  407. helm/common/image_generation_parameters.py +25 -0
  408. helm/common/images_utils.py +33 -3
  409. helm/common/key_value_store.py +35 -4
  410. helm/common/media_object.py +13 -0
  411. helm/common/moderations_api_request.py +71 -0
  412. helm/common/mongo_key_value_store.py +3 -3
  413. helm/common/multimodal_request_utils.py +31 -0
  414. helm/common/nudity_check_request.py +29 -0
  415. helm/common/request.py +15 -17
  416. helm/common/test_general.py +6 -0
  417. helm/common/tokenization_request.py +1 -1
  418. helm/config/model_deployments.yaml +1159 -538
  419. helm/config/model_metadata.yaml +868 -41
  420. helm/config/tokenizer_configs.yaml +149 -43
  421. helm/proxy/accounts.py +31 -4
  422. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  423. helm/proxy/critique/model_critique_client.py +8 -6
  424. helm/proxy/example_queries.py +29 -17
  425. helm/proxy/server.py +70 -5
  426. helm/proxy/services/remote_service.py +31 -0
  427. helm/proxy/services/server_service.py +96 -16
  428. helm/proxy/services/service.py +30 -0
  429. helm/proxy/services/test_remote_service.py +4 -3
  430. helm/proxy/services/test_service.py +0 -12
  431. helm/proxy/test_accounts.py +32 -0
  432. helm/proxy/token_counters/auto_token_counter.py +37 -37
  433. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  434. helm/proxy/token_counters/token_counter.py +3 -5
  435. helm/tokenizers/__init__.py +0 -0
  436. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  437. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
  438. helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
  439. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  440. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  441. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
  442. helm/tokenizers/simple_tokenizer.py +33 -0
  443. helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
  444. helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
  445. helm/tokenizers/test_simple_tokenizer.py +33 -0
  446. helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
  447. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  448. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  449. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  450. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  451. crfm_helm-0.4.0.dist-info/RECORD +0 -397
  452. helm/benchmark/run_specs.py +0 -2762
  453. helm/benchmark/test_model_deployment_definition.py +0 -92
  454. helm/benchmark/test_model_properties.py +0 -1570
  455. helm/benchmark/vlm_run_specs.py +0 -97
  456. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  457. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  458. helm/benchmark/window_services/huggingface_window_service.py +0 -60
  459. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  460. helm/benchmark/window_services/t511b_window_service.py +0 -30
  461. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  462. helm/benchmark/window_services/ul2_window_service.py +0 -30
  463. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  464. helm/common/cache_utils.py +0 -14
  465. helm/proxy/clients/aleph_alpha_client.py +0 -95
  466. helm/proxy/clients/goose_ai_client.py +0 -99
  467. helm/proxy/clients/microsoft_client.py +0 -180
  468. helm/proxy/clients/openai_client.py +0 -206
  469. helm/proxy/clients/simple_client.py +0 -60
  470. helm/proxy/clients/test_client.py +0 -49
  471. helm/proxy/clients/vertexai_client.py +0 -115
  472. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  473. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  474. helm/proxy/token_counters/free_token_counter.py +0 -12
  475. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  476. helm/proxy/token_counters/openai_token_counter.py +0 -22
  477. helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
  478. helm/proxy/token_counters/test_openai_token_counter.py +0 -81
  479. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  480. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
  481. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
  482. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
  483. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  484. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  485. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  486. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  487. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  488. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  489. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  490. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  491. /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
  492. /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
  493. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  494. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  495. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  496. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  497. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  498. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  499. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,442 @@
1
+ import random
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from datasets import Dataset, load_dataset
8
+
9
+ from .model.text import TextNormalizer
10
+ from helm.common.optional_dependencies import handle_module_not_found_error
11
+
12
+ try:
13
+ import jax
14
+ import jax.numpy as jnp
15
+ from braceexpand import braceexpand
16
+ except ModuleNotFoundError as e:
17
+ handle_module_not_found_error(e, ["heim"])
18
+
19
+
20
+ @dataclass
21
+ class Dataset:
22
+ dataset_repo_or_path: str
23
+ train_file: str = None
24
+ validation_file: str = None
25
+ streaming: bool = True
26
+ use_auth_token: bool = False
27
+ text_column: str = "caption"
28
+ encoding_column: str = "encoding"
29
+ max_train_samples: int = None
30
+ max_eval_samples: int = None
31
+ preprocessing_num_workers: int = None
32
+ overwrite_cache: bool = False
33
+ do_train: bool = False
34
+ do_eval: bool = True
35
+ seed_dataset: int = None
36
+ shard_by_host: bool = False
37
+ blank_caption_prob: float = 0.0
38
+ clip_score_column: str = "clip_score"
39
+ min_clip_score: float = None
40
+ max_clip_score: float = None
41
+ filter_column: str = None
42
+ filter_value: str = None
43
+ multi_eval_ds: bool = False
44
+ train_dataset: Dataset = field(init=False)
45
+ eval_dataset: Dataset = field(init=False)
46
+ other_eval_datasets: list = field(init=False)
47
+ rng_dataset: jnp.ndarray = field(init=False)
48
+ multi_hosts: bool = field(init=False)
49
+
50
+ def __post_init__(self):
51
+ if self.seed_dataset is None:
52
+ # create a random seed
53
+ self.seed_dataset = random.randint(0, 2**32 - 1)
54
+ # set numpy rng
55
+ self.np_rng = np.random.default_rng(self.seed_dataset)
56
+ self.multi_hosts = jax.process_count() > 1
57
+ # feed blank captions only in streaming mode for now
58
+ # otherwise dataset could be cached with same blanked captions
59
+ if self.blank_caption_prob:
60
+ assert self.streaming is True, "blank_caption_prob can only be used in streaming mode"
61
+ # define data_files
62
+ if self.train_file is not None or self.validation_file is not None:
63
+ # accept braceexpand notation
64
+ for k in ["train_file", "validation_file"]:
65
+ f = getattr(self, k)
66
+ if isinstance(f, str):
67
+ setattr(self, k, list(braceexpand(f)))
68
+ # for list of files, split training data shards by host
69
+ if isinstance(self.train_file, list) and self.multi_hosts and self.shard_by_host:
70
+ self.train_file = self.train_file[jax.process_index() :: jax.process_count()]
71
+ data_files = {
72
+ "train": self.train_file,
73
+ "validation": self.validation_file,
74
+ }
75
+ else:
76
+ data_files = None
77
+
78
+ # multiple validation datasets
79
+ if self.multi_eval_ds:
80
+ assert Path(
81
+ self.dataset_repo_or_path
82
+ ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds"
83
+ data_files = {
84
+ split.name: [str(f) for f in split.glob("*.parquet")]
85
+ for split in Path(self.dataset_repo_or_path).glob("*")
86
+ }
87
+ # rename "valid" to "validation" if present for consistency
88
+ if "valid" in data_files:
89
+ data_files["validation"] = data_files["valid"]
90
+ del data_files["valid"]
91
+ self.dataset_repo_or_path = "parquet"
92
+
93
+ # load dataset
94
+ dataset = load_dataset(
95
+ self.dataset_repo_or_path,
96
+ data_files=data_files,
97
+ streaming=self.streaming,
98
+ use_auth_token=self.use_auth_token,
99
+ )
100
+ if self.do_train:
101
+ if "train" not in dataset:
102
+ raise ValueError("Training requires a training dataset")
103
+ self.train_dataset = dataset["train"]
104
+ if self.max_train_samples is not None:
105
+ self.train_dataset = (
106
+ self.train_dataset.take(self.max_train_samples)
107
+ if self.streaming
108
+ else self.train_dataset.select(range(self.max_train_samples))
109
+ )
110
+ if self.do_eval:
111
+ if "validation" not in dataset:
112
+ raise ValueError("Evaluating requires a validation dataset")
113
+ self.eval_dataset = dataset["validation"]
114
+ if self.max_eval_samples is not None:
115
+ self.eval_dataset = (
116
+ self.eval_dataset.take(self.max_eval_samples)
117
+ if self.streaming
118
+ else self.eval_dataset.select(range(self.max_eval_samples))
119
+ )
120
+ # other eval datasets
121
+ other_eval_splits = dataset.keys() - {"train", "validation"}
122
+ self.other_eval_datasets = {split: dataset[split] for split in other_eval_splits}
123
+
124
+ def preprocess(self, tokenizer, config):
125
+ # get required config variables
126
+ decoder_start_token_id = config.decoder_start_token_id
127
+ normalize_text = config.normalize_text
128
+ max_length = config.max_text_length
129
+
130
+ if self.streaming:
131
+ # we need to shuffle early in streaming mode
132
+ if hasattr(self, "train_dataset"):
133
+ self.train_dataset = self.train_dataset.shuffle(buffer_size=5000, seed=self.seed_dataset)
134
+ else:
135
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
136
+
137
+ # filter data
138
+ partial_filter_function = partial(
139
+ filter_function,
140
+ filter_column=self.filter_column,
141
+ filter_value=self.filter_value,
142
+ clip_score_column=self.clip_score_column,
143
+ min_clip_score=self.min_clip_score,
144
+ max_clip_score=self.max_clip_score,
145
+ )
146
+ for ds in ["train_dataset", "eval_dataset"]:
147
+ if hasattr(self, ds):
148
+ setattr(
149
+ self,
150
+ ds,
151
+ (
152
+ getattr(self, ds).filter(partial_filter_function)
153
+ if self.streaming
154
+ else getattr(self, ds).filter(
155
+ partial_filter_function,
156
+ num_proc=self.preprocessing_num_workers,
157
+ load_from_cache_file=not self.overwrite_cache,
158
+ desc="Filtering datasets",
159
+ )
160
+ ),
161
+ )
162
+ if hasattr(self, "other_eval_datasets"):
163
+ self.other_eval_datasets = {
164
+ split: (
165
+ ds.filter(partial_filter_function)
166
+ if self.streaming
167
+ else ds.filter(
168
+ partial_filter_function,
169
+ num_proc=self.preprocessing_num_workers,
170
+ load_from_cache_file=not self.overwrite_cache,
171
+ desc="Filtering datasets",
172
+ )
173
+ )
174
+ for split, ds in self.other_eval_datasets.items()
175
+ }
176
+
177
+ # normalize text
178
+ if normalize_text:
179
+ text_normalizer = TextNormalizer()
180
+ partial_normalize_function = partial(
181
+ normalize_function,
182
+ text_column=self.text_column,
183
+ text_normalizer=text_normalizer,
184
+ )
185
+ for ds in ["train_dataset", "eval_dataset"]:
186
+ if hasattr(self, ds):
187
+ setattr(
188
+ self,
189
+ ds,
190
+ (
191
+ getattr(self, ds).map(partial_normalize_function)
192
+ if self.streaming
193
+ else getattr(self, ds).map(
194
+ partial_normalize_function,
195
+ num_proc=self.preprocessing_num_workers,
196
+ load_from_cache_file=not self.overwrite_cache,
197
+ desc="Normalizing datasets",
198
+ )
199
+ ),
200
+ )
201
+ if hasattr(self, "other_eval_datasets"):
202
+ self.other_eval_datasets = {
203
+ split: (
204
+ ds.map(partial_normalize_function)
205
+ if self.streaming
206
+ else ds.map(
207
+ partial_normalize_function,
208
+ num_proc=self.preprocessing_num_workers,
209
+ load_from_cache_file=not self.overwrite_cache,
210
+ desc="Normalizing datasets",
211
+ )
212
+ )
213
+ for split, ds in self.other_eval_datasets.items()
214
+ }
215
+
216
+ # blank captions
217
+ if self.blank_caption_prob:
218
+ partial_blank_caption_function = partial(
219
+ blank_caption_function,
220
+ text_column=self.text_column,
221
+ blank_caption_prob=self.blank_caption_prob,
222
+ rng=self.np_rng,
223
+ )
224
+ if hasattr(self, "train_dataset"):
225
+ self.train_dataset = (
226
+ self.train_dataset.map(partial_blank_caption_function)
227
+ if self.streaming
228
+ else self.train_dataset.map(
229
+ partial_blank_caption_function,
230
+ num_proc=None if self.seed_dataset else self.preprocessing_num_workers,
231
+ load_from_cache_file=False,
232
+ desc="Blanking some captions",
233
+ )
234
+ )
235
+
236
+ # preprocess
237
+ partial_preprocess_function = partial(
238
+ preprocess_function,
239
+ tokenizer=tokenizer,
240
+ text_column=self.text_column,
241
+ encoding_column=self.encoding_column,
242
+ max_length=max_length,
243
+ decoder_start_token_id=decoder_start_token_id,
244
+ )
245
+ for ds in ["train_dataset", "eval_dataset"]:
246
+ if hasattr(self, ds):
247
+ setattr(
248
+ self,
249
+ ds,
250
+ (
251
+ getattr(self, ds).map(
252
+ partial_preprocess_function,
253
+ batched=True,
254
+ remove_columns=[
255
+ self.text_column,
256
+ self.encoding_column,
257
+ ],
258
+ )
259
+ if self.streaming
260
+ else getattr(self, ds).map(
261
+ partial_preprocess_function,
262
+ batched=True,
263
+ remove_columns=getattr(ds, "column_names"),
264
+ num_proc=self.preprocessing_num_workers,
265
+ load_from_cache_file=not self.overwrite_cache,
266
+ desc="Preprocessing datasets",
267
+ )
268
+ ),
269
+ )
270
+ if hasattr(self, "other_eval_datasets"):
271
+ self.other_eval_datasets = {
272
+ split: (
273
+ ds.map(
274
+ partial_preprocess_function,
275
+ batched=True,
276
+ remove_columns=[
277
+ self.text_column,
278
+ self.encoding_column,
279
+ ],
280
+ )
281
+ if self.streaming
282
+ else ds.map(
283
+ partial_preprocess_function,
284
+ batched=True,
285
+ remove_columns=getattr(ds, "column_names"),
286
+ num_proc=self.preprocessing_num_workers,
287
+ load_from_cache_file=not self.overwrite_cache,
288
+ desc="Preprocessing datasets",
289
+ )
290
+ )
291
+ for split, ds in self.other_eval_datasets.items()
292
+ }
293
+
294
+ def dataloader(self, split, batch_size, epoch=None):
295
+ def _dataloader_datasets_non_streaming(
296
+ dataset: Dataset,
297
+ rng: jax.random.PRNGKey = None,
298
+ ):
299
+ """
300
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
301
+ Shuffle batches if rng is set.
302
+ """
303
+ steps_per_epoch = len(dataset) // batch_size
304
+
305
+ if rng is not None:
306
+ batch_idx = jax.random.permutation(rng, len(dataset))
307
+ else:
308
+ batch_idx = jnp.arange(len(dataset))
309
+
310
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
311
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
312
+
313
+ for idx in batch_idx:
314
+ batch = dataset[idx]
315
+ batch = {k: jnp.array(v) for k, v in batch.items()}
316
+ yield batch
317
+
318
+ def _dataloader_datasets_streaming(
319
+ dataset: Dataset,
320
+ epoch: int,
321
+ ):
322
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
323
+ batch = {k: [] for k in keys}
324
+ first_loop = True # stop after one loop in some cases
325
+ while (self.multi_hosts and split == "train") or first_loop:
326
+ # in multi-host, we run forever (no epoch) as hosts need to stop
327
+ # at the same time and training data may not be split equally
328
+ # For validation data we put the entire batch on each host and then
329
+ # keep only the one specific to each host (could be improved but not necessary)
330
+ if epoch is not None:
331
+ assert split == "train"
332
+ # reshuffle training data at each epoch
333
+ dataset.set_epoch(epoch)
334
+ epoch += 1
335
+ for item in dataset:
336
+ for k in keys:
337
+ batch[k].append(item[k])
338
+ if len(batch[keys[0]]) == batch_size:
339
+ batch = {k: jnp.array(v) for k, v in batch.items()}
340
+ yield batch
341
+ batch = {k: [] for k in keys}
342
+ first_loop = False
343
+
344
+ if split == "train":
345
+ ds = self.train_dataset
346
+ elif split == "eval":
347
+ ds = self.eval_dataset
348
+ else:
349
+ ds = self.other_eval_datasets[split]
350
+
351
+ if self.streaming:
352
+ return _dataloader_datasets_streaming(ds, epoch)
353
+ else:
354
+ if split == "train":
355
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
356
+ return _dataloader_datasets_non_streaming(ds, input_rng)
357
+
358
+ @property
359
+ def length(self):
360
+ len_train_dataset, len_eval_dataset = None, None
361
+ if self.streaming:
362
+ # we don't know the length, let's just assume max_samples if defined
363
+ if self.max_train_samples is not None:
364
+ len_train_dataset = self.max_train_samples
365
+ if self.max_eval_samples is not None:
366
+ len_eval_dataset = self.max_eval_samples
367
+ else:
368
+ len_train_dataset = len(self.train_dataset) if hasattr(self, "train_dataset") else None
369
+ len_eval_dataset = len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
370
+ return len_train_dataset, len_eval_dataset
371
+
372
+
373
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
374
+ """
375
+ Shift input ids one token to the right.
376
+ """
377
+ shifted_input_ids = np.zeros(input_ids.shape)
378
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
379
+ shifted_input_ids[:, 0] = decoder_start_token_id
380
+ return shifted_input_ids
381
+
382
+
383
+ def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
384
+ if blank_caption_prob and (rng.random() if rng is not None else np.random.random()) < blank_caption_prob:
385
+ example[text_column] = ""
386
+ return example
387
+
388
+
389
+ def normalize_function(example, text_column, text_normalizer):
390
+ example[text_column] = text_normalizer(example[text_column])
391
+ return example
392
+
393
+
394
+ def filter_function(
395
+ example,
396
+ min_clip_score,
397
+ max_clip_score,
398
+ clip_score_column,
399
+ filter_column,
400
+ filter_value,
401
+ ):
402
+ if min_clip_score is not None and example[clip_score_column] < min_clip_score:
403
+ return False
404
+ if max_clip_score is not None and example[clip_score_column] > max_clip_score:
405
+ return False
406
+ if filter_column is not None and example[filter_column] != filter_value:
407
+ return False
408
+ return True
409
+
410
+
411
+ def preprocess_function(
412
+ examples,
413
+ tokenizer,
414
+ text_column,
415
+ encoding_column,
416
+ max_length,
417
+ decoder_start_token_id,
418
+ ):
419
+ inputs = examples[text_column]
420
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
421
+ model_inputs = tokenizer(
422
+ inputs,
423
+ max_length=max_length,
424
+ padding="max_length",
425
+ truncation=True,
426
+ return_tensors="np",
427
+ )
428
+
429
+ # set up targets
430
+ # Note: labels correspond to our target indices
431
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
432
+ labels = examples[encoding_column]
433
+ labels = np.asarray(labels)
434
+
435
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
436
+ model_inputs["labels"] = labels
437
+
438
+ # In our case, this prepends the bos token and removes the last one
439
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
440
+ model_inputs["decoder_input_ids"] = decoder_input_ids
441
+
442
+ return model_inputs
@@ -0,0 +1,5 @@
1
+ from .configuration import DalleBartConfig
2
+ from .modeling import DalleBart
3
+ from .partitions import set_partitions
4
+ from .processor import DalleBartProcessor
5
+ from .tokenizer import DalleBartTokenizer
@@ -0,0 +1,175 @@
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model configuration """
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ from .utils import PretrainedFromWandbMixin
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
27
+ model_type = "dallebart"
28
+ keys_to_ignore_at_inference = ["past_key_values"]
29
+ attribute_map = {
30
+ "num_attention_heads": "encoder_attention_heads",
31
+ "hidden_size": "d_model",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ normalize_text=False,
37
+ encoder_vocab_size=50264,
38
+ image_vocab_size=16384, # encoded image token space
39
+ image_length=256, # number of encoded tokens
40
+ max_text_length=64, # max number of text tokens
41
+ encoder_layers=12,
42
+ encoder_ffn_dim=4096,
43
+ encoder_attention_heads=16,
44
+ decoder_layers=12,
45
+ decoder_ffn_dim=4096,
46
+ decoder_attention_heads=16,
47
+ activation_function="gelu",
48
+ d_model=1024,
49
+ dropout=0.1,
50
+ attention_dropout=0.0,
51
+ activation_dropout=0.0,
52
+ init_std=0.02,
53
+ scale_embedding=False,
54
+ gradient_checkpointing=True,
55
+ use_scan=None,
56
+ use_cache=True,
57
+ is_encoder_decoder=True,
58
+ forced_eos_token_id=None,
59
+ tie_word_embeddings=False, # different modalities and sizes
60
+ do_sample=True,
61
+ # transformer variants
62
+ use_bias=False, # use bias in attention and dense layers (except for lm_head)
63
+ ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
64
+ ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln), "subln"
65
+ use_head_scale=False, # used in NormFormer
66
+ use_cosine_attention=False, # used in Swin v2
67
+ tau_init=0.05, # used only in cosine attention (Swin v2)
68
+ use_absolute_position_embeddings=True, # default
69
+ use_swin_position_embeddings=False, # used in Swin v1/v2
70
+ use_deepnet_scaling=False, # used in Deepnet
71
+ use_subln_init=False,
72
+ use_glu=True, # "GLU Variants Improve Transformer"
73
+ use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
74
+ sinkhorn_iters=1, # used in SinkFormers
75
+ use_final_ln_encoder=True, # final layer normalization in encoder
76
+ use_final_ln_decoder=True, # final layer normalization in decoder
77
+ # parameters that should not be necessary but could affect results
78
+ force_ln_scale=False, # force scale in layernorm even when followed by dense layers
79
+ **kwargs,
80
+ ):
81
+ # text normalizer
82
+ self.normalize_text = normalize_text
83
+
84
+ # transformer variants
85
+ self.use_bias = use_bias
86
+ assert ln_type in [
87
+ "rmsnorm",
88
+ "layernorm",
89
+ ], "ln_type must be 'rmsnorm' or 'layernorm'"
90
+ self.ln_type = ln_type
91
+ if ln_positions == "deepnet":
92
+ ln_positions = "postln"
93
+ assert ln_positions in [
94
+ "normformer",
95
+ "swinv2",
96
+ "cogview",
97
+ "postln",
98
+ "preln",
99
+ "subln",
100
+ ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln', 'subln'"
101
+ self.use_head_scale = use_head_scale
102
+ assert use_alibi is False, "use_alibi is not supported yet"
103
+ self.ln_positions = ln_positions
104
+ self.use_cosine_attention = use_cosine_attention
105
+ self.tau_init = tau_init
106
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
107
+ self.use_swin_position_embeddings = use_swin_position_embeddings
108
+ self.use_deepnet_scaling = use_deepnet_scaling
109
+ self.use_subln_init = use_subln_init
110
+ self.use_glu = use_glu
111
+ self.use_alibi = use_alibi
112
+ self.sinkhorn_iters = sinkhorn_iters
113
+ if ln_positions == "postln":
114
+ assert use_final_ln_encoder, "use_final_ln_encoder must be True when ln_positions is 'postln'"
115
+ assert use_final_ln_decoder, "use_final_ln_decoder must be True when ln_positions is 'postln'"
116
+ self.use_final_ln_encoder = use_final_ln_encoder
117
+ self.use_final_ln_decoder = use_final_ln_decoder
118
+ self.force_ln_scale = force_ln_scale
119
+
120
+ # common parameters
121
+ self.encoder_vocab_size = encoder_vocab_size
122
+ self.image_vocab_size = image_vocab_size
123
+ self.image_length = image_length
124
+ self.max_text_length = max_text_length
125
+ self.d_model = d_model
126
+ self.encoder_ffn_dim = encoder_ffn_dim
127
+ self.encoder_layers = encoder_layers
128
+ self.encoder_attention_heads = encoder_attention_heads
129
+ self.decoder_ffn_dim = decoder_ffn_dim
130
+ self.decoder_layers = decoder_layers
131
+ self.decoder_attention_heads = decoder_attention_heads
132
+ self.dropout = dropout
133
+ self.attention_dropout = attention_dropout
134
+ self.activation_dropout = activation_dropout
135
+ self.activation_function = activation_function
136
+ self.init_std = init_std
137
+ self.use_cache = use_cache
138
+ self.gradient_checkpointing = gradient_checkpointing
139
+ # all layers are the same in most configurations
140
+ self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
141
+ assert not (self.use_scan and ln_positions == "swinv2"), "scan cannot be used with 'swinv2'"
142
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
143
+
144
+ # special token id's are appended to vocab if not provided
145
+ decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
146
+ bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
147
+ pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
148
+ eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
149
+
150
+ # we generate to image_length + 1 (for bos) by default
151
+ min_length = kwargs.pop("min_length", image_length + 1)
152
+ max_length = kwargs.pop("max_length", image_length + 1)
153
+
154
+ super().__init__(
155
+ # args required in parent class
156
+ is_encoder_decoder=is_encoder_decoder,
157
+ tie_word_embeddings=tie_word_embeddings,
158
+ forced_eos_token_id=forced_eos_token_id,
159
+ decoder_start_token_id=decoder_start_token_id,
160
+ bos_token_id=bos_token_id,
161
+ pad_token_id=pad_token_id,
162
+ eos_token_id=eos_token_id,
163
+ min_length=min_length,
164
+ max_length=max_length,
165
+ do_sample=do_sample,
166
+ **kwargs,
167
+ )
168
+
169
+ # ensure backward compatibility for BART CNN models
170
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
171
+ self.forced_bos_token_id = self.bos_token_id
172
+ warnings.warn(
173
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
174
+ "The config can simply be saved and uploaded again to be fixed."
175
+ )