crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,84 @@
1
+ import re
2
+
3
+ from helm.common.optional_dependencies import handle_module_not_found_error
4
+
5
+ try:
6
+ from flax.core.frozen_dict import freeze
7
+ from flax.traverse_util import flatten_dict, unflatten_dict
8
+ from jax.experimental import PartitionSpec as P
9
+ except ModuleNotFoundError as e:
10
+ handle_module_not_found_error(e, ["heim"])
11
+
12
+
13
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
14
+ # Sentinels
15
+ _unmatched = object()
16
+
17
+ # For specifying empty leaf dict `{}`
18
+ empty_dict = object()
19
+
20
+
21
+ def _match(qs, ks):
22
+ """Return True if regexes in qs match any window of strings in tuple ks."""
23
+ # compile regexes and force complete match
24
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
25
+ for i in range(len(ks) - len(qs) + 1):
26
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
27
+ if matches and all(matches):
28
+ return True
29
+ return False
30
+
31
+
32
+ def _replacement_rules(rules):
33
+ def replace(key, val):
34
+ for rule, replacement in rules:
35
+ if _match(rule, key):
36
+ return replacement
37
+ return val
38
+
39
+ return replace
40
+
41
+
42
+ def _get_partition_rules():
43
+ return [
44
+ # embeddings
45
+ (("embed_positions", "embedding"), P("mp", None)),
46
+ (("embed_tokens", "embedding"), P("mp", None)),
47
+ (("rel_bias", "embedding"), P(None, "mp")),
48
+ # attention
49
+ (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
50
+ (("out_proj", "kernel"), P("mp", None)),
51
+ # FFN
52
+ (("Dense_0", "kernel"), P(None, "mp")),
53
+ (("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
54
+ (("GLU.*", "Dense_2", "kernel"), P("mp", None)),
55
+ (("FFN.*", "Dense_1", "kernel"), P("mp", None)),
56
+ # layer norms
57
+ (("(bias|scale)",), None),
58
+ (("lm_head", "kernel"), P(None, "mp")),
59
+ # head scale and tau
60
+ (("(head_scale|tau)",), None),
61
+ ]
62
+
63
+
64
+ def set_partitions(in_dict, use_scan):
65
+ rules = _get_partition_rules()
66
+ replace = _replacement_rules(rules)
67
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
68
+ result = {k: replace(k, v) for k, v in initd.items()}
69
+ for k, v in result.items():
70
+ if v == _unmatched:
71
+ print(f"Unmatched -> {k}")
72
+ l = list(result.keys())
73
+ if use_scan:
74
+ # add None dimension to layers
75
+ result = {
76
+ k: (
77
+ (P(*(None,) + v) if v is not None else None)
78
+ if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
79
+ else v
80
+ )
81
+ for k, v in result.items()
82
+ }
83
+ assert _unmatched not in result.values(), "Incomplete partition spec."
84
+ return freeze(unflatten_dict(result))
@@ -0,0 +1,63 @@
1
+ """ DalleBart processor """
2
+
3
+ from typing import List
4
+
5
+ from .configuration import DalleBartConfig
6
+ from .text import TextNormalizer
7
+ from .tokenizer import DalleBartTokenizer
8
+ from .utils import PretrainedFromWandbMixin
9
+ from helm.common.optional_dependencies import handle_module_not_found_error
10
+
11
+
12
+ class DalleBartProcessorBase:
13
+ def __init__(self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int):
14
+ self.tokenizer = tokenizer
15
+ self.normalize_text = normalize_text
16
+ self.max_text_length = max_text_length
17
+ if normalize_text:
18
+ self.text_processor = TextNormalizer()
19
+ # create unconditional tokens
20
+ uncond = self.tokenizer(
21
+ "",
22
+ return_tensors="jax",
23
+ padding="max_length",
24
+ truncation=True,
25
+ max_length=self.max_text_length,
26
+ ).data
27
+ self.input_ids_uncond = uncond["input_ids"]
28
+ self.attention_mask_uncond = uncond["attention_mask"]
29
+
30
+ def __call__(self, text: List[str] = None):
31
+ try:
32
+ import jax.numpy as jnp
33
+ except ModuleNotFoundError as e:
34
+ handle_module_not_found_error(e, ["heim"])
35
+
36
+ # check that text is not a string
37
+ assert not isinstance(text, str), "text must be a list of strings"
38
+
39
+ if self.normalize_text:
40
+ text = [self.text_processor(t) for t in text]
41
+ res = self.tokenizer(
42
+ text,
43
+ return_tensors="jax",
44
+ padding="max_length",
45
+ truncation=True,
46
+ max_length=self.max_text_length,
47
+ ).data
48
+
49
+ # tokens used only with super conditioning
50
+ n = len(text)
51
+ res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
52
+ res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
53
+ return res
54
+
55
+ @classmethod
56
+ def from_pretrained(cls, *args, **kwargs):
57
+ tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
58
+ config = DalleBartConfig.from_pretrained(*args, **kwargs)
59
+ return cls(tokenizer, config.normalize_text, config.max_text_length)
60
+
61
+
62
+ class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
63
+ pass
@@ -0,0 +1,251 @@
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
+ from pathlib import Path
10
+
11
+ import emoji
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ from helm.common.optional_dependencies import handle_module_not_found_error
15
+
16
+ try:
17
+ import ftfy
18
+ from unidecode import unidecode
19
+ except ModuleNotFoundError as e:
20
+ handle_module_not_found_error(e, ["heim"])
21
+
22
+
23
+ # based on wiki word occurrence
24
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
25
+ temp_token = "xtokx" # avoid repeating chars
26
+
27
+
28
+ class HashtagProcessor:
29
+ # Adapted from wordninja library
30
+ # We use our wikipedia word count + a good heuristic to make it work
31
+ def __init__(self):
32
+ wiki_word_frequency = hf_hub_download("dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt")
33
+ self._word_cost = (l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines())
34
+ self._word_cost = {str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)}
35
+ self._max_word = max(len(x) for x in self._word_cost.keys())
36
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
37
+
38
+ def __call__(self, s):
39
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
40
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
41
+ return " ".join([item for sublist in l for item in sublist])
42
+
43
+ def _split(self, s):
44
+ # Find the best match for the i first characters, assuming cost has
45
+ # been built for the i-1 first characters.
46
+ # Returns a pair (match_cost, match_length).
47
+ def best_match(i):
48
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
49
+ return min((c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1) for k, c in candidates)
50
+
51
+ # Build the cost array
52
+ cost = [0]
53
+ for i in range(1, len(s) + 1):
54
+ c, k = best_match(i)
55
+ cost.append(c)
56
+
57
+ # Backtrack to recover the minimal-cost string.
58
+ out = []
59
+ i = len(s)
60
+ while i > 0:
61
+ c, k = best_match(i)
62
+ assert c == cost[i]
63
+ newToken = True
64
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
65
+ if len(out) > 0:
66
+ # re-attach split 's and split digits
67
+ if out[-1] == "'s" or (s[i - 1].isdigit() and out[-1][0].isdigit()): # digit followed by digit
68
+ out[-1] = s[i - k : i] + out[-1] # combine current token with previous token
69
+ newToken = False
70
+
71
+ if newToken:
72
+ out.append(s[i - k : i])
73
+
74
+ i -= k
75
+
76
+ return reversed(out)
77
+
78
+
79
+ def replace_person_token(t):
80
+ "Used for CC12M"
81
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
82
+ while "<person>" in t:
83
+ t = t.replace("<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1)
84
+ return t
85
+
86
+
87
+ def fix_html(t):
88
+ # from OpenAI CLIP
89
+ return html.unescape(html.unescape(t))
90
+
91
+
92
+ def replace_punctuation_with_commas(t):
93
+ return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
94
+
95
+
96
+ def simplify_quotes(t):
97
+ return re.sub("""['"`]""", ' " ', t)
98
+
99
+
100
+ def merge_quotes(t):
101
+ return re.sub('(\s*"+\s*)+', ' " ', t)
102
+
103
+
104
+ def remove_comma_numbers(t):
105
+ def _f(t):
106
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
107
+
108
+ return _f(_f(t))
109
+
110
+
111
+ def pre_process_dot_numbers(t):
112
+ return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
113
+
114
+
115
+ def post_process_dot_numbers(t):
116
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
117
+
118
+
119
+ def pre_process_quotes(t):
120
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
121
+ return re.sub(r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t)
122
+
123
+
124
+ def post_process_quotes(t):
125
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
126
+
127
+
128
+ def pre_process_dates(t):
129
+ return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
130
+
131
+
132
+ def post_process_dates(t):
133
+ return re.sub(f"{temp_token}slash{temp_token}", "/", t)
134
+
135
+
136
+ def merge_commas(t):
137
+ return re.sub("(\s*,+\s*)+", ", ", t)
138
+
139
+
140
+ def add_space_after_commas(t):
141
+ return re.sub(",", ", ", t)
142
+
143
+
144
+ def handle_special_chars(t):
145
+ "Handle special characters"
146
+ # replace "-" with a space when between words without space
147
+ t = re.sub("(\w)-(\w)", r"\1 \2", t)
148
+ # always add space around some characters
149
+ return re.sub("([%&\/$*])", r" \1 ", t)
150
+
151
+
152
+ def expand_hashtags(t, hashtag_processor):
153
+ "Remove # and try to split words"
154
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
155
+
156
+
157
+ _re_ignore_chars = r"[_#\\]"
158
+
159
+
160
+ def ignore_chars(t):
161
+ "Ignore useless characters"
162
+ return re.sub(_re_ignore_chars, " ", t)
163
+
164
+
165
+ def remove_extra_spaces(t):
166
+ "Remove extra spaces (including \t and \n)"
167
+ return re.sub("\s+", " ", t)
168
+
169
+
170
+ def remove_repeating_chars(t):
171
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
172
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
173
+
174
+
175
+ def remove_urls(t):
176
+ return re.sub(r"http\S+", "", t)
177
+
178
+
179
+ def remove_html_tags(t):
180
+ return re.sub("<[^<]+?>", " ", t)
181
+
182
+
183
+ def remove_first_last_commas(t):
184
+ t = t.strip()
185
+ t = t[:-1] if t and t[-1] == "," else t
186
+ t = t[1:] if t and t[0] == "," else t
187
+ return t.strip()
188
+
189
+
190
+ def remove_wiki_ref(t):
191
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
192
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
193
+
194
+
195
+ class TextNormalizer:
196
+ "Normalize text"
197
+
198
+ def __init__(self):
199
+ self._hashtag_processor = HashtagProcessor()
200
+
201
+ def __call__(self, t):
202
+ # fix some characters
203
+ t = ftfy.fix_text(t)
204
+ # fix html
205
+ t = fix_html(t)
206
+ # decode emojis (would be removed by unidecode)
207
+ t = emoji.demojize(t)
208
+ # decode and simplify text: see unidecode library
209
+ t = unidecode(t)
210
+ # lower case
211
+ t = t.lower()
212
+ # replace <PERSON> (for CC12M)
213
+ t = replace_person_token(t)
214
+ # remove wiki reference (for WIT)
215
+ t = remove_wiki_ref(t)
216
+ # remove html tags
217
+ t = remove_html_tags(t)
218
+ # remove urls
219
+ t = remove_urls(t)
220
+ # remove commas in numbers
221
+ t = remove_comma_numbers(t)
222
+ # handle dots in numbers and quotes - Part 1
223
+ t = pre_process_dot_numbers(t)
224
+ t = pre_process_quotes(t)
225
+ t = pre_process_dates(t)
226
+ # handle special characters
227
+ t = handle_special_chars(t)
228
+ # handle hashtags
229
+ t = expand_hashtags(t, self._hashtag_processor)
230
+ # ignore useless characters
231
+ t = ignore_chars(t)
232
+ # simplify quotes
233
+ t = simplify_quotes(t)
234
+ # all punctuation becomes commas
235
+ t = replace_punctuation_with_commas(t)
236
+ # handle dots in numbers and quotes - Part 2
237
+ t = post_process_dot_numbers(t)
238
+ t = post_process_quotes(t)
239
+ t = post_process_dates(t)
240
+ # handle repeating characters
241
+ t = remove_repeating_chars(t)
242
+ # merge quotes
243
+ t = merge_quotes(t)
244
+ # merge commas
245
+ t = merge_commas(t)
246
+ # remove multiple spaces
247
+ t = remove_extra_spaces(t)
248
+ # remove first and last comma
249
+ t = remove_first_last_commas(t)
250
+ # always start with a space
251
+ return f" {t}"
@@ -0,0 +1,9 @@
1
+ """ DalleBart tokenizer """
2
+
3
+ from transformers import BartTokenizerFast
4
+
5
+ from .utils import PretrainedFromWandbMixin
6
+
7
+
8
+ class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):
9
+ pass
@@ -0,0 +1,29 @@
1
+ import os
2
+ import tempfile
3
+
4
+ from helm.common.optional_dependencies import handle_module_not_found_error
5
+
6
+
7
+ class PretrainedFromWandbMixin:
8
+ @classmethod
9
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
10
+ """
11
+ Initializes from a wandb artifact or delegates loading to the superclass.
12
+ """
13
+ try:
14
+ import wandb
15
+ except ModuleNotFoundError as e:
16
+ handle_module_not_found_error(e, ["heim"])
17
+
18
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
19
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path):
20
+ # wandb artifact
21
+ if wandb.run is not None:
22
+ artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
23
+ else:
24
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
25
+ pretrained_model_name_or_path = artifact.download(tmp_dir)
26
+
27
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
28
+ pretrained_model_name_or_path, *model_args, **kwargs
29
+ )
@@ -0,0 +1 @@
1
+ from . import *
@@ -0,0 +1,40 @@
1
+ from typing import Tuple
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class VQGANConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ ch: int = 128,
10
+ out_ch: int = 3,
11
+ in_channels: int = 3,
12
+ num_res_blocks: int = 2,
13
+ resolution: int = 256,
14
+ z_channels: int = 256,
15
+ ch_mult: Tuple = (1, 1, 2, 2, 4),
16
+ attn_resolutions: int = (16,),
17
+ n_embed: int = 1024,
18
+ embed_dim: int = 256,
19
+ dropout: float = 0.0,
20
+ double_z: bool = False,
21
+ resamp_with_conv: bool = True,
22
+ give_pre_end: bool = False,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.ch = ch
27
+ self.out_ch = out_ch
28
+ self.in_channels = in_channels
29
+ self.num_res_blocks = num_res_blocks
30
+ self.resolution = resolution
31
+ self.z_channels = z_channels
32
+ self.ch_mult = list(ch_mult)
33
+ self.attn_resolutions = list(attn_resolutions)
34
+ self.n_embed = n_embed
35
+ self.embed_dim = embed_dim
36
+ self.dropout = dropout
37
+ self.double_z = double_z
38
+ self.resamp_with_conv = resamp_with_conv
39
+ self.give_pre_end = give_pre_end
40
+ self.num_resolutions = len(ch_mult)
@@ -0,0 +1,107 @@
1
+ import re
2
+
3
+ import torch
4
+
5
+ from .modeling_flax_vqgan import VQModel
6
+ from .configuration_vqgan import VQGANConfig
7
+ from helm.common.optional_dependencies import handle_module_not_found_error
8
+
9
+ try:
10
+ import jax.numpy as jnp
11
+ from flax.traverse_util import flatten_dict, unflatten_dict
12
+ except ModuleNotFoundError as e:
13
+ handle_module_not_found_error(e, ["heim"])
14
+
15
+
16
+ regex = r"\w+[.]\d+"
17
+
18
+
19
+ def rename_key(key):
20
+ pats = re.findall(regex, key)
21
+ for pat in pats:
22
+ key = key.replace(pat, "_".join(pat.split(".")))
23
+ return key
24
+
25
+
26
+ # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
27
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
28
+ # convert pytorch tensor to numpy
29
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
30
+
31
+ random_flax_state_dict = flatten_dict(flax_model.params)
32
+ flax_state_dict = {}
33
+
34
+ remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
35
+ flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
36
+ )
37
+ add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
38
+ flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
39
+ )
40
+
41
+ # Need to change some parameters name to match Flax names so that we don't have to fork any layer
42
+ for pt_key, pt_tensor in pt_state_dict.items():
43
+ pt_tuple_key = tuple(pt_key.split("."))
44
+
45
+ has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
46
+ require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
47
+
48
+ if remove_base_model_prefix and has_base_model_prefix:
49
+ pt_tuple_key = pt_tuple_key[1:]
50
+ elif add_base_model_prefix and require_base_model_prefix:
51
+ pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
52
+
53
+ # Correctly rename weight parameters
54
+ if (
55
+ "norm" in pt_key
56
+ and (pt_tuple_key[-1] == "bias")
57
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
58
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
59
+ ):
60
+ pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
61
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
62
+ pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
63
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
64
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
65
+ elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
66
+ # conv layer
67
+ pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
68
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
69
+ elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
70
+ # linear layer
71
+ pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
72
+ pt_tensor = pt_tensor.T
73
+ elif pt_tuple_key[-1] == "gamma":
74
+ pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
75
+ elif pt_tuple_key[-1] == "beta":
76
+ pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
77
+
78
+ if pt_tuple_key in random_flax_state_dict:
79
+ if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
80
+ raise ValueError(
81
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
82
+ f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
83
+ )
84
+
85
+ # also add unexpected weight so that warning is thrown
86
+ flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
87
+
88
+ return unflatten_dict(flax_state_dict)
89
+
90
+
91
+ def convert_model(config_path, pt_state_dict_path, save_path):
92
+ config = VQGANConfig.from_pretrained(config_path)
93
+ model = VQModel(config)
94
+
95
+ state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
96
+ keys = list(state_dict.keys())
97
+ for key in keys:
98
+ if key.startswith("loss"):
99
+ state_dict.pop(key)
100
+ continue
101
+ renamed_key = rename_key(key)
102
+ state_dict[renamed_key] = state_dict.pop(key)
103
+
104
+ state = convert_pytorch_state_dict_to_flax(state_dict, model)
105
+ model.params = state
106
+ model.save_pretrained(save_path)
107
+ return model