crfm-helm 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (333) hide show
  1. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/METADATA +7 -77
  2. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/RECORD +315 -282
  3. helm/benchmark/adaptation/adapter_spec.py +10 -0
  4. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +11 -3
  5. helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +11 -8
  6. helm/benchmark/annotation/aci_bench_annotator.py +11 -22
  7. helm/benchmark/annotation/alrage_annotator.py +90 -0
  8. helm/benchmark/annotation/chw_care_plan_annotator.py +10 -21
  9. helm/benchmark/annotation/dischargeme_annotator.py +11 -22
  10. helm/benchmark/annotation/med_dialog_annotator.py +11 -22
  11. helm/benchmark/annotation/medalign_annotator.py +11 -22
  12. helm/benchmark/annotation/medi_qa_annotator.py +11 -22
  13. helm/benchmark/annotation/medication_qa_annotator.py +11 -22
  14. helm/benchmark/annotation/mental_health_annotator.py +11 -22
  15. helm/benchmark/annotation/mimic_bhc_annotator.py +11 -22
  16. helm/benchmark/annotation/mimic_rrs_annotator.py +11 -22
  17. helm/benchmark/annotation/model_as_judge.py +23 -18
  18. helm/benchmark/annotation/mtsamples_procedures_annotator.py +11 -22
  19. helm/benchmark/annotation/mtsamples_replicate_annotator.py +11 -22
  20. helm/benchmark/annotation/starr_patient_instructions_annotator.py +11 -22
  21. helm/benchmark/metrics/air_bench_metrics.py +3157 -1
  22. helm/benchmark/metrics/alrage_metric.py +35 -0
  23. helm/benchmark/metrics/basic_metrics.py +267 -2
  24. helm/benchmark/metrics/bbq_metrics.py +12 -0
  25. helm/benchmark/metrics/classification_metrics.py +19 -1
  26. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +12 -1
  27. helm/benchmark/metrics/dry_run_metrics.py +30 -1
  28. helm/benchmark/metrics/efficiency_metrics.py +74 -0
  29. helm/benchmark/metrics/ehr_sql_metrics.py +57 -1
  30. helm/benchmark/metrics/evaluate_reference_metrics.py +311 -0
  31. helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +13 -1
  32. helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +13 -1
  33. helm/benchmark/metrics/ifeval_metrics.py +13 -1
  34. helm/benchmark/metrics/instruction_following_critique_metrics.py +41 -1
  35. helm/benchmark/metrics/kpi_edgar_metrics.py +21 -0
  36. helm/benchmark/metrics/language_modeling_metrics.py +13 -1
  37. helm/benchmark/metrics/live_qa_metrics.py +13 -1
  38. helm/benchmark/metrics/llm_jury_metrics.py +13 -1
  39. helm/benchmark/metrics/medcalc_bench_metrics.py +14 -1
  40. helm/benchmark/metrics/medec_metrics.py +25 -2
  41. helm/benchmark/metrics/metric.py +25 -0
  42. helm/benchmark/metrics/mimiciv_billing_code_metrics.py +32 -1
  43. helm/benchmark/metrics/omni_math_metrics.py +13 -1
  44. helm/benchmark/metrics/safety_metrics.py +13 -1
  45. helm/benchmark/metrics/seahelm_metrics.py +14 -1
  46. helm/benchmark/metrics/summac/model_summac.py +2 -2
  47. helm/benchmark/metrics/summarization_metrics.py +129 -1
  48. helm/benchmark/metrics/toxicity_metrics.py +31 -1
  49. helm/benchmark/metrics/ultra_suite_asr_classification_metrics.py +52 -0
  50. helm/benchmark/metrics/wildbench_metrics.py +21 -1
  51. helm/benchmark/presentation/run_display.py +13 -3
  52. helm/benchmark/presentation/run_entry.py +2 -2
  53. helm/benchmark/presentation/schema.py +5 -22
  54. helm/benchmark/presentation/summarize.py +180 -11
  55. helm/benchmark/presentation/taxonomy_info.py +20 -0
  56. helm/benchmark/run.py +1 -1
  57. helm/benchmark/run_expander.py +4 -0
  58. helm/benchmark/run_specs/arabic_run_specs.py +140 -16
  59. helm/benchmark/run_specs/bluex_run_specs.py +1 -1
  60. helm/benchmark/run_specs/classic_run_specs.py +2 -2
  61. helm/benchmark/run_specs/long_context_run_specs.py +2 -2
  62. helm/benchmark/run_specs/medhelm/__init__.py +0 -0
  63. helm/benchmark/run_specs/medhelm/benchmark_config.py +219 -0
  64. helm/benchmark/run_specs/medhelm_run_specs.py +362 -52
  65. helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +6 -2
  66. helm/benchmark/scenarios/aci_bench_scenario.py +23 -0
  67. helm/benchmark/scenarios/air_bench_scenario.py +21 -0
  68. helm/benchmark/scenarios/alrage_scenario.py +54 -0
  69. helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +23 -1
  70. helm/benchmark/scenarios/anthropic_red_team_scenario.py +12 -1
  71. helm/benchmark/scenarios/arabic_exams_scenario.py +114 -0
  72. helm/benchmark/scenarios/arabic_mmlu_scenario.py +8 -4
  73. helm/benchmark/scenarios/aratrust_scenario.py +19 -0
  74. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +24 -54
  75. helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +19 -48
  76. helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +22 -61
  77. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +21 -29
  78. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +21 -60
  79. helm/benchmark/scenarios/babi_qa_scenario.py +15 -0
  80. helm/benchmark/scenarios/banking77_scenario.py +21 -0
  81. helm/benchmark/scenarios/bbq_scenario.py +15 -0
  82. helm/benchmark/scenarios/best_chatgpt_prompts.yaml +473 -0
  83. helm/benchmark/scenarios/bird_sql_scenario.py +18 -0
  84. helm/benchmark/scenarios/bluex_scenario.py +6 -2
  85. helm/benchmark/scenarios/bold_scenario.py +15 -0
  86. helm/benchmark/scenarios/boolq_scenario.py +20 -0
  87. helm/benchmark/scenarios/chw_care_plan_scenario.py +23 -0
  88. helm/benchmark/scenarios/civil_comments_scenario.py +13 -0
  89. helm/benchmark/scenarios/clear_scenario.py +23 -0
  90. helm/benchmark/scenarios/cleva_scenario.py +479 -0
  91. helm/benchmark/scenarios/code_scenario.py +28 -0
  92. helm/benchmark/scenarios/commonsense_scenario.py +32 -0
  93. helm/benchmark/scenarios/compositional_instructions.yaml +70 -0
  94. helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +21 -0
  95. helm/benchmark/scenarios/copyright_scenario.py +35 -1
  96. helm/benchmark/scenarios/cti_to_mitre_scenario.py +21 -0
  97. helm/benchmark/scenarios/czech_bank_qa_scenario.py +18 -0
  98. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +22 -1
  99. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +23 -1
  100. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +22 -1
  101. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +21 -1
  102. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +13 -0
  103. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +13 -1
  104. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +13 -1
  105. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +13 -1
  106. helm/benchmark/scenarios/dischargeme_scenario.py +24 -0
  107. helm/benchmark/scenarios/disinformation_scenario.py +22 -0
  108. helm/benchmark/scenarios/dyck_language_scenario.py +15 -0
  109. helm/benchmark/scenarios/ehrshot_scenario.py +22 -0
  110. helm/benchmark/scenarios/enem_challenge_scenario.py +19 -0
  111. helm/benchmark/scenarios/entity_data_imputation_scenario.py +14 -0
  112. helm/benchmark/scenarios/entity_matching_scenario.py +14 -0
  113. helm/benchmark/scenarios/fin_qa_scenario.py +20 -0
  114. helm/benchmark/scenarios/financebench_scenario.py +21 -0
  115. helm/benchmark/scenarios/financial_phrasebank_scenario.py +21 -0
  116. helm/benchmark/scenarios/gold_commodity_news_scenario.py +21 -0
  117. helm/benchmark/scenarios/gpqa_scenario.py +18 -0
  118. helm/benchmark/scenarios/grammar_scenario.py +20 -1
  119. helm/benchmark/scenarios/gsm_scenario.py +21 -0
  120. helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +12 -1
  121. helm/benchmark/scenarios/harm_bench_scenario.py +12 -1
  122. helm/benchmark/scenarios/headqa_scenario.py +22 -0
  123. helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +13 -0
  124. helm/benchmark/scenarios/ice_scenario.py +21 -1
  125. helm/benchmark/scenarios/ifeval_scenario.py +18 -0
  126. helm/benchmark/scenarios/imdb_scenario.py +15 -0
  127. helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +21 -0
  128. helm/benchmark/scenarios/infinite_bench_en_sum_scenario.py +19 -0
  129. helm/benchmark/scenarios/koala_scenario.py +21 -1
  130. helm/benchmark/scenarios/kpi_edgar_scenario.py +21 -0
  131. helm/benchmark/scenarios/legal_contract_summarization_scenario.py +20 -0
  132. helm/benchmark/scenarios/legal_summarization_scenario.py +50 -0
  133. helm/benchmark/scenarios/legal_support_scenario.py +13 -0
  134. helm/benchmark/scenarios/legalbench_scenario.py +19 -0
  135. helm/benchmark/scenarios/lex_glue_scenario.py +11 -0
  136. helm/benchmark/scenarios/lextreme_scenario.py +11 -0
  137. helm/benchmark/scenarios/lsat_qa_scenario.py +14 -0
  138. helm/benchmark/scenarios/madinah_qa_scenario.py +73 -0
  139. helm/benchmark/scenarios/math_scenario.py +33 -0
  140. helm/benchmark/scenarios/mbzuai_human_translated_arabic_mmlu.py +68 -0
  141. helm/benchmark/scenarios/med_dialog_scenario.py +32 -1
  142. helm/benchmark/scenarios/med_mcqa_scenario.py +14 -0
  143. helm/benchmark/scenarios/med_qa_scenario.py +20 -0
  144. helm/benchmark/scenarios/medalign_scenario.py +23 -0
  145. helm/benchmark/scenarios/medbullets_scenario.py +22 -0
  146. helm/benchmark/scenarios/medcalc_bench_scenario.py +22 -0
  147. helm/benchmark/scenarios/medec_scenario.py +23 -0
  148. helm/benchmark/scenarios/medhallu_scenario.py +23 -0
  149. helm/benchmark/scenarios/medhelm/__init__.py +0 -0
  150. helm/benchmark/scenarios/medhelm/judges.yaml +14 -0
  151. helm/benchmark/scenarios/medhelm_configurable_scenario.py +101 -0
  152. helm/benchmark/scenarios/medi_qa_scenario.py +24 -1
  153. helm/benchmark/scenarios/medication_qa_scenario.py +31 -1
  154. helm/benchmark/scenarios/mental_health_scenario.py +23 -0
  155. helm/benchmark/scenarios/mimic_bhc_scenario.py +24 -0
  156. helm/benchmark/scenarios/mimic_rrs_scenario.py +23 -0
  157. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +22 -0
  158. helm/benchmark/scenarios/mmlu_pro_scenario.py +18 -0
  159. helm/benchmark/scenarios/mmlu_scenario.py +21 -0
  160. helm/benchmark/scenarios/msmarco_scenario.py +30 -0
  161. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +22 -0
  162. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +22 -0
  163. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +20 -0
  164. helm/benchmark/scenarios/narrativeqa_scenario.py +19 -0
  165. helm/benchmark/scenarios/natural_qa_scenario.py +32 -0
  166. helm/benchmark/scenarios/omni_math_scenario.py +18 -0
  167. helm/benchmark/scenarios/open_assistant_scenario.py +22 -0
  168. helm/benchmark/scenarios/openai_mrcr_scenario.py +15 -0
  169. helm/benchmark/scenarios/pubmed_qa_scenario.py +22 -0
  170. helm/benchmark/scenarios/quac_scenario.py +14 -0
  171. helm/benchmark/scenarios/race_based_med_scenario.py +23 -0
  172. helm/benchmark/scenarios/raft_scenario.py +15 -0
  173. helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +14 -1
  174. helm/benchmark/scenarios/ruler_qa_scenarios.py +40 -0
  175. helm/benchmark/scenarios/scenario.py +31 -0
  176. helm/benchmark/scenarios/seahelm_scenario.py +348 -0
  177. helm/benchmark/scenarios/self_instruct_scenario.py +29 -1
  178. helm/benchmark/scenarios/shc_bmt_scenario.py +22 -0
  179. helm/benchmark/scenarios/shc_cdi_scenario.py +20 -0
  180. helm/benchmark/scenarios/shc_conf_scenario.py +23 -0
  181. helm/benchmark/scenarios/shc_ent_scenario.py +21 -0
  182. helm/benchmark/scenarios/shc_gip_scenario.py +20 -0
  183. helm/benchmark/scenarios/shc_privacy_scenario.py +22 -0
  184. helm/benchmark/scenarios/shc_proxy_scenario.py +22 -0
  185. helm/benchmark/scenarios/shc_ptbm_scenario.py +23 -0
  186. helm/benchmark/scenarios/shc_sequoia_scenario.py +21 -0
  187. helm/benchmark/scenarios/simple_safety_tests_scenario.py +12 -1
  188. helm/benchmark/scenarios/situation_prompts.yaml +49 -0
  189. helm/benchmark/scenarios/spider_scenario.py +18 -0
  190. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +22 -0
  191. helm/benchmark/scenarios/summarization_scenario.py +37 -0
  192. helm/benchmark/scenarios/synthetic_efficiency_scenario.py +22 -1
  193. helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +13 -0
  194. helm/benchmark/scenarios/test_alrage_scenario.py +23 -0
  195. helm/benchmark/scenarios/test_arabic_exams_scenario.py +21 -0
  196. helm/benchmark/scenarios/test_aratrust_scenario.py +1 -1
  197. helm/benchmark/scenarios/test_bluex_scenario.py +2 -2
  198. helm/benchmark/scenarios/thai_exam_scenario.py +95 -0
  199. helm/benchmark/scenarios/the_pile_scenario.py +13 -1
  200. helm/benchmark/scenarios/truthful_qa_scenario.py +14 -0
  201. helm/benchmark/scenarios/twitter_aae_scenario.py +20 -1
  202. helm/benchmark/scenarios/vicuna_scenario.py +21 -1
  203. helm/benchmark/scenarios/wikifact_scenario.py +20 -0
  204. helm/benchmark/scenarios/wildbench_scenario.py +18 -0
  205. helm/benchmark/scenarios/wmt_14_scenario.py +19 -0
  206. helm/benchmark/static/schema_arabic.yaml +55 -12
  207. helm/benchmark/static/schema_long_context.yaml +11 -30
  208. helm/benchmark/static/schema_medhelm.yaml +36 -0
  209. helm/benchmark/static/schema_slp.yaml +219 -0
  210. helm/benchmark/static_build/assets/audio-table-Dn5NMMeJ.png +0 -0
  211. helm/benchmark/static_build/assets/index-oIeiQW2g.css +1 -0
  212. helm/benchmark/static_build/assets/index-qOFpOyHb.js +10 -0
  213. helm/benchmark/static_build/assets/react-BteFIppM.js +85 -0
  214. helm/benchmark/static_build/assets/recharts-DxuQtTOs.js +97 -0
  215. helm/benchmark/static_build/assets/tremor-DR4fE7ko.js +10 -0
  216. helm/benchmark/static_build/index.html +5 -6
  217. helm/clients/ai21_client.py +2 -0
  218. helm/clients/aleph_alpha_client.py +2 -0
  219. helm/clients/anthropic_client.py +7 -1
  220. helm/clients/audio_language/diva_llama_client.py +2 -0
  221. helm/clients/audio_language/llama_omni/arguments.py +61 -0
  222. helm/clients/audio_language/llama_omni/constants.py +9 -0
  223. helm/clients/audio_language/llama_omni/conversation.py +213 -0
  224. helm/clients/audio_language/llama_omni/model/__init__.py +0 -0
  225. helm/clients/audio_language/llama_omni/model/builder.py +88 -0
  226. helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py +190 -0
  227. helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py +118 -0
  228. helm/clients/audio_language/llama_omni/model/omni_speech_arch.py +249 -0
  229. helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py +9 -0
  230. helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py +27 -0
  231. helm/clients/audio_language/llama_omni/model/speech_generator/builder.py +9 -0
  232. helm/clients/audio_language/llama_omni/model/speech_generator/generation.py +622 -0
  233. helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py +104 -0
  234. helm/clients/audio_language/llama_omni/model/speech_projector/builder.py +9 -0
  235. helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py +27 -0
  236. helm/clients/audio_language/llama_omni/preprocess.py +295 -0
  237. helm/clients/audio_language/llama_omni/utils.py +202 -0
  238. helm/clients/audio_language/llama_omni_client.py +2 -1
  239. helm/clients/audio_language/qwen2_5_omni_client.py +2 -1
  240. helm/clients/audio_language/qwen2_audiolm_client.py +2 -1
  241. helm/clients/audio_language/qwen_audiolm_client.py +2 -1
  242. helm/clients/audio_language/qwen_omni/configuration_qwen2_5_omni.py +519 -0
  243. helm/clients/audio_language/qwen_omni/modeling_qwen2_5_omni.py +4308 -0
  244. helm/clients/audio_language/qwen_omni/processing_qwen2_5_omni.py +270 -0
  245. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/__init__.py +0 -0
  246. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/__init__.py +8 -0
  247. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/audio_process.py +56 -0
  248. helm/clients/audio_language/qwen_omni/qwen2_5_omni_utils/v2_5/vision_process.py +380 -0
  249. helm/clients/bedrock_client.py +2 -0
  250. helm/clients/cohere_client.py +3 -0
  251. helm/clients/google_client.py +2 -0
  252. helm/clients/http_model_client.py +2 -0
  253. helm/clients/huggingface_client.py +2 -1
  254. helm/clients/ibm_client.py +3 -1
  255. helm/clients/image_generation/adobe_vision_client.py +2 -0
  256. helm/clients/image_generation/aleph_alpha_image_generation_client.py +2 -0
  257. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +1 -1
  258. helm/clients/image_generation/cogview2_client.py +2 -1
  259. helm/clients/image_generation/dalle2_client.py +2 -0
  260. helm/clients/image_generation/dalle_mini_client.py +2 -1
  261. helm/clients/image_generation/deep_floyd_client.py +2 -0
  262. helm/clients/image_generation/huggingface_diffusers_client.py +2 -1
  263. helm/clients/image_generation/lexica_client.py +2 -0
  264. helm/clients/image_generation/mindalle/models/stage1/layers.py +2 -2
  265. helm/clients/image_generation/mindalle_client.py +2 -1
  266. helm/clients/image_generation/together_image_generation_client.py +2 -0
  267. helm/clients/megatron_client.py +2 -0
  268. helm/clients/mistral_client.py +2 -0
  269. helm/clients/moderation_api_client.py +2 -0
  270. helm/clients/openai_client.py +36 -20
  271. helm/clients/openai_responses_client.py +27 -3
  272. helm/clients/openrouter_client.py +31 -0
  273. helm/clients/palmyra_client.py +2 -1
  274. helm/clients/reka_client.py +2 -1
  275. helm/clients/stanfordhealthcare_azure_openai_client.py +2 -2
  276. helm/clients/stanfordhealthcare_http_model_client.py +2 -0
  277. helm/clients/test_openrouter_client.py +69 -0
  278. helm/clients/together_client.py +52 -11
  279. helm/clients/vertexai_client.py +12 -2
  280. helm/clients/vision_language/huggingface_vision2seq_client.py +2 -1
  281. helm/clients/vision_language/huggingface_vlm_client.py +2 -0
  282. helm/clients/vision_language/idefics_client.py +2 -1
  283. helm/clients/vision_language/open_flamingo_client.py +2 -1
  284. helm/clients/vision_language/paligemma_client.py +2 -1
  285. helm/clients/vision_language/palmyra_vision_client.py +2 -0
  286. helm/clients/vision_language/qwen2_vlm_client.py +2 -1
  287. helm/clients/vision_language/qwen_vlm_client.py +2 -1
  288. helm/clients/writer_client.py +2 -0
  289. helm/common/hierarchical_logger.py +20 -0
  290. helm/common/optional_dependencies.py +1 -1
  291. helm/common/test_general.py +4 -0
  292. helm/config/model_deployments.yaml +300 -1
  293. helm/config/model_metadata.yaml +302 -9
  294. helm/config/tokenizer_configs.yaml +92 -4
  295. helm/proxy/example_queries.py +8 -8
  296. helm/proxy/server.py +2 -1
  297. helm/proxy/static/index.css +4 -0
  298. helm/proxy/static/index.js +7 -1
  299. helm/benchmark/metrics/aci_bench_metrics.py +0 -14
  300. helm/benchmark/metrics/chw_care_plan_metrics.py +0 -14
  301. helm/benchmark/metrics/dischargeme_metrics.py +0 -14
  302. helm/benchmark/metrics/med_dialog_metrics.py +0 -14
  303. helm/benchmark/metrics/medalign_metrics.py +0 -14
  304. helm/benchmark/metrics/medi_qa_metrics.py +0 -14
  305. helm/benchmark/metrics/medication_qa_metrics.py +0 -14
  306. helm/benchmark/metrics/mental_health_metrics.py +0 -14
  307. helm/benchmark/metrics/mimic_bhc_metrics.py +0 -14
  308. helm/benchmark/metrics/mimic_rrs_metrics.py +0 -14
  309. helm/benchmark/metrics/mtsamples_procedures_metrics.py +0 -14
  310. helm/benchmark/metrics/mtsamples_replicate_metrics.py +0 -14
  311. helm/benchmark/metrics/starr_patient_instructions_metrics.py +0 -14
  312. helm/benchmark/static_build/assets/index-b9779128.css +0 -1
  313. helm/benchmark/static_build/assets/index-e439d5e1.js +0 -10
  314. helm/benchmark/static_build/assets/react-f82877fd.js +0 -85
  315. helm/benchmark/static_build/assets/recharts-4037aff0.js +0 -97
  316. helm/benchmark/static_build/assets/tremor-38a10867.js +0 -10
  317. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/WHEEL +0 -0
  318. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/entry_points.txt +0 -0
  319. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/licenses/LICENSE +0 -0
  320. {crfm_helm-0.5.7.dist-info → crfm_helm-0.5.9.dist-info}/top_level.txt +0 -0
  321. /helm/benchmark/static_build/assets/{air-overview-d2e6c49f.png → air-overview-DpBbyagA.png} +0 -0
  322. /helm/benchmark/static_build/assets/{crfm-logo-74391ab8.png → crfm-logo-Du4T1uWZ.png} +0 -0
  323. /helm/benchmark/static_build/assets/{heim-logo-3e5e3aa4.png → heim-logo-BJtQlEbV.png} +0 -0
  324. /helm/benchmark/static_build/assets/{helm-logo-simple-2ed5400b.png → helm-logo-simple-DzOhNN41.png} +0 -0
  325. /helm/benchmark/static_build/assets/{helm-safety-2907a7b6.png → helm-safety-COfndXuS.png} +0 -0
  326. /helm/benchmark/static_build/assets/{helmhero-28e90f4d.png → helmhero-D9TvmJsp.png} +0 -0
  327. /helm/benchmark/static_build/assets/{medhelm-overview-eac29843.png → medhelm-overview-CND0EIsy.png} +0 -0
  328. /helm/benchmark/static_build/assets/{medhelm-v1-overview-3ddfcd65.png → medhelm-v1-overview-Cu2tphBB.png} +0 -0
  329. /helm/benchmark/static_build/assets/{overview-74aea3d8.png → overview-BwypNWnk.png} +0 -0
  330. /helm/benchmark/static_build/assets/{process-flow-bd2eba96.png → process-flow-DWDJC733.png} +0 -0
  331. /helm/benchmark/static_build/assets/{vhelm-aspects-1437d673.png → vhelm-aspects-NiDQofvP.png} +0 -0
  332. /helm/benchmark/static_build/assets/{vhelm-framework-a1ca3f3f.png → vhelm-framework-NxJE4fdA.png} +0 -0
  333. /helm/benchmark/static_build/assets/{vhelm-model-8afb7616.png → vhelm-model-ypCL5Yvq.png} +0 -0
@@ -0,0 +1,622 @@
1
+ import torch
2
+ import inspect
3
+ import warnings
4
+ import torch.nn as nn
5
+ from typing import Optional, Union, List, Callable
6
+ import torch.distributed as dist
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.generation.streamers import BaseStreamer
10
+ from transformers.generation.utils import (
11
+ GenerationConfig,
12
+ GenerationMode,
13
+ LogitsProcessorList,
14
+ StoppingCriteriaList,
15
+ GenerationMixin,
16
+ GenerateEncoderDecoderOutput,
17
+ GenerateDecoderOnlyOutput,
18
+ GenerateNonBeamOutput,
19
+ is_deepspeed_zero3_enabled,
20
+ is_torchdynamo_compiling,
21
+ NEED_SETUP_CACHE_CLASSES_MAPPING,
22
+ QUANT_BACKEND_CLASSES_MAPPING,
23
+ is_hqq_available,
24
+ QuantizedCacheConfig,
25
+ is_quanto_available,
26
+ DynamicCache,
27
+ EncoderDecoderCache,
28
+ logging,
29
+ )
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class GenerationWithCTC(GenerationMixin):
35
+
36
+ @torch.no_grad()
37
+ def generate(
38
+ self,
39
+ inputs: Optional[torch.Tensor] = None,
40
+ speech: Optional[torch.Tensor] = None,
41
+ speech_lengths: Optional[torch.Tensor] = None,
42
+ generation_config: Optional[GenerationConfig] = None,
43
+ logits_processor: Optional[LogitsProcessorList] = None,
44
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
45
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
46
+ synced_gpus: Optional[bool] = None,
47
+ assistant_model: Optional["PreTrainedModel"] = None,
48
+ streamer: Optional["BaseStreamer"] = None,
49
+ streamer_unit: Optional["BaseStreamer"] = None,
50
+ streaming_unit_gen=False,
51
+ negative_prompt_ids: Optional[torch.Tensor] = None,
52
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
53
+ **kwargs,
54
+ ):
55
+
56
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
57
+ self._validate_model_class()
58
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
59
+ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
60
+ self._validate_model_kwargs(model_kwargs.copy())
61
+ self._validate_assistant(assistant_model)
62
+
63
+ # 2. Set generation parameters if not already defined
64
+ if synced_gpus is None:
65
+ if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
66
+ synced_gpus = True
67
+ else:
68
+ synced_gpus = False
69
+
70
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
71
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
72
+
73
+ accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
74
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
75
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
76
+
77
+ # 3. Define model inputs
78
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
79
+ inputs, generation_config.bos_token_id, model_kwargs
80
+ )
81
+ batch_size = inputs_tensor.shape[0]
82
+
83
+ device = inputs_tensor.device
84
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
85
+
86
+ # decoder-only models must use left-padding for batched generation.
87
+ if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
88
+ # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
89
+ # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
90
+ if (
91
+ generation_config._pad_token_tensor is not None
92
+ and batch_size > 1
93
+ and len(inputs_tensor.shape) == 2
94
+ and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
95
+ ):
96
+ logger.warning(
97
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
98
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
99
+ )
100
+
101
+ # 4. Define other model kwargs
102
+ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
103
+ # generating the first new token or not, and we only want to use the embeddings for the first new token)
104
+ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
105
+ model_kwargs["use_cache"] = True
106
+ else:
107
+ model_kwargs["use_cache"] = generation_config.use_cache
108
+
109
+ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
110
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
111
+ inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
112
+ )
113
+
114
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
115
+ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
116
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
117
+ inputs_tensor, model_kwargs, model_input_name, generation_config
118
+ )
119
+
120
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
121
+ if self.config.is_encoder_decoder:
122
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
123
+ batch_size=batch_size,
124
+ model_input_name=model_input_name,
125
+ model_kwargs=model_kwargs,
126
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
127
+ device=inputs_tensor.device,
128
+ )
129
+ else:
130
+ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
131
+
132
+ if generation_config.token_healing:
133
+ input_ids = self.heal_tokens(input_ids, tokenizer)
134
+
135
+ if streamer is not None:
136
+ streamer.put(input_ids.cpu())
137
+
138
+ # 6. Prepare `max_length` depending on other stopping criteria.
139
+ input_ids_length = input_ids.shape[-1]
140
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
141
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
142
+ generation_config = self._prepare_generated_length(
143
+ generation_config=generation_config,
144
+ has_default_max_length=has_default_max_length,
145
+ has_default_min_length=has_default_min_length,
146
+ model_input_name=model_input_name,
147
+ inputs_tensor=inputs_tensor,
148
+ input_ids_length=input_ids_length,
149
+ )
150
+
151
+ # use_dynamic_cache_by_default = False
152
+ if "mamba" in self.__class__.__name__.lower():
153
+ cache_name = "cache_params"
154
+ else:
155
+ cache_name = "past_key_values"
156
+ if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
157
+ raise ValueError(
158
+ f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
159
+ "Cache object) is unsupported. Please use only one of the two."
160
+ )
161
+ elif generation_config.cache_implementation is not None:
162
+ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
163
+ if generation_config.cache_implementation == "static" and not self._supports_static_cache:
164
+ raise ValueError(
165
+ "This model does not support `cache_implementation='static'`. Please check the following "
166
+ "issue: https://github.com/huggingface/transformers/issues/28981"
167
+ )
168
+ model_kwargs[cache_name] = self._get_cache(
169
+ generation_config.cache_implementation,
170
+ getattr(generation_config, "num_beams", 1) * batch_size,
171
+ generation_config.max_length,
172
+ model_kwargs,
173
+ )
174
+ elif generation_config.cache_implementation == "quantized":
175
+ if not self._supports_quantized_cache:
176
+ raise ValueError(
177
+ "This model does not support the quantized cache. If you want your model to support quantized "
178
+ "cache, please open an issue."
179
+ )
180
+
181
+ cache_config = (
182
+ generation_config.cache_config
183
+ if generation_config.cache_config is not None
184
+ else QuantizedCacheConfig()
185
+ )
186
+ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
187
+
188
+ if cache_config.backend == "quanto" and not is_quanto_available():
189
+ raise ImportError(
190
+ "You need to install `quanto` in order to use KV cache quantization with quanto backend. "
191
+ "Please install it via with `pip install quanto`"
192
+ )
193
+ elif cache_config.backend == "HQQ" and not is_hqq_available():
194
+ raise ImportError(
195
+ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
196
+ "Please install it via with `pip install hqq`"
197
+ )
198
+
199
+ model_kwargs[cache_name] = cache_class(cache_config)
200
+ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
201
+ # keeps copying the cache thus using much more memory
202
+ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
203
+ past = model_kwargs.get(cache_name, None)
204
+ requires_cross_attention_cache = (
205
+ self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
206
+ )
207
+ if past is None:
208
+ model_kwargs[cache_name] = (
209
+ DynamicCache()
210
+ if not requires_cross_attention_cache
211
+ else EncoderDecoderCache(DynamicCache(), DynamicCache())
212
+ )
213
+ # use_dynamic_cache_by_default = True
214
+ elif isinstance(past, tuple):
215
+ model_kwargs[cache_name] = (
216
+ DynamicCache.from_legacy_cache(past)
217
+ if not requires_cross_attention_cache
218
+ else EncoderDecoderCache.from_legacy_cache(past)
219
+ )
220
+ # use_dynamic_cache_by_default = True
221
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
222
+
223
+ # 7. determine generation mode
224
+ generation_mode = generation_config.get_generation_mode(assistant_model)
225
+
226
+ if (streamer is not None or streamer_unit is not None) and (generation_config.num_beams > 1):
227
+ raise ValueError(
228
+ "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
229
+ )
230
+
231
+ if self.device.type != input_ids.device.type:
232
+ warnings.warn(
233
+ "You are calling .generate() with the `input_ids` being on a device type different"
234
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
235
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
236
+ " Please make sure that you have put `input_ids` to the"
237
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
238
+ " running `.generate()`.",
239
+ UserWarning,
240
+ )
241
+
242
+ # 8. prepare distribution pre_processing samplers
243
+ prepared_logits_processor = self._get_logits_processor(
244
+ generation_config=generation_config,
245
+ input_ids_seq_length=input_ids_length,
246
+ encoder_input_ids=inputs_tensor,
247
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
248
+ logits_processor=logits_processor,
249
+ device=inputs_tensor.device,
250
+ model_kwargs=model_kwargs,
251
+ negative_prompt_ids=negative_prompt_ids,
252
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
253
+ )
254
+
255
+ # 9. prepare stopping criteria
256
+ prepared_stopping_criteria = self._get_stopping_criteria(
257
+ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
258
+ )
259
+ # 10. go into different generation modes
260
+
261
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
262
+ # 11. prepare logits warper
263
+ prepared_logits_warper = (
264
+ self._get_logits_warper(generation_config, device=input_ids.device)
265
+ if generation_config.do_sample
266
+ else None
267
+ )
268
+
269
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
270
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
271
+ input_ids=input_ids,
272
+ expand_size=generation_config.num_return_sequences,
273
+ is_encoder_decoder=self.config.is_encoder_decoder,
274
+ **model_kwargs,
275
+ )
276
+
277
+ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
278
+ if streaming_unit_gen:
279
+ return self._sample_streaming_unit(
280
+ input_ids,
281
+ logits_processor=prepared_logits_processor,
282
+ logits_warper=prepared_logits_warper,
283
+ stopping_criteria=prepared_stopping_criteria,
284
+ generation_config=generation_config,
285
+ synced_gpus=synced_gpus,
286
+ streamer=streamer,
287
+ streamer_unit=streamer_unit,
288
+ **model_kwargs,
289
+ )
290
+ else:
291
+ return self._sample(
292
+ input_ids,
293
+ logits_processor=prepared_logits_processor,
294
+ logits_warper=prepared_logits_warper,
295
+ stopping_criteria=prepared_stopping_criteria,
296
+ generation_config=generation_config,
297
+ synced_gpus=synced_gpus,
298
+ streamer=streamer,
299
+ **model_kwargs,
300
+ )
301
+ else:
302
+ raise NotImplementedError
303
+
304
+ def _sample(
305
+ self,
306
+ input_ids: torch.Tensor,
307
+ logits_processor: LogitsProcessorList,
308
+ stopping_criteria: StoppingCriteriaList,
309
+ generation_config: GenerationConfig,
310
+ synced_gpus: bool,
311
+ streamer: Optional["BaseStreamer"],
312
+ logits_warper: Optional[LogitsProcessorList],
313
+ **model_kwargs,
314
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
315
+ # init values
316
+ pad_token_id = generation_config._pad_token_tensor
317
+ output_attentions = generation_config.output_attentions
318
+ output_hidden_states = generation_config.output_hidden_states
319
+ output_scores = generation_config.output_scores
320
+ output_logits = generation_config.output_logits
321
+ return_dict_in_generate = generation_config.return_dict_in_generate
322
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
323
+ do_sample = generation_config.do_sample
324
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
325
+ raise ValueError(
326
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
327
+ f"{logits_warper})."
328
+ )
329
+
330
+ # init attention / hidden states / scores tuples
331
+ # scores = () if (return_dict_in_generate and output_scores) else None
332
+ # raw_logits = () if (return_dict_in_generate and output_logits) else None
333
+ # decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
334
+ # cross_attentions = () if (return_dict_in_generate and output_attentions) else None
335
+ # decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
336
+
337
+ scores: tuple = ()
338
+ raw_logits: tuple = ()
339
+ decoder_attentions: tuple = ()
340
+ cross_attentions: tuple = ()
341
+ decoder_hidden_states: tuple = ()
342
+
343
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
344
+ if return_dict_in_generate and self.config.is_encoder_decoder:
345
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
346
+ encoder_hidden_states = (
347
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
348
+ )
349
+
350
+ # keep track of which sequences are already finished
351
+ batch_size = input_ids.shape[0]
352
+ this_peer_finished = False
353
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
354
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
355
+
356
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
357
+ # prepare model inputs
358
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
359
+
360
+ # prepare variable output controls (note: some models won't accept all output controls)
361
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
362
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
363
+
364
+ # forward pass to get next token
365
+ outputs = self(**model_inputs, return_dict=True)
366
+
367
+ if synced_gpus and this_peer_finished:
368
+ continue # don't waste resources running the code we don't need
369
+
370
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be
371
+ # very large for first iteration (the clone itself is always small)
372
+ next_token_logits = outputs.logits[:, -1, :].clone()
373
+
374
+ # pre-process distribution
375
+ next_token_scores = logits_processor(input_ids, next_token_logits)
376
+ if do_sample and logits_warper is not None:
377
+ next_token_scores = logits_warper(input_ids, next_token_scores)
378
+
379
+ # Store scores, attentions and hidden_states when required
380
+ if return_dict_in_generate:
381
+ if output_scores:
382
+ scores += (next_token_scores,)
383
+ if output_logits:
384
+ raw_logits += (next_token_logits,)
385
+ if output_attentions:
386
+ decoder_attentions += (
387
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
388
+ )
389
+ if self.config.is_encoder_decoder:
390
+ cross_attentions += (outputs.cross_attentions,)
391
+
392
+ if output_hidden_states:
393
+ decoder_hidden_states += (
394
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
395
+ )
396
+
397
+ # token selection
398
+ if do_sample:
399
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
400
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
401
+ else:
402
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
403
+
404
+ # finished sentences should have their next token be a padding token
405
+ if has_eos_stopping_criteria:
406
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
407
+
408
+ # update generated ids, model inputs, and length for next step
409
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
410
+ if streamer is not None:
411
+ streamer.put(next_tokens.cpu())
412
+ model_kwargs = self._update_model_kwargs_for_generation(
413
+ outputs,
414
+ model_kwargs,
415
+ is_encoder_decoder=self.config.is_encoder_decoder,
416
+ )
417
+
418
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
419
+ this_peer_finished = bool(int(unfinished_sequences.max()) == 0)
420
+
421
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
422
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
423
+ del outputs
424
+
425
+ if streamer is not None:
426
+ streamer.end()
427
+
428
+ if return_dict_in_generate:
429
+ if self.config.is_encoder_decoder:
430
+ return GenerateEncoderDecoderOutput(
431
+ sequences=input_ids,
432
+ scores=scores,
433
+ logits=raw_logits,
434
+ encoder_attentions=encoder_attentions,
435
+ encoder_hidden_states=encoder_hidden_states,
436
+ decoder_attentions=decoder_attentions,
437
+ cross_attentions=cross_attentions,
438
+ decoder_hidden_states=decoder_hidden_states,
439
+ past_key_values=model_kwargs.get("past_key_values"),
440
+ )
441
+ else:
442
+ return GenerateDecoderOnlyOutput(
443
+ sequences=input_ids,
444
+ scores=scores,
445
+ logits=raw_logits,
446
+ attentions=decoder_attentions,
447
+ hidden_states=decoder_hidden_states,
448
+ past_key_values=model_kwargs.get("past_key_values"),
449
+ )
450
+ else:
451
+ return input_ids
452
+
453
+ def _sample_streaming_unit(
454
+ self,
455
+ input_ids: torch.Tensor,
456
+ logits_processor: LogitsProcessorList,
457
+ stopping_criteria: StoppingCriteriaList,
458
+ generation_config: GenerationConfig,
459
+ synced_gpus: bool,
460
+ streamer: Optional["BaseStreamer"],
461
+ streamer_unit: Optional["BaseStreamer"],
462
+ logits_warper: Optional[LogitsProcessorList],
463
+ **model_kwargs,
464
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
465
+ # init values
466
+ pad_token_id = generation_config._pad_token_tensor
467
+ output_attentions = generation_config.output_attentions
468
+ output_hidden_states = generation_config.output_hidden_states
469
+ output_scores = generation_config.output_scores
470
+ output_logits = generation_config.output_logits
471
+ return_dict_in_generate = generation_config.return_dict_in_generate
472
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
473
+ do_sample = generation_config.do_sample
474
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
475
+ raise ValueError(
476
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
477
+ f"{logits_warper})."
478
+ )
479
+
480
+ # init attention / hidden states / scores tuples
481
+ # scores = () if (return_dict_in_generate and output_scores) else None
482
+ # raw_logits = () if (return_dict_in_generate and output_logits) else None
483
+ # decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
484
+ # cross_attentions = () if (return_dict_in_generate and output_attentions) else None
485
+ # decoder_hidden_states: tuple = () if (return_dict_in_generate and output_hidden_states) else None
486
+
487
+ scores: tuple = ()
488
+ raw_logits: tuple = ()
489
+ decoder_attentions: tuple = ()
490
+ cross_attentions: tuple = ()
491
+ decoder_hidden_states: tuple = ()
492
+
493
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
494
+ if return_dict_in_generate and self.config.is_encoder_decoder:
495
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
496
+ encoder_hidden_states = (
497
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
498
+ )
499
+
500
+ # keep track of which sequences are already finished
501
+ batch_size = input_ids.shape[0]
502
+ this_peer_finished = False
503
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
504
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
505
+
506
+ generated_units = torch.tensor([])
507
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
508
+ # prepare model inputs
509
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
510
+
511
+ # prepare variable output controls (note: some models won't accept all output controls)
512
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
513
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
514
+
515
+ # forward pass to get next token
516
+ outputs = self(**model_inputs, return_dict=True)
517
+
518
+ if synced_gpus and this_peer_finished:
519
+ continue # don't waste resources running the code we don't need
520
+
521
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits
522
+ # which may be very large for first iteration (the clone itself is always small)
523
+ next_token_logits = outputs.logits[:, -1, :].clone()
524
+
525
+ # pre-process distribution
526
+ next_token_scores = logits_processor(input_ids, next_token_logits)
527
+ if do_sample and logits_warper is not None:
528
+ next_token_scores = logits_warper(input_ids, next_token_scores)
529
+
530
+ # Store scores, attentions and hidden_states when required
531
+ if return_dict_in_generate:
532
+ if output_scores and scores is not None and next_token_scores is not None:
533
+ scores += (next_token_scores,)
534
+ if output_logits and raw_logits is not None and next_token_logits is not None:
535
+ raw_logits += (next_token_logits,)
536
+ if output_attentions and decoder_attentions is not None:
537
+ decoder_attentions += (
538
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
539
+ )
540
+ if self.config.is_encoder_decoder and cross_attentions is not None:
541
+ cross_attentions += (outputs.cross_attentions,)
542
+
543
+ if output_hidden_states and decoder_hidden_states is not None:
544
+ decoder_hidden_states += (
545
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
546
+ )
547
+
548
+ # token selection
549
+ if do_sample:
550
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
551
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
552
+ else:
553
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
554
+
555
+ # speechgen
556
+ hidden_states = torch.cat(
557
+ [decoder_hidden_states[0][-1][:, -1:, :]]
558
+ + [decoder_hidden_states[i][-1] for i in range(1, len(decoder_hidden_states))],
559
+ dim=1,
560
+ )
561
+ ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0))
562
+ cur_units = ctc_postprocess(ctc_pred, blank=self.model.config.unit_vocab_size)
563
+
564
+ # finished sentences should have their next token be a padding token
565
+ if has_eos_stopping_criteria:
566
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
567
+
568
+ # update generated ids, model inputs, and length for next step
569
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
570
+ if streamer is not None:
571
+ streamer.put(next_tokens.cpu())
572
+ if streamer_unit is not None:
573
+ for i in range(len(generated_units), len(cur_units)):
574
+ streamer_unit.put(cur_units[i].unsqueeze(0))
575
+ generated_units = cur_units
576
+ model_kwargs = self._update_model_kwargs_for_generation(
577
+ outputs,
578
+ model_kwargs,
579
+ is_encoder_decoder=self.config.is_encoder_decoder,
580
+ )
581
+
582
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
583
+ this_peer_finished = bool(int(unfinished_sequences.max()) == 0)
584
+
585
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
586
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
587
+ del outputs
588
+
589
+ if streamer is not None:
590
+ streamer.end()
591
+
592
+ if return_dict_in_generate:
593
+ if self.config.is_encoder_decoder:
594
+ return GenerateEncoderDecoderOutput(
595
+ sequences=input_ids,
596
+ scores=scores,
597
+ logits=raw_logits,
598
+ encoder_attentions=encoder_attentions,
599
+ encoder_hidden_states=encoder_hidden_states,
600
+ decoder_attentions=decoder_attentions,
601
+ cross_attentions=cross_attentions,
602
+ decoder_hidden_states=decoder_hidden_states,
603
+ past_key_values=model_kwargs.get("past_key_values"),
604
+ )
605
+ else:
606
+ return GenerateDecoderOnlyOutput(
607
+ sequences=input_ids,
608
+ scores=scores,
609
+ logits=raw_logits,
610
+ attentions=decoder_attentions,
611
+ hidden_states=decoder_hidden_states,
612
+ past_key_values=model_kwargs.get("past_key_values"),
613
+ )
614
+ else:
615
+ return input_ids
616
+
617
+
618
+ def ctc_postprocess(tokens, blank):
619
+ _toks = tokens.squeeze(0).tolist()
620
+ deduplicated_toks = [v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]]
621
+ hyp = torch.tensor([v for v in deduplicated_toks if v != blank])
622
+ return hyp