crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +37 -2
  5. helm/benchmark/adaptation/adapters/adapter.py +4 -42
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
  9. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
  10. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
  11. helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
  12. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  13. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  14. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
  15. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
  16. helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
  17. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  18. helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
  19. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
  20. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
  21. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  22. helm/benchmark/adaptation/prompt.py +7 -1
  23. helm/benchmark/adaptation/request_state.py +6 -1
  24. helm/benchmark/adaptation/scenario_state.py +6 -2
  25. helm/benchmark/annotation/annotator.py +43 -0
  26. helm/benchmark/annotation/annotator_factory.py +61 -0
  27. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  28. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  29. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  30. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  31. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  32. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  33. helm/benchmark/annotation_executor.py +124 -0
  34. helm/benchmark/augmentations/cleva_perturbation.py +7 -14
  35. helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
  36. helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
  37. helm/benchmark/augmentations/data_augmenter.py +0 -2
  38. helm/benchmark/augmentations/dialect_perturbation.py +2 -2
  39. helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
  40. helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
  41. helm/benchmark/augmentations/gender_perturbation.py +3 -3
  42. helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
  43. helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
  44. helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
  45. helm/benchmark/augmentations/person_name_perturbation.py +0 -7
  46. helm/benchmark/augmentations/perturbation.py +20 -7
  47. helm/benchmark/augmentations/perturbation_description.py +1 -1
  48. helm/benchmark/augmentations/space_perturbation.py +2 -2
  49. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  50. helm/benchmark/augmentations/synonym_perturbation.py +2 -2
  51. helm/benchmark/augmentations/test_perturbation.py +11 -7
  52. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  53. helm/benchmark/augmentations/typos_perturbation.py +2 -2
  54. helm/benchmark/config_registry.py +38 -0
  55. helm/benchmark/executor.py +46 -16
  56. helm/benchmark/huggingface_registration.py +37 -7
  57. helm/benchmark/metrics/basic_metrics.py +172 -641
  58. helm/benchmark/metrics/bbq_metrics.py +3 -4
  59. helm/benchmark/metrics/bias_metrics.py +6 -6
  60. helm/benchmark/metrics/classification_metrics.py +11 -8
  61. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  62. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  63. helm/benchmark/metrics/code_metrics.py +4 -3
  64. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  65. helm/benchmark/metrics/common_metric_specs.py +167 -0
  66. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  67. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  68. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  69. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  70. helm/benchmark/metrics/disinformation_metrics.py +6 -112
  71. helm/benchmark/metrics/dry_run_metrics.py +5 -3
  72. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  73. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  74. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  75. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  76. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  77. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  78. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  79. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  80. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  81. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  82. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  83. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  84. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  85. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  86. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  87. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  88. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  89. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  90. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  91. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  92. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  93. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  94. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  95. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  96. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  97. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  98. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  99. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  100. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  101. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  102. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  103. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  104. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  105. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  106. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  107. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  108. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  109. helm/benchmark/metrics/machine_translation_metrics.py +5 -5
  110. helm/benchmark/metrics/metric.py +93 -172
  111. helm/benchmark/metrics/metric_name.py +0 -1
  112. helm/benchmark/metrics/metric_service.py +16 -0
  113. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  114. helm/benchmark/metrics/ranking_metrics.py +6 -7
  115. helm/benchmark/metrics/reference_metric.py +148 -0
  116. helm/benchmark/metrics/summac/model_summac.py +0 -2
  117. helm/benchmark/metrics/summarization_metrics.py +8 -8
  118. helm/benchmark/metrics/test_classification_metrics.py +9 -6
  119. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  120. helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
  121. helm/benchmark/metrics/test_metric.py +2 -2
  122. helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
  123. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
  124. helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
  125. helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
  126. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
  127. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  128. helm/benchmark/metrics/toxicity_utils.py +23 -0
  129. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  130. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  131. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  132. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  133. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  134. helm/benchmark/model_deployment_registry.py +164 -41
  135. helm/benchmark/model_metadata_registry.py +181 -35
  136. helm/benchmark/multi_gpu_runner.py +133 -0
  137. helm/benchmark/presentation/contamination.py +3 -3
  138. helm/benchmark/presentation/create_plots.py +8 -7
  139. helm/benchmark/presentation/run_display.py +50 -17
  140. helm/benchmark/presentation/schema.py +28 -46
  141. helm/benchmark/presentation/summarize.py +213 -96
  142. helm/benchmark/presentation/table.py +8 -8
  143. helm/benchmark/presentation/test_contamination.py +2 -2
  144. helm/benchmark/presentation/test_run_entry.py +14 -9
  145. helm/benchmark/presentation/test_summarize.py +5 -0
  146. helm/benchmark/run.py +66 -54
  147. helm/benchmark/run_expander.py +342 -31
  148. helm/benchmark/run_spec.py +93 -0
  149. helm/benchmark/run_spec_factory.py +162 -0
  150. helm/benchmark/run_specs/__init__.py +0 -0
  151. helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
  152. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  153. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  154. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  155. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  156. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  157. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  158. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  159. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  160. helm/benchmark/runner.py +116 -69
  161. helm/benchmark/runner_config_registry.py +21 -0
  162. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  163. helm/benchmark/scenarios/bold_scenario.py +2 -2
  164. helm/benchmark/scenarios/cleva_scenario.py +43 -46
  165. helm/benchmark/scenarios/code_scenario.py +3 -2
  166. helm/benchmark/scenarios/commonsense_scenario.py +171 -191
  167. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  168. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  169. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  170. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  171. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  172. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  173. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  174. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  175. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  176. helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
  177. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  178. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  179. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  180. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  181. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  182. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  183. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  184. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  185. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  186. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  187. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  188. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  189. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  190. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  191. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  192. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  193. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  194. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  195. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  196. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  197. helm/benchmark/scenarios/legalbench_scenario.py +123 -0
  198. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  199. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  200. helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
  201. helm/benchmark/scenarios/math_scenario.py +19 -2
  202. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  203. helm/benchmark/scenarios/numeracy_scenario.py +3 -3
  204. helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
  205. helm/benchmark/scenarios/raft_scenario.py +2 -6
  206. helm/benchmark/scenarios/scenario.py +14 -2
  207. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  208. helm/benchmark/scenarios/test_math_scenario.py +22 -0
  209. helm/benchmark/scenarios/test_scenario.py +6 -3
  210. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  211. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  212. helm/benchmark/scenarios/the_pile_scenario.py +6 -7
  213. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  214. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  215. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  216. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  217. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  218. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  219. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  220. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  221. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  222. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  223. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  224. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  225. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  226. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  227. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  228. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  229. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  230. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  231. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  232. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  233. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  234. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  235. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  236. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  237. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
  238. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  239. helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
  240. helm/benchmark/server.py +59 -2
  241. helm/benchmark/slurm_jobs.py +12 -0
  242. helm/benchmark/slurm_runner.py +79 -51
  243. helm/benchmark/static/benchmarking.js +3 -4
  244. helm/benchmark/static/contamination.yaml +1 -1
  245. helm/benchmark/static/images/organizations/together.png +0 -0
  246. helm/benchmark/static/json-urls.js +4 -0
  247. helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
  248. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  249. helm/benchmark/static/schema_lite.yaml +824 -0
  250. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  251. helm/benchmark/static/schema_unitxt.yaml +428 -0
  252. helm/benchmark/static/schema_vlm.yaml +576 -0
  253. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  254. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  255. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  256. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  257. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  258. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  259. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  260. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  261. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  262. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  263. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  264. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  265. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  266. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  267. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  268. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  269. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  270. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  271. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  272. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  273. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  274. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  275. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  276. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  277. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  278. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  279. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  280. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  281. helm/benchmark/static_build/config.js +4 -0
  282. helm/benchmark/static_build/index.html +20 -0
  283. helm/benchmark/test_data_preprocessor.py +3 -3
  284. helm/benchmark/test_model_deployment_definition.py +90 -0
  285. helm/benchmark/test_run_expander.py +1 -1
  286. helm/benchmark/tokenizer_config_registry.py +10 -14
  287. helm/benchmark/window_services/ai21_window_service.py +22 -33
  288. helm/benchmark/window_services/cohere_window_service.py +1 -63
  289. helm/benchmark/window_services/default_window_service.py +2 -35
  290. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  291. helm/benchmark/window_services/ice_window_service.py +0 -34
  292. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  293. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  294. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  295. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  296. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  297. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  298. helm/benchmark/window_services/local_window_service.py +21 -4
  299. helm/benchmark/window_services/no_decoding_window_service.py +32 -0
  300. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  301. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  302. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  303. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  304. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  305. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  306. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  307. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  308. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  309. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  310. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  311. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  312. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  313. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  314. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  315. helm/benchmark/window_services/test_utils.py +3 -2
  316. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  317. helm/benchmark/window_services/window_service.py +42 -0
  318. helm/benchmark/window_services/window_service_factory.py +24 -269
  319. helm/benchmark/window_services/yalm_window_service.py +0 -27
  320. helm/clients/__init__.py +0 -0
  321. helm/{proxy/clients → clients}/ai21_client.py +5 -12
  322. helm/clients/aleph_alpha_client.py +112 -0
  323. helm/{proxy/clients → clients}/anthropic_client.py +213 -24
  324. helm/clients/auto_client.py +215 -0
  325. helm/clients/bedrock_client.py +128 -0
  326. helm/clients/bedrock_utils.py +72 -0
  327. helm/{proxy/clients → clients}/client.py +67 -55
  328. helm/clients/clip_score_client.py +49 -0
  329. helm/clients/clip_scorers/__init__.py +0 -0
  330. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  331. helm/clients/clip_scorers/clip_scorer.py +50 -0
  332. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  333. helm/{proxy/clients → clients}/cohere_client.py +6 -17
  334. helm/clients/gcs_client.py +82 -0
  335. helm/{proxy/clients → clients}/google_client.py +7 -8
  336. helm/clients/google_translate_client.py +35 -0
  337. helm/{proxy/clients → clients}/http_model_client.py +6 -10
  338. helm/{proxy/clients → clients}/huggingface_client.py +134 -92
  339. helm/clients/image_generation/__init__.py +0 -0
  340. helm/clients/image_generation/adobe_vision_client.py +78 -0
  341. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  342. helm/clients/image_generation/cogview2/__init__.py +0 -0
  343. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  344. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  345. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  346. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  347. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  348. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  349. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  350. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  351. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  352. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  353. helm/clients/image_generation/cogview2_client.py +191 -0
  354. helm/clients/image_generation/dalle2_client.py +192 -0
  355. helm/clients/image_generation/dalle3_client.py +108 -0
  356. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  357. helm/clients/image_generation/dalle_mini/data.py +442 -0
  358. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  359. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  360. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  361. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  362. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  363. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  364. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  365. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  366. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  367. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  368. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  369. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  370. helm/clients/image_generation/dalle_mini_client.py +190 -0
  371. helm/clients/image_generation/deep_floyd_client.py +78 -0
  372. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  373. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  374. helm/clients/image_generation/lexica_client.py +86 -0
  375. helm/clients/image_generation/mindalle/__init__.py +0 -0
  376. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  377. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  378. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  379. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  380. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  381. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  382. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  383. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  384. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  385. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  386. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  387. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  388. helm/clients/image_generation/mindalle_client.py +115 -0
  389. helm/clients/image_generation/nudity_check_client.py +64 -0
  390. helm/clients/image_generation/together_image_generation_client.py +111 -0
  391. helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
  392. helm/{proxy/clients → clients}/megatron_client.py +13 -7
  393. helm/clients/mistral_client.py +134 -0
  394. helm/clients/moderation_api_client.py +109 -0
  395. helm/clients/open_lm_client.py +43 -0
  396. helm/clients/openai_client.py +302 -0
  397. helm/{proxy/clients → clients}/palmyra_client.py +15 -12
  398. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  399. helm/clients/simple_client.py +64 -0
  400. helm/{proxy/clients → clients}/test_auto_client.py +15 -15
  401. helm/clients/test_client.py +100 -0
  402. helm/clients/test_huggingface_client.py +70 -0
  403. helm/clients/test_simple_client.py +19 -0
  404. helm/{proxy/clients → clients}/test_together_client.py +23 -12
  405. helm/{proxy/clients → clients}/together_client.py +18 -71
  406. helm/clients/vertexai_client.py +391 -0
  407. helm/clients/vision_language/__init__.py +0 -0
  408. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  409. helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
  410. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  411. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  412. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  413. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  414. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  415. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  416. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  417. helm/clients/vision_language/open_flamingo_client.py +155 -0
  418. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  419. helm/clients/vllm_client.py +46 -0
  420. helm/common/cache.py +24 -179
  421. helm/common/cache_backend_config.py +47 -0
  422. helm/common/clip_score_request.py +41 -0
  423. helm/common/concurrency.py +32 -0
  424. helm/common/credentials_utils.py +28 -0
  425. helm/common/file_caches/__init__.py +0 -0
  426. helm/common/file_caches/file_cache.py +16 -0
  427. helm/common/file_caches/local_file_cache.py +61 -0
  428. helm/common/file_caches/test_local_file_cache.py +25 -0
  429. helm/common/file_upload_request.py +27 -0
  430. helm/common/general.py +29 -10
  431. helm/common/image_generation_parameters.py +25 -0
  432. helm/common/images_utils.py +24 -1
  433. helm/common/key_value_store.py +113 -0
  434. helm/common/media_object.py +13 -0
  435. helm/common/moderations_api_request.py +71 -0
  436. helm/common/mongo_key_value_store.py +88 -0
  437. helm/common/multimodal_request_utils.py +31 -0
  438. helm/common/nudity_check_request.py +29 -0
  439. helm/common/object_spec.py +2 -2
  440. helm/common/request.py +36 -27
  441. helm/common/test_general.py +6 -0
  442. helm/common/tokenization_request.py +6 -3
  443. helm/config/__init__.py +0 -0
  444. helm/config/model_deployments.yaml +1942 -0
  445. helm/config/model_metadata.yaml +2201 -0
  446. helm/config/tokenizer_configs.yaml +362 -0
  447. helm/proxy/accounts.py +31 -4
  448. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  449. helm/proxy/critique/model_critique_client.py +13 -5
  450. helm/proxy/example_queries.py +29 -17
  451. helm/proxy/retry.py +8 -2
  452. helm/proxy/server.py +77 -5
  453. helm/proxy/services/remote_service.py +31 -0
  454. helm/proxy/services/server_service.py +103 -20
  455. helm/proxy/services/service.py +34 -2
  456. helm/proxy/services/test_remote_service.py +7 -6
  457. helm/proxy/services/test_service.py +27 -18
  458. helm/proxy/test_accounts.py +32 -0
  459. helm/proxy/token_counters/auto_token_counter.py +37 -37
  460. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  461. helm/proxy/token_counters/token_counter.py +3 -5
  462. helm/py.typed +0 -0
  463. helm/tokenizers/__init__.py +0 -0
  464. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  465. helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
  466. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
  467. helm/tokenizers/auto_tokenizer.py +93 -0
  468. helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
  469. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  470. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  471. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
  472. helm/tokenizers/simple_tokenizer.py +33 -0
  473. helm/tokenizers/test_anthropic_tokenizer.py +82 -0
  474. helm/tokenizers/test_huggingface_tokenizer.py +136 -0
  475. helm/tokenizers/test_simple_tokenizer.py +33 -0
  476. helm/tokenizers/vertexai_tokenizer.py +97 -0
  477. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  478. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  479. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  480. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  481. crfm_helm-0.3.0.dist-info/RECORD +0 -396
  482. helm/benchmark/vlm_run_specs.py +0 -71
  483. helm/benchmark/window_services/anthropic_window_service.py +0 -68
  484. helm/benchmark/window_services/bloom_window_service.py +0 -35
  485. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  486. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  487. helm/benchmark/window_services/gptj_window_service.py +0 -38
  488. helm/benchmark/window_services/gptneox_window_service.py +0 -41
  489. helm/benchmark/window_services/http_model_window_service.py +0 -28
  490. helm/benchmark/window_services/huggingface_window_service.py +0 -59
  491. helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
  492. helm/benchmark/window_services/llama_window_service.py +0 -28
  493. helm/benchmark/window_services/luminous_window_service.py +0 -67
  494. helm/benchmark/window_services/megatron_window_service.py +0 -10
  495. helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
  496. helm/benchmark/window_services/openai_window_service.py +0 -13
  497. helm/benchmark/window_services/opt_window_service.py +0 -35
  498. helm/benchmark/window_services/palmyra_window_service.py +0 -45
  499. helm/benchmark/window_services/remote_window_service.py +0 -48
  500. helm/benchmark/window_services/santacoder_window_service.py +0 -27
  501. helm/benchmark/window_services/starcoder_window_service.py +0 -27
  502. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  503. helm/benchmark/window_services/t511b_window_service.py +0 -30
  504. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  505. helm/benchmark/window_services/ul2_window_service.py +0 -30
  506. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  507. helm/benchmark/window_services/wider_openai_window_service.py +0 -52
  508. helm/proxy/clients/aleph_alpha_client.py +0 -99
  509. helm/proxy/clients/auto_client.py +0 -461
  510. helm/proxy/clients/goose_ai_client.py +0 -100
  511. helm/proxy/clients/microsoft_client.py +0 -182
  512. helm/proxy/clients/openai_client.py +0 -206
  513. helm/proxy/clients/remote_model_registry.py +0 -28
  514. helm/proxy/clients/simple_client.py +0 -61
  515. helm/proxy/clients/test_anthropic_client.py +0 -63
  516. helm/proxy/clients/test_client.py +0 -31
  517. helm/proxy/clients/test_huggingface_client.py +0 -87
  518. helm/proxy/models.py +0 -963
  519. helm/proxy/test_models.py +0 -27
  520. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  521. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  522. helm/proxy/token_counters/free_token_counter.py +0 -12
  523. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  524. helm/proxy/token_counters/openai_token_counter.py +0 -22
  525. helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
  526. helm/proxy/token_counters/test_openai_token_counter.py +0 -79
  527. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  528. helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
  529. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  530. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  531. {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  532. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  533. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  534. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  535. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  536. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  537. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  538. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  539. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  540. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  541. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  542. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  543. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  544. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  545. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  546. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,1834 @@
1
+ # coding=utf-8
2
+ # Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model. """
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Any, Dict, Optional, Tuple
20
+
21
+ from transformers.modeling_flax_outputs import (
22
+ FlaxBaseModelOutput,
23
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
24
+ FlaxCausalLMOutputWithCrossAttentions,
25
+ FlaxSeq2SeqLMOutput,
26
+ )
27
+ from transformers.modeling_flax_utils import ACT2FN
28
+ from transformers.models.bart.modeling_flax_bart import (
29
+ FlaxBartAttention,
30
+ FlaxBartForConditionalGeneration,
31
+ FlaxBartForConditionalGenerationModule,
32
+ FlaxBartModule,
33
+ )
34
+ from transformers.utils import ModelOutput, logging
35
+ from transformers.generation.configuration_utils import GenerationConfig
36
+
37
+ from helm.common.optional_dependencies import handle_module_not_found_error
38
+ from .configuration import DalleBartConfig
39
+ from .utils import PretrainedFromWandbMixin
40
+
41
+ try:
42
+ import flax
43
+ import flax.linen as nn
44
+ import jax
45
+ import jax.numpy as jnp
46
+ from einops import rearrange
47
+ from flax.core.frozen_dict import unfreeze
48
+ from flax.linen import combine_masks, make_causal_mask
49
+ from flax.linen import partitioning as nn_partitioning
50
+ from flax.linen.linear import PrecisionLike
51
+ from flax.traverse_util import flatten_dict, unflatten_dict
52
+ from jax import custom_jvp, lax
53
+ from jax.random import PRNGKey
54
+ except ModuleNotFoundError as e:
55
+ handle_module_not_found_error(e, ["heim"])
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+ remat = nn_partitioning.remat
61
+
62
+
63
+ def smelu(beta: Any = 1.0):
64
+ """
65
+ Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
66
+ https://arxiv.org/abs/2202.06499
67
+ """
68
+
69
+ @custom_jvp
70
+ @jax.jit
71
+ def _smelu(x: Any) -> Any:
72
+ x = jnp.where(x <= -beta, 0.0, x)
73
+ return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
74
+
75
+ _smelu.defjvps(
76
+ lambda g, ans, x: lax.select(
77
+ x == -beta,
78
+ lax.full_like(g, 0),
79
+ lax.select(x == beta, lax.full_like(g, 1), g),
80
+ )
81
+ )
82
+ return _smelu
83
+
84
+
85
+ ACT2FN.update({"smelu": smelu()})
86
+
87
+
88
+ # deepnet initialization
89
+ def deepnet_init(init_std, gain=1):
90
+ init = jax.nn.initializers.normal(init_std)
91
+
92
+ def _init(*args, **kwargs):
93
+ return gain * init(*args, **kwargs)
94
+
95
+ return _init
96
+
97
+
98
+ # deepnet gain
99
+ deepnet_gain = {
100
+ "encoder": {
101
+ "alpha": lambda config: 0.81 * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
102
+ "beta": lambda config: 0.87 * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
103
+ },
104
+ "decoder": {
105
+ "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
106
+ "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
107
+ },
108
+ }
109
+
110
+ # subln gain
111
+ subln_gain = {
112
+ "encoder": lambda config: math.sqrt(
113
+ 1.0 / 3.0 * math.log(3 * config.decoder_layers) * math.log(2 * config.encoder_layers)
114
+ ),
115
+ "decoder": lambda config: math.sqrt(math.log(3 * config.decoder_layers)),
116
+ }
117
+
118
+
119
+ class RMSNorm(nn.Module):
120
+ """
121
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
122
+
123
+ Adapted from flax.linen.LayerNorm
124
+ """
125
+
126
+ epsilon: float = 1e-6
127
+ dtype: Any = jnp.float32
128
+ param_dtype: Any = jnp.float32
129
+ use_scale: bool = True
130
+ scale_init: Any = jax.nn.initializers.ones
131
+
132
+ @nn.compact
133
+ def __call__(self, x):
134
+ reduction_axes = (-1,)
135
+ feature_axes = (-1,)
136
+
137
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
138
+
139
+ return self._normalize(
140
+ self,
141
+ x,
142
+ rms_sq,
143
+ reduction_axes,
144
+ feature_axes,
145
+ self.dtype,
146
+ self.param_dtype,
147
+ self.epsilon,
148
+ self.use_scale,
149
+ self.scale_init,
150
+ )
151
+
152
+ def _compute_rms_sq(self, x, axes):
153
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
154
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
155
+ return rms_sq
156
+
157
+ def _normalize(
158
+ self,
159
+ mdl,
160
+ x,
161
+ rms_sq,
162
+ reduction_axes,
163
+ feature_axes,
164
+ dtype,
165
+ param_dtype,
166
+ epsilon,
167
+ use_scale,
168
+ scale_init,
169
+ ):
170
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
171
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
172
+ stats_shape = list(x.shape)
173
+ for axis in reduction_axes:
174
+ stats_shape[axis] = 1
175
+ rms_sq = rms_sq.reshape(stats_shape)
176
+ feature_shape = [1] * x.ndim
177
+ reduced_feature_shape = []
178
+ for ax in feature_axes:
179
+ feature_shape[ax] = x.shape[ax]
180
+ reduced_feature_shape.append(x.shape[ax])
181
+ mul = lax.rsqrt(rms_sq + epsilon)
182
+ if use_scale:
183
+ scale = mdl.param("scale", scale_init, reduced_feature_shape, param_dtype).reshape(feature_shape)
184
+ mul *= scale
185
+ y = mul * x
186
+ return jnp.asarray(y, dtype)
187
+
188
+
189
+ def norm(type, *args, **kwargs):
190
+ if type == "rmsnorm":
191
+ return RMSNorm(*args, **kwargs)
192
+ elif type == "layernorm":
193
+ return nn.LayerNorm(*args, **kwargs)
194
+ else:
195
+ raise ValueError(f"Unknown norm type {type}")
196
+
197
+
198
+ def dot_product_attention_weights(
199
+ query: Any,
200
+ key: Any,
201
+ bias: Optional[Any] = None,
202
+ mask: Optional[Any] = None,
203
+ embed_pos: Optional[Any] = None,
204
+ broadcast_dropout: bool = True,
205
+ dropout_rng: Optional[PRNGKey] = None,
206
+ dropout_rate: float = 0.0,
207
+ deterministic: bool = False,
208
+ dtype: Any = jnp.float32,
209
+ precision: PrecisionLike = None,
210
+ sinkhorn_iters: int = 1,
211
+ is_encoder: bool = False,
212
+ tau=None,
213
+ ):
214
+ """
215
+ Computes dot-product attention weights given query and key.
216
+ mask is included into the bias.
217
+
218
+ Adapted from flax.linen.attention.dot_product_attention_weights"
219
+ """
220
+ assert query.ndim == key.ndim, "q, k must have same rank."
221
+ assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
222
+ assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
223
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
224
+
225
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
226
+ attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
227
+
228
+ # divide by tau (used in Swin v2)
229
+ if tau is not None:
230
+ attn_weights = attn_weights / tau
231
+ else:
232
+ depth = query.shape[-1]
233
+ attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
234
+
235
+ # apply attention bias: masking, dropout, proximity bias, etc.
236
+ if bias is not None:
237
+ attn_weights = attn_weights + bias
238
+
239
+ # add relative position
240
+ if embed_pos is not None:
241
+ attn_weights = attn_weights + embed_pos
242
+
243
+ # normalize the attention weights
244
+ if not is_encoder or sinkhorn_iters == 1:
245
+ # sinkhorn does not work for causal (leaks info of future tokens into past)
246
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
247
+ else:
248
+ # adapted from https://github.com/lucidrains/sinkhorn-transformer
249
+ for i in range(sinkhorn_iters):
250
+ # when causal, some attn_weights have been set to -inf through bias
251
+ if i % 2 == 0:
252
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
253
+ else:
254
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
255
+ if mask is not None:
256
+ attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
257
+ attn_weights = jnp.exp(attn_weights).astype(dtype)
258
+
259
+ # apply attention dropout
260
+ if not deterministic and dropout_rate > 0.0:
261
+ keep_prob = 1.0 - dropout_rate
262
+ if broadcast_dropout:
263
+ # dropout is broadcast across the batch + head dimensions
264
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
265
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
266
+ else:
267
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
268
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
269
+ attn_weights = attn_weights * multiplier
270
+
271
+ return attn_weights
272
+
273
+
274
+ class FlaxBartAttention(FlaxBartAttention):
275
+ """
276
+ Edits:
277
+ - causal mask is used only in decoder and considers image_length
278
+ - scale attention heads per NormFormer paper
279
+ """
280
+
281
+ is_encoder: bool = False
282
+ is_cross_attention: bool = False
283
+ q_length: int = None
284
+ k_length: int = None
285
+
286
+ def setup(self) -> None:
287
+ self.head_dim = self.embed_dim // self.num_heads
288
+ if self.head_dim * self.num_heads != self.embed_dim:
289
+ raise ValueError(
290
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
291
+ f" and `num_heads`: {self.num_heads})."
292
+ )
293
+
294
+ dense = partial(
295
+ nn.Dense,
296
+ self.embed_dim,
297
+ use_bias=self.bias,
298
+ dtype=self.dtype,
299
+ )
300
+
301
+ if self.config.use_deepnet_scaling:
302
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](self.config)
303
+ elif self.config.use_subln_init and not self.is_cross_attention:
304
+ gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
305
+
306
+ self.q_proj = dense(kernel_init=jax.nn.initializers.normal(self.config.init_std))
307
+ self.k_proj = dense(kernel_init=jax.nn.initializers.normal(self.config.init_std))
308
+ self.v_proj = dense(
309
+ kernel_init=(
310
+ deepnet_init(self.config.init_std, gain)
311
+ if (self.config.use_deepnet_scaling or (self.config.use_subln_init and not self.is_cross_attention))
312
+ else jax.nn.initializers.normal(self.config.init_std)
313
+ )
314
+ )
315
+ self.out_proj = dense(
316
+ kernel_init=(
317
+ deepnet_init(self.config.init_std, gain)
318
+ if (self.config.use_deepnet_scaling or (self.config.use_subln_init and not self.is_cross_attention))
319
+ else jax.nn.initializers.normal(self.config.init_std)
320
+ )
321
+ )
322
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
323
+
324
+ if self.config.use_head_scale:
325
+ self.head_scale = self.param("head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1))
326
+
327
+ if self.config.use_cosine_attention:
328
+ # TODO: try using a learnt scale, somehow it immediately diverges in my experiments
329
+ self.tau = self.config.tau_init
330
+
331
+ if self.config.use_swin_position_embeddings:
332
+ self.rel_bias = nn.Embed(
333
+ self.q_length,
334
+ self.k_length * self.num_heads,
335
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
336
+ )
337
+
338
+ if self.causal:
339
+ # used only in decoder
340
+ self.causal_mask = make_causal_mask(jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool")
341
+
342
+ if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
343
+ self.mid_layernorm = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)
344
+
345
+ def __call__(
346
+ self,
347
+ hidden_states: jnp.ndarray,
348
+ key_value_states: Optional[jnp.ndarray] = None,
349
+ attention_mask: Optional[jnp.ndarray] = None,
350
+ init_cache: bool = False,
351
+ deterministic: bool = True,
352
+ ) -> Tuple[jnp.ndarray]:
353
+ """Input shape: Batch x Time x Channel"""
354
+
355
+ # if key_value_states are provided this layer is used as a cross-attention layer
356
+ # for the decoder
357
+ is_cross_attention = key_value_states is not None
358
+ batch_size = hidden_states.shape[0]
359
+
360
+ # get query proj
361
+ query_states = self.q_proj(hidden_states)
362
+ # get key, value proj
363
+ if is_cross_attention:
364
+ # cross_attentions
365
+ key_states = self.k_proj(key_value_states)
366
+ value_states = self.v_proj(key_value_states)
367
+ else:
368
+ # self_attention
369
+ key_states = self.k_proj(hidden_states)
370
+ value_states = self.v_proj(hidden_states)
371
+
372
+ query_states = self._split_heads(query_states)
373
+ key_states = self._split_heads(key_states)
374
+ value_states = self._split_heads(value_states)
375
+
376
+ # handle cache prepare causal attention mask
377
+ if self.causal:
378
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
379
+ if self.has_variable("cache", "cached_key"):
380
+ mask_shift = self.variables["cache"]["cache_index"]
381
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
382
+ causal_mask = lax.dynamic_slice(
383
+ self.causal_mask,
384
+ (0, 0, mask_shift, 0),
385
+ (1, 1, query_length, max_decoder_length),
386
+ )
387
+ else:
388
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
389
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
390
+
391
+ # combine masks if needed
392
+ if attention_mask is not None and self.causal:
393
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
394
+ attention_mask = combine_masks(attention_mask, causal_mask)
395
+ elif self.causal:
396
+ attention_mask = causal_mask
397
+ elif attention_mask is not None:
398
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
399
+
400
+ # During fast autoregressive decoding, we feed one position at a time,
401
+ # and cache the keys and values step by step.
402
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
403
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
404
+ key_states, value_states, query_states, attention_mask
405
+ )
406
+
407
+ # Convert the boolean attention mask to an attention bias.
408
+ if attention_mask is not None:
409
+ # attention mask in the form of attention bias
410
+ attention_bias = lax.select(
411
+ attention_mask > 0,
412
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
413
+ jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
414
+ )
415
+ else:
416
+ attention_bias = None
417
+
418
+ dropout_rng = None
419
+ if not deterministic and self.dropout > 0.0:
420
+ dropout_rng = self.make_rng("dropout")
421
+
422
+ if self.config.use_cosine_attention:
423
+ # normalize q and k
424
+ query_states = query_states / (jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8)
425
+ key_states = key_states / (jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8)
426
+
427
+ # relative position embeddings
428
+ if self.config.use_swin_position_embeddings:
429
+ position_ids = jnp.arange(self.q_length)
430
+ embed_pos = self.rel_bias(position_ids)
431
+ embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
432
+ else:
433
+ embed_pos = None
434
+
435
+ tau = self.tau if self.config.use_cosine_attention else None
436
+ attn_weights = dot_product_attention_weights(
437
+ query_states,
438
+ key_states,
439
+ bias=attention_bias,
440
+ mask=attention_mask,
441
+ embed_pos=embed_pos,
442
+ dropout_rng=dropout_rng,
443
+ dropout_rate=self.dropout,
444
+ broadcast_dropout=True,
445
+ deterministic=deterministic,
446
+ dtype=self.dtype,
447
+ precision=None,
448
+ sinkhorn_iters=self.config.sinkhorn_iters,
449
+ is_encoder=self.is_encoder,
450
+ tau=tau,
451
+ )
452
+
453
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
454
+ if self.config.use_head_scale:
455
+ # per Normformer
456
+ attn_output = attn_output * self.head_scale
457
+ attn_output = self._merge_heads(attn_output)
458
+
459
+ if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
460
+ attn_output = self.mid_layernorm(attn_output)
461
+
462
+ attn_output = self.out_proj(attn_output)
463
+
464
+ return attn_output, attn_weights
465
+
466
+
467
+ class GLU(nn.Module):
468
+ """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
469
+
470
+ config: DalleBartConfig
471
+ ffn_dim: int
472
+ embed_dim: int
473
+ dtype: jnp.dtype = jnp.float32
474
+ is_encoder: bool = False
475
+
476
+ @nn.compact
477
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
478
+
479
+ if self.config.use_deepnet_scaling:
480
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](self.config)
481
+ elif self.config.use_subln_init:
482
+ gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
483
+
484
+ if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
485
+ x = norm(
486
+ self.config.ln_type,
487
+ dtype=self.dtype,
488
+ epsilon=1e-05,
489
+ use_scale=self.config.force_ln_scale,
490
+ )(x)
491
+ w = nn.Dense(
492
+ self.ffn_dim,
493
+ dtype=self.dtype,
494
+ use_bias=self.config.use_bias,
495
+ kernel_init=(
496
+ deepnet_init(self.config.init_std, gain)
497
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
498
+ else jax.nn.initializers.normal(self.config.init_std)
499
+ ),
500
+ )(x)
501
+ w = ACT2FN[self.config.activation_function](w)
502
+ v = nn.Dense(
503
+ self.ffn_dim,
504
+ dtype=self.dtype,
505
+ use_bias=self.config.use_bias,
506
+ kernel_init=(
507
+ deepnet_init(self.config.init_std, gain)
508
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
509
+ else jax.nn.initializers.normal(self.config.init_std)
510
+ ),
511
+ )(x)
512
+ x = w * v
513
+ if self.config.ln_positions in ["normformer", "subln"]:
514
+ x = norm(
515
+ self.config.ln_type,
516
+ dtype=self.dtype,
517
+ epsilon=1e-05,
518
+ use_scale=self.config.force_ln_scale,
519
+ )(x)
520
+ x = nn.Dropout(rate=self.config.activation_dropout)(x, deterministic=deterministic)
521
+
522
+ x = nn.Dense(
523
+ self.embed_dim,
524
+ dtype=self.dtype,
525
+ use_bias=self.config.use_bias,
526
+ kernel_init=(
527
+ deepnet_init(self.config.init_std, gain)
528
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
529
+ else jax.nn.initializers.normal(self.config.init_std)
530
+ ),
531
+ )(x)
532
+ if self.config.ln_positions in ["swinv2", "cogview"]:
533
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
534
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
535
+ return x
536
+
537
+
538
+ class FFN(nn.Module):
539
+ """Simple FFN layer"""
540
+
541
+ config: DalleBartConfig
542
+ ffn_dim: int
543
+ embed_dim: int
544
+ dtype: jnp.dtype = jnp.float32
545
+ is_encoder: bool = False
546
+
547
+ @nn.compact
548
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
549
+
550
+ if self.config.use_deepnet_scaling:
551
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](self.config)
552
+ elif self.config.use_subln_init:
553
+ gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
554
+ if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
555
+ x = norm(
556
+ self.config.ln_type,
557
+ dtype=self.dtype,
558
+ epsilon=1e-05,
559
+ use_scale=self.config.force_ln_scale,
560
+ )(x)
561
+ x = nn.Dense(
562
+ self.ffn_dim,
563
+ dtype=self.dtype,
564
+ use_bias=self.config.use_bias,
565
+ kernel_init=(
566
+ deepnet_init(self.config.init_std, gain)
567
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
568
+ else jax.nn.initializers.normal(self.config.init_std)
569
+ ),
570
+ )(x)
571
+ x = ACT2FN[self.config.activation_function](x)
572
+ if self.config.ln_positions in ["normformer", "subln"]:
573
+ x = norm(
574
+ self.config.ln_type,
575
+ dtype=self.dtype,
576
+ epsilon=1e-05,
577
+ use_scale=self.config.force_ln_scale,
578
+ )(x)
579
+ x = nn.Dropout(rate=self.config.activation_dropout)(x, deterministic=deterministic)
580
+ x = nn.Dense(
581
+ self.embed_dim,
582
+ dtype=self.dtype,
583
+ use_bias=self.config.use_bias,
584
+ kernel_init=(
585
+ deepnet_init(self.config.init_std, gain)
586
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
587
+ else jax.nn.initializers.normal(self.config.init_std)
588
+ ),
589
+ )(x)
590
+ if self.config.ln_positions in ["swinv2", "cogview"]:
591
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
592
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
593
+ return x
594
+
595
+
596
+ class FlaxBartEncoderLayer(nn.Module):
597
+ """
598
+ Edits:
599
+ - no bias
600
+ - use custom FlaxBartAttention
601
+ """
602
+
603
+ config: DalleBartConfig
604
+ dtype: jnp.dtype = jnp.float32
605
+ add_norm: bool = False
606
+ use_scale: bool = True
607
+
608
+ @nn.compact
609
+ def __call__(
610
+ self,
611
+ hidden_states: jnp.ndarray,
612
+ attention_mask: jnp.ndarray,
613
+ output_attentions: bool = True,
614
+ deterministic: bool = True,
615
+ ) -> Tuple[jnp.ndarray]:
616
+
617
+ if self.config.use_scan:
618
+ hidden_states = hidden_states[0]
619
+
620
+ res_gain = deepnet_gain["encoder"]["alpha"](self.config) if self.config.use_deepnet_scaling else 1
621
+
622
+ embed_dim = self.config.d_model
623
+ residual = hidden_states
624
+ if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
625
+ hidden_states = norm(
626
+ self.config.ln_type,
627
+ dtype=self.dtype,
628
+ epsilon=1e-05,
629
+ use_scale=self.config.force_ln_scale,
630
+ )(hidden_states)
631
+ hidden_states, attn_weights = FlaxBartAttention(
632
+ config=self.config,
633
+ embed_dim=embed_dim,
634
+ num_heads=self.config.encoder_attention_heads,
635
+ dropout=self.config.attention_dropout,
636
+ bias=self.config.use_bias,
637
+ dtype=self.dtype,
638
+ is_encoder=True,
639
+ is_cross_attention=False,
640
+ q_length=self.config.max_text_length,
641
+ k_length=self.config.max_text_length,
642
+ )(hidden_states=hidden_states, attention_mask=attention_mask)
643
+
644
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
645
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
646
+ hidden_states = nn.Dropout(rate=self.config.dropout)(hidden_states, deterministic=deterministic)
647
+ hidden_states = residual * res_gain + hidden_states
648
+ if self.config.ln_positions in ["postln"]:
649
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
650
+
651
+ residual = hidden_states
652
+ ff_block = (
653
+ GLU(
654
+ config=self.config,
655
+ ffn_dim=self.config.encoder_ffn_dim,
656
+ embed_dim=embed_dim,
657
+ dtype=self.dtype,
658
+ is_encoder=True,
659
+ )
660
+ if self.config.use_glu
661
+ else FFN(
662
+ config=self.config,
663
+ ffn_dim=self.config.encoder_ffn_dim,
664
+ embed_dim=embed_dim,
665
+ dtype=self.dtype,
666
+ is_encoder=True,
667
+ )
668
+ )
669
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
670
+ hidden_states = residual * res_gain + hidden_states
671
+ if self.add_norm:
672
+ use_scale = self.use_scale or self.config.force_ln_scale
673
+ hidden_states = norm(
674
+ self.config.ln_type,
675
+ dtype=self.dtype,
676
+ epsilon=1e-05,
677
+ use_scale=use_scale,
678
+ )(hidden_states)
679
+
680
+ outputs = (hidden_states,)
681
+
682
+ if output_attentions:
683
+ outputs += (attn_weights,)
684
+
685
+ if self.config.use_scan:
686
+ outputs = (outputs, None)
687
+
688
+ return outputs
689
+
690
+
691
+ class FlaxBartDecoderLayer(nn.Module):
692
+ """
693
+ Edits:
694
+ - no bias
695
+ - use custom FlaxBartAttention
696
+ """
697
+
698
+ config: DalleBartConfig
699
+ dtype: jnp.dtype = jnp.float32
700
+ add_norm: bool = False
701
+ use_scale: bool = True
702
+
703
+ @nn.compact
704
+ def __call__(
705
+ self,
706
+ hidden_states: jnp.ndarray,
707
+ attention_mask: jnp.ndarray,
708
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
709
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
710
+ init_cache: bool = False,
711
+ output_attentions: bool = True,
712
+ deterministic: bool = True,
713
+ ) -> Tuple[jnp.ndarray]:
714
+
715
+ if self.config.use_scan:
716
+ hidden_states = hidden_states[0]
717
+
718
+ res_gain = deepnet_gain["decoder"]["alpha"](self.config) if self.config.use_deepnet_scaling else 1
719
+
720
+ embed_dim = self.config.d_model
721
+ residual = hidden_states
722
+
723
+ # Self Attention
724
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
725
+ hidden_states = norm(
726
+ self.config.ln_type,
727
+ dtype=self.dtype,
728
+ epsilon=1e-05,
729
+ use_scale=self.config.force_ln_scale,
730
+ )(hidden_states)
731
+ hidden_states, attn_weights = FlaxBartAttention(
732
+ config=self.config,
733
+ embed_dim=embed_dim,
734
+ num_heads=self.config.decoder_attention_heads,
735
+ dropout=self.config.attention_dropout,
736
+ causal=True,
737
+ bias=self.config.use_bias,
738
+ dtype=self.dtype,
739
+ is_encoder=False,
740
+ is_cross_attention=False,
741
+ q_length=self.config.image_length,
742
+ k_length=self.config.image_length,
743
+ )(
744
+ hidden_states=hidden_states,
745
+ attention_mask=attention_mask,
746
+ init_cache=init_cache,
747
+ )
748
+
749
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
750
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
751
+ hidden_states = nn.Dropout(rate=self.config.dropout)(hidden_states, deterministic=deterministic)
752
+ hidden_states = residual * res_gain + hidden_states
753
+ if self.config.ln_positions in ["postln"]:
754
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
755
+
756
+ # Cross Attention
757
+ cross_attn_weights = None
758
+ if encoder_hidden_states is not None:
759
+ residual = hidden_states
760
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
761
+ hidden_states = norm(
762
+ self.config.ln_type,
763
+ dtype=self.dtype,
764
+ epsilon=1e-05,
765
+ use_scale=self.config.force_ln_scale,
766
+ )(hidden_states)
767
+ hidden_states, cross_attn_weights = FlaxBartAttention(
768
+ config=self.config,
769
+ embed_dim=embed_dim,
770
+ num_heads=self.config.decoder_attention_heads,
771
+ dropout=self.config.attention_dropout,
772
+ bias=self.config.use_bias,
773
+ dtype=self.dtype,
774
+ is_encoder=False,
775
+ is_cross_attention=True,
776
+ q_length=self.config.image_length,
777
+ k_length=self.config.max_text_length,
778
+ )(
779
+ hidden_states=hidden_states,
780
+ key_value_states=encoder_hidden_states,
781
+ attention_mask=encoder_attention_mask,
782
+ )
783
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
784
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
785
+ hidden_states = nn.Dropout(rate=self.config.dropout)(hidden_states, deterministic=deterministic)
786
+ hidden_states = residual * res_gain + hidden_states
787
+ if self.config.ln_positions in ["postln"]:
788
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
789
+
790
+ # Feed forward
791
+ residual = hidden_states
792
+ ff_block = (
793
+ GLU(
794
+ config=self.config,
795
+ ffn_dim=self.config.decoder_ffn_dim,
796
+ embed_dim=embed_dim,
797
+ dtype=self.dtype,
798
+ is_encoder=False,
799
+ )
800
+ if self.config.use_glu
801
+ else FFN(
802
+ config=self.config,
803
+ ffn_dim=self.config.decoder_ffn_dim,
804
+ embed_dim=embed_dim,
805
+ dtype=self.dtype,
806
+ is_encoder=False,
807
+ )
808
+ )
809
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
810
+ hidden_states = residual * res_gain + hidden_states
811
+ if self.add_norm:
812
+ use_scale = self.use_scale or self.config.force_ln_scale
813
+ hidden_states = norm(
814
+ self.config.ln_type,
815
+ dtype=self.dtype,
816
+ epsilon=1e-05,
817
+ use_scale=use_scale,
818
+ )(hidden_states)
819
+
820
+ outputs = (hidden_states,)
821
+
822
+ if output_attentions:
823
+ outputs += (attn_weights, cross_attn_weights)
824
+
825
+ if self.config.use_scan:
826
+ outputs = (outputs, None)
827
+
828
+ return outputs
829
+
830
+
831
+ class FlaxBartEncoderLayerCollection(nn.Module):
832
+ config: DalleBartConfig
833
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
834
+ """
835
+ Edits:
836
+ - use custom FlaxBartEncoderLayer
837
+ - allow Gradient Checkpointing (nn.remat)
838
+ """
839
+
840
+ @nn.compact
841
+ def __call__(
842
+ self,
843
+ hidden_states,
844
+ attention_mask,
845
+ deterministic: bool = True,
846
+ output_attentions: bool = False,
847
+ output_hidden_states: bool = False,
848
+ return_dict: bool = True,
849
+ ):
850
+ all_hidden_states = () if output_hidden_states else None
851
+ all_self_attns = () if output_attentions else None
852
+
853
+ n_layers = self.config.encoder_layers
854
+ layer = (
855
+ remat(
856
+ FlaxBartEncoderLayer,
857
+ static_argnums=(2, 3),
858
+ prevent_cse=not self.config.use_scan,
859
+ )
860
+ if self.config.gradient_checkpointing
861
+ else FlaxBartEncoderLayer
862
+ )
863
+
864
+ if self.config.use_scan:
865
+ # all blocks are the same so we use nn.scan
866
+ assert not output_attentions, "cannot scan with output_attentions"
867
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
868
+ hidden_states = (hidden_states,)
869
+ # we use a scale on all norms (even last layer) to allow scanning
870
+ hidden_states, _ = nn.scan(
871
+ layer,
872
+ variable_axes={"params": 0, "cache": 0},
873
+ split_rngs={"params": True, "dropout": True},
874
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
875
+ length=n_layers,
876
+ )(
877
+ self.config,
878
+ dtype=self.dtype,
879
+ add_norm=self.config.ln_positions == "postln",
880
+ name="FlaxBartEncoderLayers",
881
+ )(
882
+ hidden_states,
883
+ attention_mask,
884
+ output_attentions,
885
+ deterministic,
886
+ )
887
+ hidden_states = hidden_states[0]
888
+ else:
889
+ for i in range(n_layers):
890
+ if output_hidden_states:
891
+ all_hidden_states += (hidden_states,)
892
+ # final layernorm on the output of the last layer
893
+ # or every 6 layers for Swin v2
894
+ add_norm = self.config.ln_positions == "postln" or (
895
+ self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0) and (i != n_layers - 1)
896
+ )
897
+ # we don't need to scale the norm for the last layer
898
+ use_scale = i != n_layers - 1
899
+ layer_outputs = layer(
900
+ self.config,
901
+ dtype=self.dtype,
902
+ add_norm=add_norm,
903
+ use_scale=use_scale,
904
+ name=f"FlaxBartEncoderLayer_{i}",
905
+ )(
906
+ hidden_states,
907
+ attention_mask,
908
+ output_attentions,
909
+ deterministic,
910
+ )
911
+ hidden_states = layer_outputs[0]
912
+ if output_attentions:
913
+ all_self_attns += (layer_outputs[1],)
914
+
915
+ # add hidden states from the last layer
916
+ if output_hidden_states:
917
+ all_hidden_states += (hidden_states,)
918
+
919
+ outputs = [
920
+ hidden_states,
921
+ all_hidden_states,
922
+ all_self_attns,
923
+ ]
924
+
925
+ if not return_dict:
926
+ return tuple(v for v in outputs if v is not None)
927
+
928
+ return FlaxBaseModelOutput(
929
+ last_hidden_state=hidden_states,
930
+ hidden_states=all_hidden_states,
931
+ attentions=all_self_attns,
932
+ )
933
+
934
+
935
+ class FlaxBartDecoderLayerCollection(nn.Module):
936
+ config: DalleBartConfig
937
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
938
+ """
939
+ Edits:
940
+ - use custom FlaxBartDecoderLayer
941
+ - allow Gradient Checkpointing (nn.remat)
942
+ """
943
+
944
+ @nn.compact
945
+ def __call__(
946
+ self,
947
+ hidden_states,
948
+ attention_mask,
949
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
950
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
951
+ deterministic: bool = True,
952
+ init_cache: bool = False,
953
+ output_attentions: bool = False,
954
+ output_hidden_states: bool = False,
955
+ return_dict: bool = True,
956
+ ):
957
+ # decoder layers
958
+ all_hidden_states = () if output_hidden_states else None
959
+ all_self_attns = () if output_attentions else None
960
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
961
+
962
+ n_layers = self.config.decoder_layers
963
+ layer = (
964
+ remat(
965
+ FlaxBartDecoderLayer,
966
+ static_argnums=(4, 5, 6),
967
+ prevent_cse=not self.config.use_scan,
968
+ )
969
+ if self.config.gradient_checkpointing
970
+ else FlaxBartDecoderLayer
971
+ )
972
+
973
+ if self.config.use_scan:
974
+ # all blocks are the same so we use nn.scan
975
+ assert not output_attentions, "cannot scan with output_attentions"
976
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
977
+ hidden_states = (hidden_states,)
978
+ # we use a scale on all norms (even last layer) to allow scanning
979
+ hidden_states, _ = nn.scan(
980
+ layer,
981
+ variable_axes={"params": 0, "cache": 0},
982
+ split_rngs={"params": True, "dropout": True},
983
+ in_axes=(
984
+ nn.broadcast,
985
+ nn.broadcast,
986
+ nn.broadcast,
987
+ nn.broadcast,
988
+ nn.broadcast,
989
+ nn.broadcast,
990
+ ),
991
+ length=n_layers,
992
+ )(
993
+ self.config,
994
+ dtype=self.dtype,
995
+ add_norm=self.config.ln_positions == "postln",
996
+ name="FlaxBartDecoderLayers",
997
+ )(
998
+ hidden_states,
999
+ attention_mask,
1000
+ encoder_hidden_states,
1001
+ encoder_attention_mask,
1002
+ init_cache,
1003
+ output_attentions,
1004
+ deterministic,
1005
+ )
1006
+ hidden_states = hidden_states[0]
1007
+
1008
+ else:
1009
+ for i in range(n_layers):
1010
+ if output_hidden_states:
1011
+ all_hidden_states += (hidden_states,)
1012
+ # final layernorm on the output of the last layer
1013
+ # or every 6 layers for Swin v2
1014
+ add_norm = self.config.ln_positions == "postln" or (
1015
+ self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0) and (i != n_layers - 1)
1016
+ )
1017
+ # we don't need to scale the norm for the last layer
1018
+ use_scale = i != n_layers - 1
1019
+ layer_outputs = layer(
1020
+ self.config,
1021
+ dtype=self.dtype,
1022
+ add_norm=add_norm,
1023
+ use_scale=use_scale,
1024
+ name=f"FlaxBartDecoderLayer_{i}",
1025
+ )(
1026
+ hidden_states,
1027
+ attention_mask,
1028
+ encoder_hidden_states,
1029
+ encoder_attention_mask,
1030
+ init_cache,
1031
+ output_attentions,
1032
+ deterministic,
1033
+ )
1034
+
1035
+ hidden_states = layer_outputs[0]
1036
+ if output_attentions:
1037
+ all_self_attns += (layer_outputs[1],)
1038
+
1039
+ if encoder_hidden_states is not None:
1040
+ all_cross_attentions += (layer_outputs[2],)
1041
+
1042
+ # add hidden states from the last decoder layer
1043
+ if output_hidden_states:
1044
+ all_hidden_states += (hidden_states,)
1045
+
1046
+ outputs = [
1047
+ hidden_states,
1048
+ all_hidden_states,
1049
+ all_self_attns,
1050
+ all_cross_attentions,
1051
+ ]
1052
+
1053
+ if not return_dict:
1054
+ return tuple(v for v in outputs if v is not None)
1055
+
1056
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1057
+ last_hidden_state=hidden_states,
1058
+ hidden_states=all_hidden_states,
1059
+ attentions=all_self_attns,
1060
+ cross_attentions=all_cross_attentions,
1061
+ )
1062
+
1063
+
1064
+ class FlaxBartEncoder(nn.Module):
1065
+ config: DalleBartConfig
1066
+ embed_tokens: nn.Embed
1067
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1068
+ """
1069
+ Edits:
1070
+ - offset set to 0 (no padding token)
1071
+ - use max_text_length instead of max_position_embeddings
1072
+ - use custom FlaxBartEncoderLayerCollection
1073
+ - embed_tokens cannot be None (issue at compile time)
1074
+ """
1075
+
1076
+ def setup(self):
1077
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1078
+
1079
+ embed_dim = self.config.d_model
1080
+ self.padding_idx = self.config.pad_token_id
1081
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
1082
+
1083
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1084
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1085
+ self.offset = 0
1086
+ if self.config.use_absolute_position_embeddings:
1087
+ self.embed_positions = nn.Embed(
1088
+ self.config.max_text_length + self.offset, # image length for BOS
1089
+ embed_dim,
1090
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1091
+ )
1092
+ self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
1093
+ self.layernorm_embedding = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)
1094
+
1095
+ # postln is already applied in every layer
1096
+ if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
1097
+ self.final_ln = norm(
1098
+ self.config.ln_type,
1099
+ dtype=self.dtype,
1100
+ epsilon=1e-05,
1101
+ use_scale=self.config.force_ln_scale,
1102
+ )
1103
+ else:
1104
+ self.final_ln = None
1105
+
1106
+ def __call__(
1107
+ self,
1108
+ input_ids,
1109
+ attention_mask,
1110
+ position_ids,
1111
+ output_attentions: bool = False,
1112
+ output_hidden_states: bool = False,
1113
+ return_dict: bool = True,
1114
+ deterministic: bool = True,
1115
+ ):
1116
+ input_shape = input_ids.shape
1117
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1118
+
1119
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1120
+
1121
+ if self.config.use_absolute_position_embeddings:
1122
+ embed_pos = self.embed_positions(position_ids + self.offset)
1123
+ hidden_states = hidden_states + embed_pos
1124
+
1125
+ hidden_states = self.layernorm_embedding(hidden_states)
1126
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1127
+
1128
+ outputs = self.layers(
1129
+ hidden_states,
1130
+ attention_mask,
1131
+ deterministic=deterministic,
1132
+ output_attentions=output_attentions,
1133
+ output_hidden_states=output_hidden_states,
1134
+ return_dict=return_dict,
1135
+ )
1136
+
1137
+ if self.final_ln is None:
1138
+ final_output = outputs[0]
1139
+ else:
1140
+ final_output = self.final_ln(outputs[0])
1141
+
1142
+ if not return_dict:
1143
+ return (final_output,) + outputs[1:]
1144
+
1145
+ return FlaxBaseModelOutput(
1146
+ last_hidden_state=final_output,
1147
+ hidden_states=outputs.hidden_states,
1148
+ attentions=outputs.attentions,
1149
+ )
1150
+
1151
+
1152
+ class FlaxBartDecoder(nn.Module):
1153
+ config: DalleBartConfig
1154
+ embed_tokens: nn.Embed
1155
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1156
+ """
1157
+ Edits:
1158
+ - offset set to 0 (no padding token)
1159
+ - use image_length instead of max_position_embeddings
1160
+ - use custom FlaxBartDecoderLayerCollection
1161
+ - embed_tokens cannot be None (issue at compile time)
1162
+ """
1163
+
1164
+ def setup(self):
1165
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1166
+
1167
+ embed_dim = self.config.d_model
1168
+ self.padding_idx = self.config.pad_token_id
1169
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
1170
+
1171
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1172
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1173
+ self.offset = 0
1174
+ if self.config.use_absolute_position_embeddings:
1175
+ self.embed_positions = nn.Embed(
1176
+ self.config.image_length + self.offset, # image length for BOS
1177
+ embed_dim,
1178
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1179
+ )
1180
+
1181
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1182
+ self.layernorm_embedding = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)
1183
+
1184
+ # postln is already applied in every layer
1185
+ if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1186
+ self.final_ln = norm(
1187
+ self.config.ln_type,
1188
+ dtype=self.dtype,
1189
+ epsilon=1e-05,
1190
+ use_scale=self.config.force_ln_scale,
1191
+ )
1192
+
1193
+ def __call__(
1194
+ self,
1195
+ input_ids,
1196
+ attention_mask,
1197
+ position_ids,
1198
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1199
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1200
+ init_cache: bool = False,
1201
+ output_attentions: bool = False,
1202
+ output_hidden_states: bool = False,
1203
+ return_dict: bool = True,
1204
+ deterministic: bool = True,
1205
+ ):
1206
+ input_shape = input_ids.shape
1207
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1208
+
1209
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1210
+
1211
+ if self.config.use_absolute_position_embeddings:
1212
+ embed_pos = self.embed_positions(position_ids + self.offset)
1213
+ hidden_states = hidden_states + embed_pos
1214
+
1215
+ hidden_states = self.layernorm_embedding(hidden_states)
1216
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1217
+
1218
+ outputs = self.layers(
1219
+ hidden_states,
1220
+ attention_mask,
1221
+ encoder_hidden_states,
1222
+ encoder_attention_mask,
1223
+ deterministic=deterministic,
1224
+ init_cache=init_cache,
1225
+ output_attentions=output_attentions,
1226
+ output_hidden_states=output_hidden_states,
1227
+ return_dict=return_dict,
1228
+ )
1229
+
1230
+ if self.final_ln is None:
1231
+ final_output = outputs[0]
1232
+ else:
1233
+ final_output = self.final_ln(outputs[0])
1234
+
1235
+ if not return_dict:
1236
+ return (final_output,) + outputs[1:]
1237
+
1238
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1239
+ last_hidden_state=final_output,
1240
+ hidden_states=outputs.hidden_states,
1241
+ attentions=outputs.attentions,
1242
+ cross_attentions=outputs.cross_attentions,
1243
+ )
1244
+
1245
+
1246
+ class FlaxBartModule(FlaxBartModule):
1247
+ """
1248
+ Edits
1249
+ - use custom FlaxBartEncoder & FlaxBartDecoder
1250
+ - use separate embeddings for Encoder & Decoder
1251
+ """
1252
+
1253
+ def setup(self):
1254
+ encoder_embed_tokens = nn.Embed(
1255
+ self.config.encoder_vocab_size,
1256
+ self.config.d_model,
1257
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1258
+ )
1259
+ decoder_embed_tokens = nn.Embed(
1260
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
1261
+ self.config.d_model,
1262
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1263
+ )
1264
+
1265
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens)
1266
+ self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens)
1267
+
1268
+
1269
+ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
1270
+ """
1271
+ Edits:
1272
+ - no bias
1273
+ - lm_head set to image_vocab_size + 1 (for BOS)
1274
+ - uses custom FlaxBartModule
1275
+ """
1276
+
1277
+ def setup(self):
1278
+ self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
1279
+ self.lm_head = nn.Dense(
1280
+ self.config.image_vocab_size
1281
+ + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1282
+ use_bias=False,
1283
+ dtype=self.dtype,
1284
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1285
+ )
1286
+
1287
+ def __call__(
1288
+ self,
1289
+ input_ids,
1290
+ attention_mask,
1291
+ decoder_input_ids,
1292
+ decoder_attention_mask,
1293
+ position_ids,
1294
+ decoder_position_ids,
1295
+ output_attentions: bool = False,
1296
+ output_hidden_states: bool = False,
1297
+ return_dict: bool = True,
1298
+ deterministic: bool = True,
1299
+ ):
1300
+ outputs = self.model(
1301
+ input_ids=input_ids,
1302
+ attention_mask=attention_mask,
1303
+ decoder_input_ids=decoder_input_ids,
1304
+ decoder_attention_mask=decoder_attention_mask,
1305
+ position_ids=position_ids,
1306
+ decoder_position_ids=decoder_position_ids,
1307
+ output_attentions=output_attentions,
1308
+ output_hidden_states=output_hidden_states,
1309
+ return_dict=return_dict,
1310
+ deterministic=deterministic,
1311
+ )
1312
+
1313
+ hidden_states = outputs[0]
1314
+
1315
+ if self.config.tie_word_embeddings:
1316
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
1317
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1318
+ else:
1319
+ lm_logits = self.lm_head(hidden_states)
1320
+
1321
+ if not return_dict:
1322
+ output = (lm_logits,) + outputs[1:]
1323
+ return output
1324
+
1325
+ return FlaxSeq2SeqLMOutput(
1326
+ logits=lm_logits,
1327
+ decoder_hidden_states=outputs.decoder_hidden_states,
1328
+ decoder_attentions=outputs.decoder_attentions,
1329
+ cross_attentions=outputs.cross_attentions,
1330
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1331
+ encoder_hidden_states=outputs.encoder_hidden_states,
1332
+ encoder_attentions=outputs.encoder_attentions,
1333
+ )
1334
+
1335
+
1336
+ @flax.struct.dataclass
1337
+ class SampleState:
1338
+ cur_len: jnp.ndarray
1339
+ sequences: jnp.ndarray
1340
+ running_token: jnp.ndarray
1341
+ is_sent_finished: jnp.ndarray
1342
+ prng_key: jnp.ndarray
1343
+ model_kwargs: Dict[str, jnp.ndarray]
1344
+ model_kwargs_uncond: Dict[str, jnp.ndarray]
1345
+
1346
+
1347
+ @flax.struct.dataclass
1348
+ class FlaxSampleOutput(ModelOutput):
1349
+ """
1350
+ Flax Base class for outputs of decoder-only generation models using sampling.
1351
+
1352
+
1353
+ Args:
1354
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
1355
+ The generated sequences.
1356
+ """
1357
+
1358
+ sequences: jnp.ndarray = None
1359
+
1360
+
1361
+ class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
1362
+ """
1363
+ Edits:
1364
+ - renamed from FlaxBartForConditionalGeneration
1365
+ - uses custom FlaxBartForConditionalGenerationModule
1366
+ - no bias in decode method
1367
+ - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
1368
+ related to position embedding during model.generate()
1369
+ - custom generate method to allow super conditions
1370
+ - num_params property
1371
+ - unscan function
1372
+ """
1373
+
1374
+ module_class = FlaxBartForConditionalGenerationModule
1375
+ config_class = DalleBartConfig
1376
+
1377
+ def num_params(self, params=None):
1378
+ if params is None:
1379
+ params = self.params
1380
+ num_params = jax.tree_util.tree_map(lambda param: param.size, flatten_dict(unfreeze(params))).values()
1381
+ return sum(list(num_params))
1382
+
1383
+ def unscan(self, params):
1384
+ if self.config.use_scan:
1385
+ self.config.use_scan = False
1386
+ params = flatten_dict(params)
1387
+ scanned_keys = [k for k in params.keys() if "layers" in k]
1388
+ for k in scanned_keys:
1389
+ v = params[k]
1390
+ name_idx = k.index("layers") + 1
1391
+ for i in range(len(v)):
1392
+ new_k = (
1393
+ *k[:name_idx],
1394
+ f"{k[name_idx][:-1]}_{i}",
1395
+ *k[name_idx + 1 :],
1396
+ )
1397
+ params[new_k] = v[i]
1398
+ del params[k]
1399
+ params = unflatten_dict(params)
1400
+ return params
1401
+
1402
+ def decode(
1403
+ self,
1404
+ decoder_input_ids,
1405
+ encoder_outputs,
1406
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1407
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1408
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1409
+ past_key_values: dict = None,
1410
+ output_attentions: Optional[bool] = None,
1411
+ output_hidden_states: Optional[bool] = None,
1412
+ return_dict: Optional[bool] = None,
1413
+ train: bool = False,
1414
+ params: dict = None,
1415
+ dropout_rng: PRNGKey = None,
1416
+ ):
1417
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1418
+ output_hidden_states = (
1419
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1420
+ )
1421
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1422
+
1423
+ encoder_hidden_states = encoder_outputs[0]
1424
+ if encoder_attention_mask is None:
1425
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
1426
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
1427
+
1428
+ batch_size, sequence_length = decoder_input_ids.shape
1429
+ if decoder_attention_mask is None:
1430
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1431
+
1432
+ if decoder_position_ids is None:
1433
+ if past_key_values is not None:
1434
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1435
+
1436
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
1437
+
1438
+ # Handle any PRNG if needed
1439
+ rngs = {}
1440
+ if dropout_rng is not None:
1441
+ rngs["dropout"] = dropout_rng
1442
+
1443
+ inputs = {"params": params or self.params}
1444
+
1445
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1446
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1447
+ # it can be changed by FlaxBartAttention module
1448
+ if past_key_values:
1449
+ inputs["cache"] = past_key_values
1450
+ mutable = ["cache"]
1451
+ else:
1452
+ mutable = False
1453
+
1454
+ def _decoder_forward(
1455
+ module,
1456
+ decoder_input_ids,
1457
+ decoder_attention_mask,
1458
+ decoder_position_ids,
1459
+ **kwargs,
1460
+ ):
1461
+ decoder_module = module._get_decoder_module()
1462
+ outputs = decoder_module(
1463
+ decoder_input_ids,
1464
+ decoder_attention_mask,
1465
+ decoder_position_ids,
1466
+ **kwargs,
1467
+ )
1468
+ hidden_states = outputs[0]
1469
+
1470
+ if self.config.tie_word_embeddings:
1471
+ shared_embedding = module.model.variables["params"]["shared"]["embedding"]
1472
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1473
+ else:
1474
+ lm_logits = module.lm_head(hidden_states)
1475
+
1476
+ return lm_logits, outputs
1477
+
1478
+ outputs = self.module.apply(
1479
+ inputs,
1480
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1481
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1482
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1483
+ encoder_hidden_states=encoder_hidden_states,
1484
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
1485
+ output_attentions=output_attentions,
1486
+ output_hidden_states=output_hidden_states,
1487
+ return_dict=return_dict,
1488
+ deterministic=not train,
1489
+ rngs=rngs,
1490
+ mutable=mutable,
1491
+ method=_decoder_forward,
1492
+ )
1493
+
1494
+ if past_key_values is None:
1495
+ lm_logits, decoder_outputs = outputs
1496
+ else:
1497
+ (lm_logits, decoder_outputs), past = outputs
1498
+
1499
+ if return_dict:
1500
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1501
+ logits=lm_logits,
1502
+ hidden_states=decoder_outputs.hidden_states,
1503
+ attentions=decoder_outputs.attentions,
1504
+ cross_attentions=decoder_outputs.cross_attentions,
1505
+ )
1506
+ else:
1507
+ outputs = (lm_logits,) + decoder_outputs[1:]
1508
+
1509
+ # add updated cache to model output
1510
+ if past_key_values is not None and return_dict:
1511
+ outputs["past_key_values"] = unfreeze(past["cache"])
1512
+ return outputs
1513
+ elif past_key_values is not None and not return_dict:
1514
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1515
+
1516
+ return outputs
1517
+
1518
+ def prepare_inputs_for_generation(
1519
+ self,
1520
+ decoder_input_ids,
1521
+ max_length,
1522
+ attention_mask: Optional[jnp.DeviceArray] = None,
1523
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
1524
+ encoder_outputs=None,
1525
+ **kwargs,
1526
+ ):
1527
+ # initializing the cache
1528
+ batch_size, seq_length = decoder_input_ids.shape
1529
+
1530
+ past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
1531
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1532
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1533
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1534
+ extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
1535
+ if decoder_attention_mask is not None:
1536
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1537
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
1538
+ else:
1539
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1540
+
1541
+ return {
1542
+ "past_key_values": past_key_values,
1543
+ "encoder_outputs": encoder_outputs,
1544
+ "encoder_attention_mask": attention_mask,
1545
+ "decoder_attention_mask": extended_attention_mask,
1546
+ "decoder_position_ids": position_ids,
1547
+ }
1548
+
1549
+ def generate(
1550
+ self,
1551
+ input_ids: jnp.ndarray,
1552
+ attention_mask: Optional[jnp.ndarray] = None,
1553
+ max_length: Optional[int] = None,
1554
+ pad_token_id: Optional[int] = None,
1555
+ bos_token_id: Optional[int] = None,
1556
+ eos_token_id: Optional[int] = None,
1557
+ decoder_start_token_id: Optional[int] = None,
1558
+ do_sample: Optional[bool] = None,
1559
+ prng_key: Optional[jnp.ndarray] = None,
1560
+ top_k: Optional[int] = None,
1561
+ top_p: Optional[float] = None,
1562
+ temperature: Optional[float] = None,
1563
+ num_beams: Optional[int] = None,
1564
+ no_repeat_ngram_size: Optional[int] = None,
1565
+ min_length: Optional[int] = None,
1566
+ forced_bos_token_id: Optional[int] = None,
1567
+ forced_eos_token_id: Optional[int] = None,
1568
+ length_penalty: Optional[float] = None,
1569
+ early_stopping: Optional[bool] = None,
1570
+ trace: bool = True,
1571
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1572
+ condition_scale: Optional[float] = 1.0,
1573
+ input_ids_uncond: Optional[jnp.ndarray] = None,
1574
+ attention_mask_uncond: Optional[jnp.ndarray] = None,
1575
+ **model_kwargs,
1576
+ ):
1577
+ """Edit: Allow super conditioning."""
1578
+
1579
+ # set init values
1580
+ max_length = max_length if max_length is not None else self.config.max_length
1581
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
1582
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1583
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1584
+ decoder_start_token_id = (
1585
+ decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
1586
+ )
1587
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1588
+
1589
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
1590
+ raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
1591
+
1592
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
1593
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
1594
+
1595
+ if self.config.is_encoder_decoder:
1596
+ # add encoder_outputs to model_kwargs
1597
+ if model_kwargs.get("encoder_outputs") is None:
1598
+ model_kwargs_input = dict(model_kwargs)
1599
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
1600
+ input_ids,
1601
+ params,
1602
+ {"attention_mask": attention_mask, **model_kwargs_input},
1603
+ )
1604
+ if condition_scale != 1.0:
1605
+ assert input_ids_uncond is not None, "`input_ids_uncond` has to be defined for super conditioning."
1606
+ assert do_sample is True, "`do_sample` has to be True for super conditioning."
1607
+ assert num_beams == 1, "`num_beams` has to be 1 for super conditioning."
1608
+ model_kwargs_uncond = self._prepare_encoder_decoder_kwargs_for_generation(
1609
+ input_ids_uncond,
1610
+ params,
1611
+ {
1612
+ "attention_mask": attention_mask_uncond,
1613
+ **model_kwargs_input,
1614
+ },
1615
+ )
1616
+ else:
1617
+ model_kwargs_uncond = None
1618
+ # prepare decoder_input_ids for generation
1619
+ input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1620
+
1621
+ if not do_sample and num_beams == 1:
1622
+ logits_processor = self._get_logits_processor(
1623
+ no_repeat_ngram_size,
1624
+ min_length,
1625
+ max_length,
1626
+ eos_token_id,
1627
+ forced_bos_token_id,
1628
+ forced_eos_token_id,
1629
+ )
1630
+ return self._greedy_search(
1631
+ input_ids,
1632
+ max_length,
1633
+ pad_token_id,
1634
+ eos_token_id,
1635
+ logits_processor=logits_processor,
1636
+ trace=trace,
1637
+ params=params,
1638
+ model_kwargs=model_kwargs,
1639
+ )
1640
+ elif do_sample and num_beams == 1:
1641
+ try:
1642
+ logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
1643
+ logits_processor = self._get_logits_processor(
1644
+ no_repeat_ngram_size,
1645
+ min_length,
1646
+ max_length,
1647
+ eos_token_id,
1648
+ forced_bos_token_id,
1649
+ forced_eos_token_id,
1650
+ )
1651
+ except:
1652
+ logits_warper = self._get_logits_warper(
1653
+ generation_config=GenerationConfig(top_k=top_k, top_p=top_p, temperature=temperature)
1654
+ )
1655
+ logits_processor = self._get_logits_processor(
1656
+ generation_config=GenerationConfig(
1657
+ no_repeat_ngram_size=no_repeat_ngram_size,
1658
+ min_length=min_length,
1659
+ max_length=max_length,
1660
+ eos_token_id=eos_token_id,
1661
+ forced_bos_token_id=forced_bos_token_id,
1662
+ forced_eos_token_id=forced_eos_token_id,
1663
+ )
1664
+ )
1665
+
1666
+ return self._sample(
1667
+ input_ids,
1668
+ max_length,
1669
+ pad_token_id,
1670
+ eos_token_id,
1671
+ prng_key,
1672
+ logits_warper=logits_warper,
1673
+ logits_processor=logits_processor,
1674
+ trace=trace,
1675
+ params=params,
1676
+ model_kwargs=model_kwargs,
1677
+ condition_scale=condition_scale,
1678
+ model_kwargs_uncond=model_kwargs_uncond,
1679
+ )
1680
+ elif not do_sample and num_beams > 1:
1681
+ # broadcast input_ids & encoder_outputs
1682
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
1683
+
1684
+ if "encoder_outputs" in model_kwargs:
1685
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
1686
+ model_kwargs["encoder_outputs"]["last_hidden_state"],
1687
+ num_beams=num_beams,
1688
+ )
1689
+
1690
+ if "attention_mask" in model_kwargs:
1691
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
1692
+ model_kwargs["attention_mask"], num_beams=num_beams
1693
+ )
1694
+
1695
+ logits_processor = self._get_logits_processor(
1696
+ no_repeat_ngram_size,
1697
+ min_length,
1698
+ max_length,
1699
+ eos_token_id,
1700
+ forced_bos_token_id,
1701
+ forced_eos_token_id,
1702
+ )
1703
+
1704
+ return self._beam_search(
1705
+ input_ids,
1706
+ max_length,
1707
+ pad_token_id,
1708
+ eos_token_id,
1709
+ length_penalty=length_penalty,
1710
+ early_stopping=early_stopping,
1711
+ logits_processor=logits_processor,
1712
+ trace=trace,
1713
+ params=params,
1714
+ model_kwargs=model_kwargs,
1715
+ )
1716
+ else:
1717
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
1718
+
1719
+ def _sample(
1720
+ self,
1721
+ input_ids: None,
1722
+ max_length: Optional[int] = None,
1723
+ pad_token_id: Optional[int] = None,
1724
+ eos_token_id: Optional[int] = None,
1725
+ prng_key: Optional[jnp.ndarray] = None,
1726
+ logits_processor=None,
1727
+ logits_warper=None,
1728
+ trace: bool = True,
1729
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1730
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
1731
+ condition_scale: float = 1.0,
1732
+ model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
1733
+ ):
1734
+ # init values
1735
+ max_length = max_length if max_length is not None else self.config.max_length
1736
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1737
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1738
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1739
+
1740
+ batch_size, cur_len = input_ids.shape
1741
+
1742
+ eos_token_id = jnp.array(eos_token_id)
1743
+ pad_token_id = jnp.array(pad_token_id)
1744
+ cur_len = jnp.array(cur_len)
1745
+
1746
+ # per batch-item holding current token in loop.
1747
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
1748
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
1749
+
1750
+ # per batch-item state bit indicating if sentence has finished.
1751
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
1752
+
1753
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1754
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1755
+ model = self.decode if self.config.is_encoder_decoder else self
1756
+
1757
+ # initialize model specific kwargs
1758
+ model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
1759
+ if condition_scale != 1.0:
1760
+ model_kwargs_uncond = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs_uncond)
1761
+
1762
+ # initialize state
1763
+ state = SampleState(
1764
+ cur_len=cur_len,
1765
+ sequences=sequences,
1766
+ running_token=input_ids,
1767
+ is_sent_finished=is_sent_finished,
1768
+ prng_key=prng_key,
1769
+ model_kwargs=model_kwargs,
1770
+ model_kwargs_uncond=model_kwargs_uncond,
1771
+ )
1772
+
1773
+ def sample_search_cond_fn(state):
1774
+ """state termination condition fn."""
1775
+ has_reached_max_length = state.cur_len == max_length
1776
+ all_sequence_finished = jnp.all(state.is_sent_finished)
1777
+ finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
1778
+ return ~finish_generation
1779
+
1780
+ def sample_search_body_fn(state):
1781
+ """state update fn."""
1782
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
1783
+ model_outputs = model(state.running_token, params=params, **state.model_kwargs)
1784
+
1785
+ logits = model_outputs.logits[:, -1]
1786
+
1787
+ # perform super conditioning
1788
+ # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
1789
+ if condition_scale != 1.0:
1790
+ model_outputs_uncond = model(state.running_token, params=params, **state.model_kwargs_uncond)
1791
+ logits_uncond = model_outputs_uncond.logits[:, -1]
1792
+ logits = logits_uncond + condition_scale * (logits - logits_uncond)
1793
+ else:
1794
+ model_outputs_uncond = None
1795
+
1796
+ # apply min_length, ...
1797
+ logits = logits_processor(state.sequences, logits, state.cur_len)
1798
+ # apply top_k, top_k, temperature
1799
+ logits = logits_warper(logits, logits, state.cur_len)
1800
+
1801
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
1802
+
1803
+ next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
1804
+ next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
1805
+ next_token = next_token[:, None]
1806
+
1807
+ next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
1808
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1809
+ next_model_kwargs_uncond = (
1810
+ self.update_inputs_for_generation(model_outputs_uncond, state.model_kwargs_uncond)
1811
+ if condition_scale != 1.0
1812
+ else None
1813
+ )
1814
+
1815
+ return SampleState(
1816
+ cur_len=state.cur_len + 1,
1817
+ sequences=next_sequences,
1818
+ running_token=next_token,
1819
+ is_sent_finished=next_is_sent_finished,
1820
+ model_kwargs=next_model_kwargs,
1821
+ model_kwargs_uncond=next_model_kwargs_uncond,
1822
+ prng_key=prng_key_next,
1823
+ )
1824
+
1825
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1826
+ if input_ids.shape[1] > 1:
1827
+ state = sample_search_body_fn(state)
1828
+
1829
+ if not trace:
1830
+ state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
1831
+ else:
1832
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
1833
+
1834
+ return FlaxSampleOutput(sequences=state.sequences)